You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
10 KiB
414 lines
10 KiB
package main |
|
|
|
import ( |
|
"encoding/base64" |
|
"fmt" |
|
"log" |
|
"strconv" |
|
"strings" |
|
"time" |
|
|
|
"github.com/gorilla/securecookie" |
|
"github.com/jmoiron/sqlx" |
|
_ "github.com/mattn/go-sqlite3" |
|
"github.com/pkg/errors" |
|
) |
|
|
|
const databaseVersion = 1 |
|
|
|
var ErrAccountExists = errors.New("account already exists") |
|
var ErrChannelExists = errors.New("channel already exists") |
|
var ErrChannelDoesNotExist = errors.New("channel does not exist") |
|
|
|
var tables = map[string][]string{ |
|
"meta": { |
|
"`key` TEXT NULL PRIMARY KEY", |
|
"`value` TEXT NULL"}, |
|
"accounts": { |
|
"`id` INTEGER PRIMARY KEY AUTOINCREMENT", |
|
"`username` TEXT NULL", |
|
"`password` TEXT NULL"}, |
|
"channels": { |
|
"`channel` TEXT PRIMARY KEY", |
|
"`topic` TEXT NULL", |
|
"`topictime` INTEGER NULL", |
|
"`password` TEXT NULL"}, |
|
"permissions": { |
|
"`channel` TEXT NULL", |
|
"`account` INTEGER NULL", |
|
"`permission` INTEGER NULL"}, |
|
"bans": { |
|
"`channel` TEXT NULL", |
|
"`type` INTEGER NULL", |
|
"`target` TEXT NULL", |
|
"`expires` INTEGER NULL", |
|
"`reason` TEXT NULL"}} |
|
|
|
const ( |
|
BAN_TYPE_ADDRESS = 1 |
|
BAN_TYPE_ACCOUNT = 2 |
|
) |
|
|
|
type DBAccount struct { |
|
ID int64 |
|
Username string |
|
Password string |
|
} |
|
|
|
type DBChannel struct { |
|
Channel string |
|
Topic string |
|
TopicTime int64 |
|
Password string |
|
} |
|
|
|
type DBPermission struct { |
|
Channel string |
|
Account int64 |
|
Permission int |
|
} |
|
|
|
type DBMode struct { |
|
Channel string |
|
Mode string |
|
Value string |
|
} |
|
|
|
type DBBan struct { |
|
Channel string |
|
Type int |
|
Target string |
|
Expires int64 |
|
Reason string |
|
} |
|
|
|
type Database struct { |
|
db *sqlx.DB |
|
} |
|
|
|
func (d *Database) Connect(driver string, dataSource string) error { |
|
var err error |
|
d.db, err = sqlx.Connect(driver, dataSource) |
|
if err != nil { |
|
return errors.Wrapf(err, "failed to connect to %s database", driver) |
|
} |
|
|
|
err = d.CreateTables() |
|
if err != nil { |
|
return errors.Wrap(err, "failed to create tables") |
|
} |
|
|
|
err = d.Migrate() |
|
if err != nil { |
|
return errors.Wrap(err, "failed to migrate database") |
|
} |
|
|
|
err = d.Initialize() |
|
if err != nil { |
|
return errors.Wrap(err, "failed to initialize database") |
|
} |
|
|
|
return err |
|
} |
|
|
|
func (d *Database) CreateTables() error { |
|
for tname, tcolumns := range tables { |
|
_, err := d.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS `%s` (%s)", tname, strings.Join(tcolumns, ","))) |
|
if err != nil { |
|
return errors.Wrapf(err, "failed to create %s table", tname) |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (d *Database) Migrate() error { |
|
rows, err := d.db.Query("SELECT `value` FROM meta WHERE `key`=? LIMIT 1", "version") |
|
if p(err) { |
|
return errors.Wrap(err, "failed to fetch database version") |
|
} |
|
|
|
version := 0 |
|
for rows.Next() { |
|
v := "" |
|
err = rows.Scan(&v) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to fetch database version") |
|
} |
|
|
|
version, err = strconv.Atoi(v) |
|
if err != nil { |
|
version = -1 |
|
} |
|
} |
|
|
|
if version == -1 { |
|
log.Panic("Unable to migrate database: database version unknown") |
|
} else if version == 0 { |
|
_, err := d.db.Exec("UPDATE meta SET `value`=? WHERE `key`=?", strconv.Itoa(databaseVersion), "version") |
|
if err != nil { |
|
return errors.Wrap(err, "failed to save database version") |
|
} |
|
} else if version < databaseVersion { |
|
// databaseVersion 2 migration queries will go here |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (d *Database) Initialize() error { |
|
a, err := db.Account(1) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to initialize") |
|
} |
|
|
|
if a.ID > 0 { |
|
return nil // Admin account exists |
|
} |
|
|
|
err = d.AddAccount("admin", "password") |
|
if err != nil { |
|
return errors.Wrap(err, "failed to create initial administrator account") |
|
} |
|
|
|
ac := &DBChannel{Channel: channelServer, Topic: "Secret Area of VIP Quality"} |
|
d.AddChannel(1, ac) |
|
|
|
uc := &DBChannel{Channel: channelLobby, Topic: "Welcome to AnonIRC"} |
|
d.AddChannel(1, uc) |
|
|
|
return nil |
|
} |
|
|
|
func (d *Database) Close() error { |
|
err := d.db.Close() |
|
if err != nil { |
|
err = errors.Wrap(err, "failed to close database") |
|
} |
|
return err |
|
} |
|
|
|
// Accounts |
|
|
|
func (d *Database) Account(id int64) (DBAccount, error) { |
|
a := DBAccount{} |
|
err := d.db.Get(&a, "SELECT * FROM accounts WHERE id=? LIMIT 1", id) |
|
if p(err) { |
|
return a, errors.Wrap(err, "failed to fetch account") |
|
} |
|
|
|
return a, nil |
|
} |
|
|
|
func (d *Database) AccountU(username string) (DBAccount, error) { |
|
a := DBAccount{} |
|
err := d.db.Get(&a, "SELECT * FROM accounts WHERE username=? LIMIT 1", generateHash(username)) |
|
if p(err) { |
|
return a, errors.Wrap(err, "failed to fetch account by username") |
|
} |
|
|
|
return a, nil |
|
} |
|
|
|
// TODO: Lockout on too many failed attempts |
|
func (d *Database) Auth(username string, password string) (int64, error) { |
|
// TODO: Salt in config |
|
a := DBAccount{} |
|
err := d.db.Get(&a, "SELECT * FROM accounts WHERE username=? AND password=? LIMIT 1", generateHash(username), generateHash(username+"-"+password)) |
|
if p(err) { |
|
return 0, errors.Wrap(err, "failed to authenticate account") |
|
} |
|
|
|
return a.ID, nil |
|
} |
|
|
|
func (d *Database) GenerateToken() string { |
|
return base64.URLEncoding.EncodeToString(securecookie.GenerateRandomKey(64)) |
|
} |
|
|
|
func (d *Database) AddAccount(username string, password string) error { |
|
ex, err := d.AccountU(username) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to search for existing account while adding account") |
|
} else if ex.ID > 0 { |
|
return ErrAccountExists |
|
} |
|
|
|
_, err = d.db.Exec("INSERT INTO accounts (username, password) VALUES (?, ?)", generateHash(username), generateHash(username+"-"+password)) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to add account") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (d *Database) SetUsername(accountid int64, username string, password string) error { |
|
ex, err := d.AccountU(username) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to search for existing account while setting username") |
|
} else if ex.ID > 0 { |
|
return ErrAccountExists |
|
} |
|
|
|
_, err = d.db.Exec("UPDATE accounts SET username=?, password=? WHERE id=?", generateHash(username), generateHash(username+"-"+password), accountid) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set username") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (d *Database) SetPassword(accountid int64, username string, password string) error { |
|
_, err := d.db.Exec("UPDATE accounts SET password=? WHERE id=?", generateHash(username+"-"+password), accountid) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set password") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// Channels |
|
|
|
func (d *Database) ChannelID(id int64) (DBChannel, error) { |
|
c := DBChannel{} |
|
err := d.db.Get(&c, "SELECT * FROM channels WHERE id=? LIMIT 1", id) |
|
if p(err) { |
|
return c, errors.Wrap(err, "failed to fetch channel") |
|
} |
|
|
|
return c, nil |
|
} |
|
|
|
func (d *Database) Channel(channel string) (DBChannel, error) { |
|
c := DBChannel{} |
|
err := d.db.Get(&c, "SELECT * FROM channels WHERE channel=? LIMIT 1", generateHash(channel)) |
|
if p(err) { |
|
return c, errors.Wrap(err, "failed to fetch channel by key") |
|
} |
|
|
|
return c, nil |
|
} |
|
|
|
func (d *Database) AddChannel(accountid int64, channel *DBChannel) error { |
|
ex, err := d.Channel(channel.Channel) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to search for existing channel while adding channel") |
|
} else if ex.Channel != "" { |
|
return ErrChannelExists |
|
} |
|
|
|
chch := channel.Channel |
|
channel.Channel = generateHash(strings.ToLower(channel.Channel)) |
|
_, err = d.db.Exec("INSERT INTO channels (channel, topic, topictime, password) VALUES (?, ?, ?, ?)", channel.Channel, channel.Topic, channel.TopicTime, channel.Password) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to add channel") |
|
} |
|
|
|
err = d.SetPermission(accountid, chch, permissionSuperAdmin) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set permission on newly added channel") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// Permissions |
|
|
|
func (d *Database) GetPermission(accountid int64, channel string) (DBPermission, error) { |
|
dbp := DBPermission{} |
|
|
|
// Return REGISTERED by default |
|
dbp.Permission = permissionRegistered |
|
|
|
err := d.db.Get(&dbp, "SELECT * FROM permissions WHERE account=? AND channel=? LIMIT 1", accountid, generateHash(channel)) |
|
if p(err) { |
|
return dbp, errors.Wrap(err, "failed to fetch permission") |
|
} |
|
|
|
return dbp, nil |
|
} |
|
|
|
func (d *Database) SetPermission(accountid int64, channel string, permission int) error { |
|
acc, err := d.Account(accountid) |
|
if err != nil { |
|
log.Panicf("%+v", err) |
|
} else if acc.ID == 0 { |
|
return nil |
|
} |
|
|
|
ch, err := d.Channel(channel) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to fetch channel while setting permission") |
|
} else if ch.Channel == "" { |
|
return nil |
|
} |
|
chh := generateHash(channel) |
|
|
|
dbp, err := d.GetPermission(accountid, chh) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set permission") |
|
} |
|
|
|
if dbp.Channel != "" { |
|
_, err = d.db.Exec("UPDATE permissions SET permission=? WHERE account=? AND channel=?", permission, accountid, chh) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set permission") |
|
} |
|
} else { |
|
_, err = d.db.Exec("INSERT INTO permissions (channel, account, permission) VALUES (?, ?, ?)", chh, accountid, permission) |
|
if err != nil { |
|
return errors.Wrap(err, "failed to set permission") |
|
} |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// Bans |
|
|
|
func (d *Database) Ban(banid int) (DBBan, error) { |
|
b := DBBan{} |
|
err := d.db.Get(&b, "SELECT * FROM bans WHERE id=? LIMIT 1", banid) |
|
if p(err) { |
|
return b, errors.Wrap(err, "failed to fetch ban") |
|
} |
|
|
|
return b, nil |
|
} |
|
|
|
func (d *Database) BanAddr(addrhash string, channel string) (DBBan, error) { |
|
b := DBBan{} |
|
if addrhash == "" { |
|
return b, nil |
|
} |
|
|
|
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=? AND (`expires` = 0 OR `expires` > ?)", generateHash(channel), BAN_TYPE_ADDRESS, addrhash, time.Now().Unix()) |
|
if p(err) { |
|
return b, errors.Wrap(err, "failed to fetch ban") |
|
} |
|
|
|
return b, nil |
|
} |
|
|
|
func (d *Database) BanAccount(accountid int64, channel string) (DBBan, error) { |
|
b := DBBan{} |
|
if accountid == 0 { |
|
return b, nil |
|
} |
|
|
|
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=? AND (`expires` = 0 OR `expires` > ?)", generateHash(channel), BAN_TYPE_ACCOUNT, accountid, time.Now().Unix()) |
|
if p(err) { |
|
return b, errors.Wrap(err, "failed to fetch ban") |
|
} |
|
|
|
return b, nil |
|
} |
|
|
|
func (d *Database) AddBan(b DBBan) error { |
|
_, err := d.db.Exec("INSERT INTO bans (`channel`, `type`, `target`, `expires`, `reason`) VALUES (?, ?, ?, ?, ?)", b.Channel, b.Type, b.Target, b.Expires, b.Reason) |
|
if p(err) { |
|
return errors.Wrap(err, "failed to add ban") |
|
} |
|
|
|
return nil |
|
}
|
|
|