diff --git a/cmd/queenbee/main.go b/cmd/queenbee/main.go index 19146a9..d38d654 100644 --- a/cmd/queenbee/main.go +++ b/cmd/queenbee/main.go @@ -5,6 +5,8 @@ import ( "log" "code.rocketnine.space/tslocum/beehive" + + _ "github.com/lib/pq" ) func main() { @@ -16,6 +18,16 @@ func main() { parseConfig(configPath) + log.Println("Connecting to database...") + connStr := "postgres://queenbee:queenbee@localhost/queenbee?sslmode=disable&connect_timeout=5" + var err error + beehive.DB, err = beehive.NewDatabase("postgres", connStr) + if err != nil { + log.Fatal(err) + } + log.Println("Connected successfully") + + log.Println("Listening for worker connections") s, err := beehive.NewServer(config.Listen, config.Workers) if err != nil { log.Fatal(err) diff --git a/database.go b/database.go new file mode 100644 index 0000000..f00a8f6 --- /dev/null +++ b/database.go @@ -0,0 +1,122 @@ +package beehive + +import ( + "database/sql" + "log" +) + +const initSchema = ` +CREATE TABLE worker ( + id serial PRIMARY KEY +); + +CREATE TABLE deployment ( + id serial PRIMARY KEY, + created timestamptz, + modified timestamptz, + worker_id integer REFERENCES worker, + festoon text +); + +CREATE TABLE task ( + id serial PRIMARY KEY, + type integer, + created timestamptz, + started timestamptz, + completed timestamptz, + deployment_id integer REFERENCES deployment +); + +CREATE INDEX ON task (started); +CREATE INDEX ON task (completed); +CREATE INDEX ON task (deployment_id); + +CREATE FUNCTION task_notify() RETURNS trigger AS $$ +BEGIN + NOTIFY task; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER task_notify_trigger +AFTER INSERT ON task +EXECUTE FUNCTION task_notify(); +` + +var DB *Database + +type Database struct { + db *sql.DB +} + +func NewDatabase(driverName string, dataSource string) (*Database, error) { + db, err := sql.Open(driverName, dataSource) + if err != nil { + return nil, err + } + + // Test connection. + _, err = db.Exec("SET search_path TO queenbee") + if err != nil { + log.Fatal(err) + } + + // Initialize database. + var result int + err = db.QueryRow("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'queenbee' AND table_name = 'deployment'").Scan(&result) + if err != nil { + log.Fatal(err) + } else if result == 0 { + _, err = db.Exec(initSchema) + if err != nil { + log.Fatalf("failed to initialize database: %s", err) + } + log.Println("Initialized database schema") + } + + d := &Database{ + db: db, + } + return d, nil +} + +func (d *Database) StartTransaction() error { + _, err := d.db.Exec("BEGIN") + return err +} + +func (d *Database) CancelTransaction() error { + _, err := d.db.Exec("ROLLBACK") + return err +} + +func (d *Database) CommitTransaction() error { + _, err := d.db.Exec("COMMIT") + return err +} + +func (d *Database) PendingTasks() ([]*Task, error) { + var tasks []*Task + rows, err := d.db.Query("SELECT id, created, started, completed, deployment_id from task where started = 0 ORDER BY id ASC") + if err != nil { + return nil, err + } + for rows.Next() { + task := &Task{} + err = rows.Scan(task.ID, task.Created, task.Started, task.Completed, task.DeploymentID) + if err != nil { + return nil, err + } + tasks = append(tasks, task) + } + return tasks, nil +} + +func (d *Database) AddTask(t *Task) error { + _, err := d.db.Exec("INSERT INTO task (created, started, completed, deployment_id) VALUES (null, ?, ?, ?, ?)", t.Created, t.Started, t.Completed, t.DeploymentID) + return err +} + +func (d *Database) UpdateTask(t *Task) error { + _, err := d.db.Exec("UPDATE task SET started=?, completed=?, deployment_id=? WHERE id=?", t.Started, t.Completed, t.DeploymentID, t.ID) + return err +} diff --git a/go.mod b/go.mod index fc4850e..3989415 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,9 @@ module code.rocketnine.space/tslocum/beehive go 1.19 -require sigs.k8s.io/yaml v1.3.0 +require ( + github.com/lib/pq v1.10.9 + sigs.k8s.io/yaml v1.3.0 +) require gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 417df28..7512d80 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/server.go b/server.go index 609269e..b3ec1fc 100644 --- a/server.go +++ b/server.go @@ -210,7 +210,7 @@ func (s *Server) handleConnection(c *Client) { log.Printf("Worker %d disconnected", c.Worker) } -func (s *Server) sendTask(workerID int, t *Task) bool { +func (s *Server) sendTask(workerID int, t *TaskMessage) bool { for i := range s.Clients { if s.Clients[i].Worker == workerID { taskJSON, err := json.Marshal(t) diff --git a/task.go b/task.go index 4858dee..9a50e50 100644 --- a/task.go +++ b/task.go @@ -1,5 +1,16 @@ package beehive +import "time" + +type Task struct { + ID int + Type TaskType + Created time.Time + Started time.Time + Completed time.Time + DeploymentID int +} + type TaskType int // Note: Task types must only be appended to preserve values. @@ -11,13 +22,13 @@ const ( TaskStop ) -type Task struct { +type TaskMessage struct { Type TaskType Parameters map[string]string } -func NewTask(t TaskType, parameters map[string]string) *Task { - return &Task{ +func NewTask(t TaskType, parameters map[string]string) *TaskMessage { + return &TaskMessage{ Type: t, Parameters: parameters, } diff --git a/worker.go b/worker.go index cfe49e8..58143ea 100644 --- a/worker.go +++ b/worker.go @@ -22,7 +22,7 @@ type Worker struct { Deployments []*Deployment - TaskQueue chan *Task + TaskQueue chan *TaskMessage requestPortsFunc func(d *Deployment) []int } @@ -33,7 +33,7 @@ func NewWorker(id int, ip string, festoonsDir string, deploymentsDir string) *Wo IP: ip, FestoonsDir: festoonsDir, DeploymentsDir: deploymentsDir, - TaskQueue: make(chan *Task), + TaskQueue: make(chan *TaskMessage), } go w.handleTaskQueue() @@ -69,7 +69,7 @@ func (w *Worker) handleTaskQueue() { } } -func (w *Worker) ExecuteTask(t *Task) error { +func (w *Worker) ExecuteTask(t *TaskMessage) error { return nil } @@ -88,7 +88,7 @@ func (w *Worker) HandleRead(c *Client) { continue } - task := &Task{} + task := &TaskMessage{} err := json.Unmarshal(scanner.Bytes(), task) if err != nil { log.Fatalf("failed to unmarshal %s: %s", scanner.Bytes(), err)