Implement BAN and KILL

This commit is contained in:
Trevor Slocum 2017-12-14 17:39:18 -08:00
parent 54c43696ac
commit 12cf575569
6 changed files with 236 additions and 95 deletions

View File

@ -26,7 +26,7 @@ type ChannelLog struct {
Timestamp int64
Client string
IP string
Account int
Account int64
Action string
Message string
}
@ -62,7 +62,7 @@ func (c *Channel) Log(client *Client, action string, message string) {
c.logs[nano] = &ChannelLog{Timestamp: nano, Client: client.identifier, IP: client.iphash, Account: client.account, Action: action, Message: message}
}
func (c *Channel) RevealLog(page int, full bool) []string {
func (c *Channel) RevealLog(page int, showAll bool) []string {
c.RLock()
defer c.RUnlock()
@ -88,7 +88,7 @@ func (c *Channel) RevealLog(page int, full bool) []string {
}
if page == -1 || i >= (CHANNEL_LOGS_PER_PAGE*(page-1)) {
if full || (l.Action != irc.JOIN && l.Action != irc.PART) {
if showAll || (l.Action != irc.JOIN && l.Action != irc.PART) {
if page > -1 && j == CHANNEL_LOGS_PER_PAGE {
logsRemain = true
break
@ -102,7 +102,7 @@ func (c *Channel) RevealLog(page int, full bool) []string {
if len(ls) == 0 {
ls = append(ls, "No log entries match criteria")
} else {
filterType := "all"
filterType := "all entries"
if page > -1 {
filterType = fmt.Sprintf("page %d", page)
}
@ -118,7 +118,7 @@ func (c *Channel) RevealLog(page int, full bool) []string {
return ls
}
func (c *Channel) RevealInfo(identifier string) (string, int) {
func (c *Channel) RevealInfo(identifier string) (string, int64) {
if len(identifier) != 5 {
return "", 0
}

View File

@ -21,7 +21,7 @@ type Client struct {
nick string
user string
host string
account int
account int64
conn net.Conn
writebuffer chan *irc.Message
@ -78,16 +78,16 @@ func (c *Client) getPrefix() *irc.Prefix {
return &irc.Prefix{Name: c.nick, User: c.user, Host: c.host}
}
func (c *Client) write(msg *irc.Message) {
func (c *Client) write(prefix *irc.Prefix, command string, params []string) {
if c.state == ENTITY_STATE_TERMINATING {
return
}
c.writebuffer <- msg
c.writebuffer <- &irc.Message{Prefix: prefix, Command: command, Params: params}
}
func (c *Client) writeMessage(command string, params []string) {
c.write(&irc.Message{&prefixAnonIRC, command, params})
c.write(&prefixAnonIRC, command, params)
}
func (c *Client) sendMessage(message string) {
@ -149,7 +149,7 @@ func (c *Client) getPermission(channel string) int {
}
func (c *Client) globalPermission() int {
return c.getPermission("&")
return c.getPermission(CHANNEL_SERVER)
}
func (c *Client) canUse(command string, channel string) bool {

View File

@ -6,6 +6,7 @@ import (
"log"
"strconv"
"strings"
"time"
"github.com/gorilla/securecookie"
"github.com/jmoiron/sqlx"
@ -49,7 +50,7 @@ const (
)
type DBAccount struct {
ID int
ID int64
Username string
Password string
}
@ -57,13 +58,13 @@ type DBAccount struct {
type DBChannel struct {
Channel string
Topic string
TopicTime int
TopicTime int64
Password string
}
type DBPermission struct {
Channel string
Account int
Account int64
Permission int
}
@ -77,7 +78,7 @@ type DBBan struct {
Channel string
Type int
Target string
Expires int
Expires int64
Reason string
}
@ -189,7 +190,7 @@ func (d *Database) Close() error {
// Accounts
func (d *Database) Account(id int) (DBAccount, error) {
func (d *Database) Account(id int64) (DBAccount, error) {
a := DBAccount{}
err := d.db.Get(&a, "SELECT * FROM accounts WHERE id=? LIMIT 1", id)
if p(err) {
@ -210,7 +211,7 @@ func (d *Database) AccountU(username string) (DBAccount, error) {
}
// TODO: Lockout on too many failed attempts
func (d *Database) Auth(username string, password string) (int, error) {
func (d *Database) Auth(username string, password string) (int64, error) {
// TODO: Salt in config
a := DBAccount{}
err := d.db.Get(&a, "SELECT * FROM accounts WHERE username=? AND password=? LIMIT 1", generateHash(username), generateHash(username+"-"+password))
@ -241,7 +242,7 @@ func (d *Database) AddAccount(username string, password string) error {
return nil
}
func (d *Database) SetUsername(accountid int, username string, password string) error {
func (d *Database) SetUsername(accountid int64, username string, password string) error {
ex, err := d.AccountU(username)
if err != nil {
return errors.Wrap(err, "failed to search for existing account while setting username")
@ -257,7 +258,7 @@ func (d *Database) SetUsername(accountid int, username string, password string)
return nil
}
func (d *Database) SetPassword(accountid int, username string, password string) error {
func (d *Database) SetPassword(accountid int64, username string, password string) error {
_, err := d.db.Exec("UPDATE accounts SET password=? WHERE id=?", generateHash(username+"-"+password), accountid)
if err != nil {
return errors.Wrap(err, "failed to set password")
@ -268,7 +269,7 @@ func (d *Database) SetPassword(accountid int, username string, password string)
// Channels
func (d *Database) ChannelID(id int) (DBChannel, error) {
func (d *Database) ChannelID(id int64) (DBChannel, error) {
c := DBChannel{}
err := d.db.Get(&c, "SELECT * FROM channels WHERE id=? LIMIT 1", id)
if p(err) {
@ -288,7 +289,7 @@ func (d *Database) Channel(channel string) (DBChannel, error) {
return c, nil
}
func (d *Database) AddChannel(accountid int, channel *DBChannel) error {
func (d *Database) AddChannel(accountid int64, channel *DBChannel) error {
ex, err := d.Channel(channel.Channel)
if err != nil {
return errors.Wrap(err, "failed to search for existing channel while adding channel")
@ -313,7 +314,7 @@ func (d *Database) AddChannel(accountid int, channel *DBChannel) error {
// Permissions
func (d *Database) GetPermission(accountid int, channel string) (DBPermission, error) {
func (d *Database) GetPermission(accountid int64, channel string) (DBPermission, error) {
dbp := DBPermission{}
// Return REGISTERED by default
@ -327,7 +328,7 @@ func (d *Database) GetPermission(accountid int, channel string) (DBPermission, e
return dbp, nil
}
func (d *Database) SetPermission(accountid int, channel string, permission int) error {
func (d *Database) SetPermission(accountid int64, channel string, permission int) error {
acc, err := d.Account(accountid)
if err != nil {
log.Panicf("%+v", err)
@ -377,7 +378,11 @@ func (d *Database) Ban(banid int) (DBBan, error) {
func (d *Database) BanAddr(addrhash string, channel string) (DBBan, error) {
b := DBBan{}
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=?", generateHash(channel), BAN_TYPE_ADDRESS, addrhash)
if addrhash == "" {
return b, nil
}
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=? AND (`expires` = 0 OR `expires` > ?)", generateHash(channel), BAN_TYPE_ADDRESS, addrhash, time.Now().Unix())
if p(err) {
return b, errors.Wrap(err, "failed to fetch ban")
}
@ -385,9 +390,13 @@ func (d *Database) BanAddr(addrhash string, channel string) (DBBan, error) {
return b, nil
}
func (d *Database) BanAccount(accountid int, channel string) (DBBan, error) {
func (d *Database) BanAccount(accountid int64, channel string) (DBBan, error) {
b := DBBan{}
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=?", generateHash(channel), BAN_TYPE_ACCOUNT, accountid)
if accountid == 0 {
return b, nil
}
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=? AND (`expires` = 0 OR `expires` > ?)", generateHash(channel), BAN_TYPE_ACCOUNT, accountid, time.Now().Unix())
if p(err) {
return b, errors.Wrap(err, "failed to fetch ban")
}
@ -396,19 +405,7 @@ func (d *Database) BanAccount(accountid int, channel string) (DBBan, error) {
}
func (d *Database) AddBan(b DBBan) error {
var err error
// Channel-specific (not server-wide)
if b.Channel != "&" {
ex, err := d.Channel(b.Channel)
if err != nil {
return errors.Wrap(err, "failed to search for existing ban while adding ban")
} else if ex.Channel == "" {
return ErrChannelDoesNotExist
}
}
_, err = d.db.Exec("INSERT INTO bans (`channel`, `type`, `target`, `expires`, `reason`) VALUES (?, ?, ?, ?, ?, ?)", b.Channel, b.Type, b.Target, b.Expires, b.Reason)
_, err := d.db.Exec("INSERT INTO bans (`channel`, `type`, `target`, `expires`, `reason`) VALUES (?, ?, ?, ?, ?)", b.Channel, b.Type, b.Target, b.Expires, b.Reason)
if p(err) {
return errors.Wrap(err, "failed to add ban")
}

View File

@ -32,8 +32,8 @@ import (
"gopkg.in/sorcix/irc.v2"
)
var prefixAnonymous = irc.Prefix{"Anonymous", "Anon", "IRC"}
var prefixAnonIRC = irc.Prefix{Name: "AnonIRC"}
var prefixAnonymous = irc.Prefix{Name: "Anonymous", User: "Anon", Host: "IRC"}
const motd = `
_|_| _|_|_| _|_|_| _|_|_|

208
server.go
View File

@ -61,7 +61,7 @@ const (
var permissionLabels = map[int]string{
PERMISSION_CLIENT: "Client",
PERMISSION_REGISTERED: "Registered",
PERMISSION_REGISTERED: "Registered Client",
PERMISSION_VIP: "VIP",
PERMISSION_MODERATOR: "Moderator",
PERMISSION_ADMIN: "Administrator",
@ -269,7 +269,7 @@ func (s *Server) revealClient(channel string, identifier string) *Client {
log.Println("hash not found")
return nil
}
log.Println("have hash")
cls := s.getClients("")
for _, rcl := range cls {
if rcl.iphash == riphash || (rcl.account > 0 && rcl.account == raccount) {
@ -280,7 +280,7 @@ func (s *Server) revealClient(channel string, identifier string) *Client {
return nil
}
func (s *Server) revealClientInfo(channel string, identifier string) (string, int) {
func (s *Server) revealClientInfo(channel string, identifier string) (string, int64) {
if len(identifier) != 5 {
return "", 0
}
@ -344,7 +344,7 @@ func (s *Server) joinChannel(channel string, client string) {
}
ch.clients.Store(client, s.clientsInChannel(channel, client)+1)
cl.write(&irc.Message{cl.getPrefix(), irc.JOIN, []string{channel}})
cl.write(cl.getPrefix(), irc.JOIN, []string{channel})
ch.Log(cl, irc.JOIN, "")
s.sendNames(channel, client)
@ -360,7 +360,7 @@ func (s *Server) partChannel(channel string, client string, reason string) {
return
}
cl.write(&irc.Message{cl.getPrefix(), irc.PART, []string{channel, reason}})
cl.write(cl.getPrefix(), irc.PART, []string{channel, reason})
ch.Log(cl, irc.PART, reason)
ch.clients.Delete(client)
@ -374,7 +374,7 @@ func (s *Server) partAllChannels(client string, reason string) {
}
}
func (s *Server) revealChannelLog(channel string, client string, page int, full bool) {
func (s *Server) revealChannelLog(channel string, client string, page int, showAll bool) {
cl := s.getClient(client)
if cl == nil {
return
@ -390,7 +390,7 @@ func (s *Server) revealChannelLog(channel string, client string, page int, full
return
}
r := ch.RevealLog(page, full)
r := ch.RevealLog(page, showAll)
for _, rev := range r {
cl.sendMessage(rev)
}
@ -452,7 +452,7 @@ func (s *Server) updateClientCount(channel string, client string, reason string)
if ccount < chancount {
for i := ccount; i < chancount; i++ {
cl.write(&irc.Message{s.getAnonymousPrefix(i), irc.JOIN, []string{channel}})
cl.write(s.getAnonymousPrefix(i), irc.JOIN, []string{channel})
}
ch.clients.Store(cl.identifier, chancount)
@ -463,7 +463,7 @@ func (s *Server) updateClientCount(channel string, client string, reason string)
pr = reason
}
cl.write(&irc.Message{s.getAnonymousPrefix(i - 1), irc.PART, []string{channel, pr}})
cl.write(s.getAnonymousPrefix(i-1), irc.PART, []string{channel, pr})
reasonShown = true
}
} else {
@ -525,7 +525,7 @@ func (s *Server) sendTopic(channel string, client string, changed bool) {
tprefix = prefixAnonIRC
tcommand = irc.RPL_TOPIC
}
cl.write(&irc.Message{&tprefix, tcommand, []string{channel, ch.topic}})
cl.write(&tprefix, tcommand, []string{channel, ch.topic})
if !changed {
cl.writeMessage(strings.Join([]string{irc.RPL_TOPICWHOTIME, cl.nick, channel, prefixAnonymous.Name, fmt.Sprintf("%d", ch.topictime)}, " "), nil)
@ -627,7 +627,7 @@ func (s *Server) handleMode(c *Client, params []string) {
ch.clients.Range(func(k, v interface{}) bool {
cl := s.getClient(k.(string))
if cl != nil {
cl.write(&irc.Message{&prefixAnonymous, irc.MODE, []string{params[0], ch.printModes(addedmodes, removedmodes)}})
cl.write(&prefixAnonymous, irc.MODE, []string{params[0], ch.printModes(addedmodes, removedmodes)})
}
return true
@ -706,18 +706,112 @@ func (s *Server) sendUsage(cl *Client, command string) {
}
sort.Strings(commands)
var printedLabel bool
var usage []string
for _, cmd := range commands {
usage = u[cmd]
if command == COMMAND_HELP {
// Print all commands
var perms []int
for permission := range permissionLabels {
perms = append(perms, permission)
}
sort.Ints(perms)
cl.sendMessage(cmd + " " + usage[0])
for _, ul := range usage[1:] {
cl.sendMessage(" " + ul)
for i := 0; i < 2; i++ {
serverLabel := ""
if i == 1 {
serverLabel = "Server "
}
for _, permission := range perms {
printedLabel = false
for _, cmd := range commands {
if cl.permissionRequired(cmd) != permission {
continue
} else if (i == 0 && containsString(serverCommands, cmd)) || (i == 1 && !containsString(serverCommands, cmd)) {
continue
}
if !printedLabel {
cl.sendNotice(serverLabel + permissionLabels[permission] + " Commands")
printedLabel = true
}
usage = u[cmd]
cl.sendMessage(cmd + " " + usage[0])
for _, ul := range usage[1:] {
cl.sendMessage(" " + ul)
}
}
}
}
} else {
// TODO: Cleanup/merge
for _, cmd := range commands {
usage = u[cmd]
cl.sendMessage(cmd + " " + usage[0])
for _, ul := range usage[1:] {
cl.sendMessage(" " + ul)
}
}
}
}
func (s *Server) ban(channel string, iphash string, accountid int64, expires int64, reason string) error {
if channel == "" || expires < 0 {
log.Println("invalid args")
return nil
}
b := DBBan{}
if iphash != "" {
b = DBBan{Channel: generateHash(channel), Type: BAN_TYPE_ADDRESS, Target: iphash, Expires: expires}
err := db.AddBan(b)
if err != nil {
return err
}
}
if accountid > 0 {
b = DBBan{Channel: generateHash(channel), Type: BAN_TYPE_ACCOUNT, Target: fmt.Sprintf("%d", accountid), Expires: expires}
err := db.AddBan(b)
if err != nil {
return err
}
}
if b.Channel == "" {
log.Println("blank chan")
return nil
}
ch := channel
rs := formatAction("Banned", reason)
if channel == CHANNEL_SERVER {
ch = ""
rs = formatAction("Killed", reason)
}
cls := s.getClients(ch)
for _, cl := range cls {
if cl == nil {
continue
}
if (iphash != "" && cl.iphash == iphash) || (accountid > 0 && cl.account == accountid) {
if channel == CHANNEL_SERVER {
s.killClient(cl, rs)
} else {
s.partChannel(channel, cl.identifier, rs)
}
}
}
return nil
}
func (s *Server) handleUserCommand(client string, command string, params []string) {
cl := s.getClient(client)
if cl == nil {
@ -860,22 +954,28 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
}
page := 1
all := false
if len(params) > 1 {
page, err = strconv.Atoi(params[1])
if err != nil || page < -1 || page == 0 {
cl.sendError("Unable to reveal, invalid page specified")
return
if strings.ToLower(params[1]) == "all" {
page = -1
all = true
} else {
cl.sendError("Unable to reveal, invalid page specified")
return
}
}
}
full := false
if len(params) > 2 {
if strings.ToLower(params[2]) == "full" {
full = true
if strings.ToLower(params[2]) == "all" {
all = true
}
}
s.revealChannelLog(params[0], cl.identifier, page, full)
s.revealChannelLog(params[0], cl.identifier, page, all)
case COMMAND_KICK:
if len(params) < 2 {
s.sendUsage(cl, command)
@ -900,7 +1000,7 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
}
s.partChannel(ch.identifier, rcl.identifier, reason)
cl.sendMessage(fmt.Sprintf("Kicked %s %s", params[0], params[1]))
case COMMAND_BAN:
case COMMAND_BAN, COMMAND_KILL:
if len(params) < 3 {
s.sendUsage(cl, command)
return
@ -908,44 +1008,40 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
ch := s.getChannel(params[0])
if ch == nil {
cl.sendError("Unable to ban, invalid channel specified")
cl.sendError(fmt.Sprintf("Unable to %s, invalid channel specified", strings.ToLower(command)))
return
}
rcl := s.revealClient(params[0], params[1])
if rcl == nil {
cl.sendError("Unable to ban, client not found or no longer connected")
cl.sendError(fmt.Sprintf("Unable to %s, client not found or no longer connected", strings.ToLower(command)))
return
}
reason := strings.Join(params[3:], " ")
expires := parseDuration(params[2])
if expires < 0 {
cl.sendError(fmt.Sprintf("Unable to %s, invalid duration supplied", strings.ToLower(command)))
return
} else if expires > 0 {
expires = time.Now().Unix() + expires
}
partmsg := "Banned"
reason := ""
if len(params) > 3 {
partmsg = fmt.Sprintf("%s: %s", partmsg, reason)
strings.Join(params[3:], " ")
}
// TODO: Apply ban in DB
s.partChannel(ch.identifier, rcl.identifier, partmsg)
cl.sendMessage(fmt.Sprintf("Banned %s %s", params[0], params[1]))
case COMMAND_KILL:
if len(params) < 3 {
s.sendUsage(cl, command)
bch := ch.identifier
if command == COMMAND_KILL {
bch = CHANNEL_SERVER
}
err := s.ban(bch, rcl.iphash, rcl.account, expires, reason)
if err != nil {
cl.sendError(fmt.Sprintf("Unable to %s, %v", strings.ToLower(command), err))
return
}
rcl := s.revealClient(params[0], params[1])
if rcl == nil {
cl.sendError("Unable to kill, client not found or no longer connected")
return
}
reason := "Killed"
if len(params) > 3 {
reason = fmt.Sprintf("%s: %s", reason, strings.Join(params[3:], " "))
}
s.partAllChannels(rcl.identifier, reason)
s.killClient(rcl)
cl.sendMessage(fmt.Sprintf("Killed %s %s", params[0], params[1]))
cl.sendMessage(fmt.Sprintf("%sed %s %s", strings.ToLower(command), params[0], params[1]))
case COMMAND_STATS:
cl.sendMessage(fmt.Sprintf("%d clients in %d channels", s.clientCount(), s.channelCount()))
@ -996,7 +1092,7 @@ func (s *Server) handlePrivmsg(channel string, client string, message string) {
ch.clients.Range(func(k, v interface{}) bool {
chcl := s.getClient(k.(string))
if chcl != nil && chcl.identifier != client {
chcl.write(&irc.Message{&prefixAnonymous, irc.PRIVMSG, []string{channel, message}})
chcl.write(&prefixAnonymous, irc.PRIVMSG, []string{channel, message})
}
return true
@ -1013,14 +1109,14 @@ func (s *Server) handleRead(c *Client) {
c.conn.SetReadDeadline(time.Now().Add(300 * time.Second))
if _, ok := s.clients.Load(c.identifier); !ok {
s.killClient(c)
s.killClient(c, "")
return
}
msg, err := c.reader.Decode()
if msg == nil || err != nil {
// Error decoding message, client probably disconnected
s.killClient(c)
s.killClient(c, "")
return
}
if debugMode && (verbose || (len(msg.Command) >= 4 && msg.Command[0:4] != irc.PING && msg.Command[0:4] != irc.PONG)) {
@ -1065,7 +1161,7 @@ func (s *Server) handleRead(c *Client) {
if !authSuccess {
c.sendPasswordIncorrect()
s.killClient(c)
s.killClient(c, "")
}
} else if msg.Command == irc.CAP && len(msg.Params) > 0 && len(msg.Params[0]) > 0 && msg.Params[0] == irc.CAP_LS {
c.writeMessage(irc.CAP, []string{irc.CAP_LS, "userhost-in-names"})
@ -1185,7 +1281,7 @@ func (s *Server) handleRead(c *Client) {
s.partChannel(channel, c.identifier, "")
}
} else if msg.Command == irc.QUIT {
s.killClient(c)
s.killClient(c, "")
} else {
s.handleUserCommand(c.identifier, msg.Command, msg.Params)
}
@ -1251,10 +1347,10 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
go s.handleWrite(client)
s.handleRead(client) // Block until the connection is closed
s.killClient(client)
s.killClient(client, "")
}
func (s *Server) killClient(c *Client) {
func (s *Server) killClient(c *Client, reason string) {
if c == nil || c.state == ENTITY_STATE_TERMINATING {
return
}
@ -1263,7 +1359,7 @@ func (s *Server) killClient(c *Client) {
select {
case c.terminate <- true:
if _, ok := s.clients.Load(c.identifier); ok {
s.partAllChannels(c.identifier, "")
s.partAllChannels(c.identifier, reason)
}
c.wg.Wait()
default:
@ -1343,7 +1439,7 @@ func (s *Server) pingClients() {
s.clients.Range(func(k, v interface{}) bool {
cl := v.(*Client)
if cl != nil {
cl.write(&irc.Message{nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))}})
cl.write(nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))})
}
return true

View File

@ -7,6 +7,7 @@ import (
"fmt"
"math/rand"
"sort"
"strconv"
"strings"
"golang.org/x/crypto/sha3"
@ -81,3 +82,50 @@ func containsString(s []string, e string) bool {
func p(err error) bool {
return err != nil && err != sql.ErrNoRows
}
func formatAction(action string, reason string) string {
rs := action
if reason != "" {
rs += ": " + reason
}
return rs
}
func parseDuration(duration string) int64 {
duration = strings.TrimSpace(duration)
if intval, err := strconv.Atoi(duration); err == nil {
if intval == 0 {
return 0 // Never expire
}
}
if len(duration) < 2 {
return -1 // Value and unit are required
}
sv := duration[0 : len(duration)-1]
unit := strings.ToLower(duration[len(duration)-1:])
value, err := strconv.ParseInt(sv, 10, 64)
if err != nil || value < 0 {
return -1
}
switch unit {
case "y":
return value * 3600 * 24 * 365
case "w":
return value * 3600 * 24 * 7
case "d":
return value * 3600 * 24
case "h":
return value * 3600
case "m":
return value * 60
case "s":
return value
}
return -1
}