Clean up client disconnection

This commit is contained in:
Trevor Slocum 2017-12-20 02:22:40 -08:00
parent dc0cb8b3c8
commit 5adbcbb640
2 changed files with 47 additions and 50 deletions

View File

@ -25,7 +25,6 @@ type Client struct {
conn net.Conn
writebuffer chan *irc.Message
terminate chan bool
reader *irc.Decoder
writer *irc.Encoder
@ -49,7 +48,6 @@ func NewClient(identifier string, conn net.Conn, ssl bool) *Client {
c.nick = "*"
c.conn = conn
c.writebuffer = make(chan *irc.Message, writebuffersize)
c.terminate = make(chan bool)
c.reader = irc.NewDecoder(conn)
c.writer = irc.NewEncoder(conn)
@ -83,6 +81,7 @@ func (c *Client) write(prefix *irc.Prefix, command string, params []string) {
return
}
c.wg.Add(1)
c.writebuffer <- &irc.Message{Prefix: prefix, Command: command, Params: params}
}
@ -106,6 +105,13 @@ func (c *Client) sendNotice(message string) {
c.sendMessage("*** " + message)
}
func (c *Client) sendBanned(reason string) {
if reason != "" {
reason = fmt.Sprintf(" (%s)", reason)
}
c.writeMessage(irc.ERR_YOUREBANNEDCREEP, []string{"You are banned from this server" + reason})
}
func (c *Client) accessDenied(permissionRequired int) {
ex := ""
if permissionRequired > PERMISSION_CLIENT {

View File

@ -283,7 +283,6 @@ func (s *Server) clientCount() int {
func (s *Server) revealClient(channel string, identifier string) *Client {
riphash, raccount := s.revealClientInfo(channel, identifier)
if riphash == "" && raccount == 0 {
log.Println("hash not found")
return nil
}
@ -605,7 +604,7 @@ func (s *Server) handleTopic(channel string, client string, topic string) {
chp, err := db.GetPermission(cl.account, channel)
if err != nil {
log.Panicf("%+v", err)
} else if ch.hasMode("t") && (chp.Permission < PERMISSION_VIP) {
} else if ch.hasMode("t") && chp.Permission < PERMISSION_VIP {
cl.accessDenied(PERMISSION_VIP)
return
}
@ -824,7 +823,6 @@ func (s *Server) sendUsage(cl *Client, command string) {
func (s *Server) ban(channel string, iphash string, accountid int64, expires int64, reason string) error {
if channel == "" || expires < 0 {
log.Println("invalid args")
return nil
}
@ -837,7 +835,6 @@ func (s *Server) ban(channel string, iphash string, accountid int64, expires int
return err
}
}
log.Println("B")
if accountid > 0 {
b = DBBan{Channel: generateHash(channel), Type: BAN_TYPE_ACCOUNT, Target: fmt.Sprintf("%d", accountid), Expires: expires}
err := db.AddBan(b)
@ -845,7 +842,6 @@ func (s *Server) ban(channel string, iphash string, accountid int64, expires int
return err
}
}
log.Println("C")
if b.Channel == "" {
return nil
}
@ -856,7 +852,6 @@ func (s *Server) ban(channel string, iphash string, accountid int64, expires int
ch = ""
rs = formatAction("Killed", reason)
}
log.Println("D")
cls := s.getClients(ch)
for _, cl := range cls {
if cl == nil {
@ -865,6 +860,7 @@ func (s *Server) ban(channel string, iphash string, accountid int64, expires int
if (iphash != "" && cl.iphash == iphash) || (accountid > 0 && cl.account == accountid) {
if channel == CHANNEL_SERVER {
cl.writeMessage(irc.KILL, []string{cl.nick, rs})
s.killClient(cl, rs)
} else {
s.partChannel(channel, cl.identifier, rs)
@ -1115,7 +1111,7 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
reason := ""
if len(params) > 3 {
strings.Join(params[3:], " ")
reason = strings.Join(params[3:], " ")
}
bch := ch.identifier
@ -1123,13 +1119,12 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
bch = CHANNEL_SERVER
}
err := s.ban(bch, rcl.iphash, rcl.account, expires, reason)
log.Println("A")
if err != nil {
cl.sendError(fmt.Sprintf("Unable to %s, %v", strings.ToLower(command), err))
return
}
cl.sendMessage(fmt.Sprintf("%sed %s %s", strings.ToLower(command), params[0], params[1]))
cl.sendMessage(fmt.Sprintf("%sed %s %s", strings.Title(strings.ToLower(command)), params[0], params[1]))
case COMMAND_STATS:
cl.sendMessage(fmt.Sprintf("%d clients in %d channels", s.clientCount(), s.channelCount()))
case COMMAND_REHASH:
@ -1205,7 +1200,9 @@ func (s *Server) handleRead(c *Client) {
c.conn.SetReadDeadline(time.Now().Add(300 * time.Second))
msg, err := c.reader.Decode()
if msg == nil || err != nil {
if c.state == ENTITY_STATE_TERMINATING {
return
} else if msg == nil || err != nil {
// Error decoding message, client probably disconnected
s.killClient(c, "")
return
@ -1380,34 +1377,34 @@ func (s *Server) handleRead(c *Client) {
}
func (s *Server) handleWrite(c *Client) {
for {
select {
case msg := <-c.writebuffer:
if c.state == ENTITY_STATE_TERMINATING {
continue
}
c.wg.Add(1)
addnick := false
if _, err := strconv.Atoi(msg.Command); err == nil {
addnick = true
} else if msg.Command == irc.CAP {
addnick = true
}
if addnick {
msg.Params = append([]string{c.nick}, msg.Params...)
}
if debugMode && (verbose || len(msg.Command) < 4 || (msg.Command[0:4] != irc.PING && msg.Command[0:4] != irc.PONG)) {
log.Printf("%s <- %s", c.identifier, msg)
}
c.writer.Encode(msg)
werror := false
for msg := range c.writebuffer {
if werror {
// We experienced a write error, stop writing
c.wg.Done()
case <-c.terminate:
close(c.writebuffer)
return
continue
}
addnick := false
if _, err := strconv.Atoi(msg.Command); err == nil {
addnick = true
} else if msg.Command == irc.CAP {
addnick = true
}
if addnick {
msg.Params = append([]string{c.nick}, msg.Params...)
}
if debugMode && (verbose || len(msg.Command) < 4 || (msg.Command[0:4] != irc.PING && msg.Command[0:4] != irc.PONG)) {
log.Printf("%s <- %s", c.identifier, msg)
}
err := c.writer.Encode(msg)
if err != nil {
werror = true
}
c.wg.Done()
}
}
@ -1434,11 +1431,7 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
s.clients.Store(c.identifier, c)
s.handleRead(c) // Block until the connection is closed
} else {
if reason != "" {
reason = fmt.Sprintf(" (%s)", reason)
}
c.writeMessage(irc.ERR_YOUREBANNEDCREEP, []string{"You are banned from this server" + reason})
time.Sleep(1 * time.Second)
c.sendBanned(reason)
}
s.killClient(c, "")
@ -1451,14 +1444,12 @@ func (s *Server) killClient(c *Client, reason string) {
}
c.state = ENTITY_STATE_TERMINATING
select {
case c.terminate <- true:
if _, ok := s.clients.Load(c.identifier); ok {
s.partAllChannels(c.identifier, reason)
}
c.wg.Wait()
default:
if _, ok := s.clients.Load(c.identifier); ok {
s.partAllChannels(c.identifier, reason)
}
c.wg.Wait()
close(c.writebuffer)
c.conn.Close()
}
func (s *Server) listenPlain() {