Resolve client write buffer deadlock, add ProfilingPort config option

This commit is contained in:
Trevor Slocum 2017-04-29 14:36:33 -07:00
parent 80a5f88b0f
commit 4d0dd1b9c1
4 changed files with 43 additions and 28 deletions

View File

@ -25,6 +25,7 @@ import (
"github.com/orcaman/concurrent-map"
irc "gopkg.in/sorcix/irc.v2"
_ "net/http/pprof"
"os"
"os/signal"
"syscall"
@ -41,6 +42,7 @@ _| _| _| _| _| _| _| _| _| _| _| _|
_| _| _| _| _|_| _| _| _|_|_| _| _| _|_|_|
`
const letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
const writebuffersize = 10
type Pair struct {
Key string
@ -92,5 +94,6 @@ func main() {
server.reload()
}()
go server.startProfiling()
server.listen()
}

View File

@ -35,7 +35,10 @@ func (c *Client) write(msg *irc.Message) {
func (c *Client) handleWrite() {
for msg := range c.writebuffer {
c.Lock()
if msg == nil {
return
}
addnick := false
if _, err := strconv.Atoi(msg.Command); err == nil {
addnick = true
@ -51,7 +54,6 @@ func (c *Client) handleWrite() {
log.Println(c.identifier, "->", msg)
}
c.writer.Encode(msg)
c.Unlock()
}
}

View File

@ -10,6 +10,9 @@ import (
const ENTITY_CLIENT = 0
const ENTITY_CHANNEL = 1
const ENTITY_STATE_TERMINATING = 0
const ENTITY_STATE_NORMAL = 1
const CLIENT_MODES = "c"
const CHANNEL_MODES = "cipstz"
const CHANNEL_MODES_ARG = "kl"
@ -18,6 +21,7 @@ type Entity struct {
entitytype int
identifier string
created int64
state int
modes cmap.ConcurrentMap
*sync.RWMutex

View File

@ -14,13 +14,15 @@ import (
cmap "github.com/orcaman/concurrent-map"
irc "gopkg.in/sorcix/irc.v2"
"math/rand"
"net/http"
"os"
"reflect"
)
type Config struct {
SSLCert string
SSLKey string
SSLCert string
SSLKey string
ProfilingPort int
}
type Server struct {
@ -111,7 +113,7 @@ func (s *Server) joinChannel(channel string, client string) {
}
if ch == nil {
ch = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), cmap.New(), new(sync.RWMutex)}, cmap.New(), "", 0}
ch = &Channel{Entity{ENTITY_CHANNEL, channel, time.Now().Unix(), ENTITY_STATE_NORMAL, cmap.New(), new(sync.RWMutex)}, cmap.New(), "", 0}
s.channels.Set(channel, ch)
} else if ch.hasMode("z") && !cl.ssl {
cl.sendNotice("Unable to join " + channel + ": SSL connections only (channel mode +z)")
@ -192,7 +194,6 @@ func (s *Server) updateClientCount(channel string) {
continue
}
cl.Lock()
chancount := s.getClientCount(channel, cclient)
if ccount < chancount {
@ -206,12 +207,10 @@ func (s *Server) updateClientCount(channel string) {
cl.write(&irc.Message{s.getAnonymousPrefix(i - 1), irc.PART, []string{channel}})
}
} else {
cl.Unlock()
continue
}
ch.clients.Set(cclient, chancount)
cl.Unlock()
}
}
@ -285,10 +284,8 @@ func (s *Server) handleTopic(channel string, client string, topic string) {
}
if topic != "" {
ch.Lock()
ch.topic = topic
ch.topictime = time.Now().Unix()
ch.Unlock()
for cls := range ch.clients.IterBuffered() {
s.sendTopic(channel, cls.Key, true)
@ -322,13 +319,11 @@ func (s *Server) handleMode(c *Client, params []string) {
lastmodes[ms.Key] = ms.Val.(string)
}
ch.Lock()
if params[1][0] == '+' {
ch.addModes(params[1][1:])
} else {
ch.removeModes(params[1][1:])
}
ch.Unlock()
s.enforceModes(params[0])
if !reflect.DeepEqual(ch.modes.Items(), lastmodes) {
@ -369,13 +364,11 @@ func (s *Server) handleMode(c *Client, params []string) {
lastmodes := c.getModes()
if len(params) > 1 && len(params[1]) > 0 && (params[1][0] == '+' || params[1][0] == '-') {
c.Lock()
if params[1][0] == '+' {
c.addModes(params[1][1:])
} else {
c.removeModes(params[1][1:])
}
c.Unlock()
}
if !reflect.DeepEqual(c.modes, lastmodes) {
@ -429,14 +422,13 @@ func (s *Server) handleRead(c *Client) {
c.conn.SetDeadline(time.Now().Add(300 * time.Second))
if !s.clients.Has(c.identifier) {
c.conn.Close()
s.killClient(c)
return
}
msg, err := c.reader.Decode()
if msg == nil || err != nil {
c.conn.Close()
s.partAllChannels(c.identifier)
s.killClient(c)
return
}
if len(msg.Command) >= 4 && msg.Command[0:4] != irc.PING && msg.Command[0:4] != irc.PONG {
@ -540,8 +532,7 @@ func (s *Server) handleRead(c *Client) {
s.partChannel(channel, c.identifier, "")
}
} else if msg.Command == irc.QUIT {
c.conn.Close()
s.partAllChannels(c.identifier)
s.killClient(c)
}
}
}
@ -550,8 +541,6 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
defer conn.Close()
var identifier string
s.Lock()
for {
identifier = randomIdentifier()
if !s.clients.Has(identifier) {
@ -559,13 +548,28 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
}
}
client := &Client{Entity{ENTITY_CLIENT, identifier, time.Now().Unix(), cmap.New(), new(sync.RWMutex)}, ssl, "*", "", "", conn, make(chan *irc.Message), irc.NewDecoder(conn), irc.NewEncoder(conn), false}
client := &Client{Entity{ENTITY_CLIENT, identifier, time.Now().Unix(), ENTITY_STATE_NORMAL, cmap.New(), new(sync.RWMutex)}, ssl, "*", "", "", conn, make(chan *irc.Message, writebuffersize), irc.NewDecoder(conn), irc.NewEncoder(conn), false}
s.clients.Set(client.identifier, client)
s.Unlock()
go client.handleWrite()
s.handleRead(client)
s.killClient(client)
close(client.writebuffer)
s.clients.Remove(identifier)
}
func (s *Server) killClient(c *Client) {
if c.state == ENTITY_STATE_TERMINATING {
return
}
c.state = ENTITY_STATE_TERMINATING
c.write(nil)
c.conn.Close()
if s.clients.Has(c.identifier) {
s.partAllChannels(c.identifier)
}
}
func (s *Server) listenPlain() {
@ -638,7 +642,6 @@ func (s *Server) listenSSL() {
func (s *Server) pingClients() {
for {
s.Lock()
for cls := range s.clients.IterBuffered() {
cl := s.getClient(cls.Key)
@ -646,19 +649,16 @@ func (s *Server) pingClients() {
cl.write(&irc.Message{nil, irc.PING, []string{fmt.Sprintf("anonirc%d%d", int32(time.Now().Unix()), rand.Intn(1000))}})
}
}
s.Unlock()
time.Sleep(90 * time.Second)
}
}
func (s *Server) loadConfig() {
s.Lock()
if _, err := os.Stat("anonircd.conf"); err == nil {
if _, err := toml.DecodeFile("anonircd.conf", &s.config); err != nil {
log.Fatalf("Failed to read anonircd.conf: %v", err)
}
}
s.Unlock()
}
func (s *Server) reload() {
@ -668,6 +668,12 @@ func (s *Server) reload() {
s.restartssl <- true
}
func (s *Server) startProfiling() {
if s.config.ProfilingPort > 0 {
http.ListenAndServe(fmt.Sprintf("localhost:%d", s.config.ProfilingPort), nil)
}
}
func (s *Server) listen() {
go s.listenPlain()
go s.listenSSL()