sage/sage.go

378 lines
8.3 KiB
Go

// sage - Markov chain IRC bot
// https://gitlab.com/tslocum/sage
// Written by Trevor 'tee' Slocum <tslocum@gmail.com>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package main
import (
"bufio"
"flag"
"fmt"
"log"
"math/rand"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"path"
"regexp"
"strconv"
"strings"
"syscall"
"time"
"github.com/BurntSushi/toml"
"github.com/dustin/go-humanize"
irc "github.com/fluffle/goirc/client"
"github.com/tslocum/markov"
)
type Config struct {
Server string
Nick string
NickPassword string
Ident string
Name string
Channels []string
MarkovOrder int
MarkovWords int
MarkovTimeout int
DebugPort int
}
var config *Config
var self *markov.BoltTableStore
var memory *markov.Accumulator
var brain *markov.Model
var experience = make(chan string, 100)
var client *irc.Conn
var validword = regexp.MustCompile("^[a-zA-Z]([a-zA-Z \\-'/]+[a-zA-Z])?$")
var invalidchars = regexp.MustCompile("[^a-zA-Z0-9 \\-'/]+")
var commonwords = []string{"a", "an", "at", "are", "arent", "and", "is", "isnt", "not", "of", "i", "you", "he", "she", "him", "her", "his", "hers", "they", "them", "theirs", "us", "our", "ours", "get", "got", "it", "in", "if", "of", "or", "on", "just", "no", "yes", "yeah", "ya", "yea", "yep", "yup"}
var dataDir = flag.String("data", "", "Data directory (contains sage.conf, memories are stored here)")
var importFile = flag.String("import", "", "Import plaintext file into memory")
func loadConfig() {
cfile := "a new data folder\nThen supply the folder path: sage -data /home/sage/data"
if *dataDir != "" {
cfile = path.Join(*dataDir, "sage.conf")
}
nonexistmsg := fmt.Sprintf("Error! Unable to read sage.conf, please copy sage.default.conf to %s", cfile)
var err error
config = new(Config)
if *dataDir != "" {
if _, err = os.Stat(cfile); err == nil {
if _, err = toml.DecodeFile(cfile, &config); err != nil {
log.Fatalf("Failed to read %s: %v", cfile, err)
}
if config.Server == "" || config.Nick == "" {
log.Fatal("Server and Nick parameters in sage.conf are required")
}
} else {
log.Fatalf("%s\n%v", nonexistmsg, err)
}
} else {
log.Fatal(nonexistmsg)
}
if config.MarkovOrder <= 0 {
config.MarkovOrder = 1
}
if config.MarkovWords <= 0 {
config.MarkovWords = 1
}
if config.MarkovTimeout <= 0 {
config.MarkovTimeout = 1000
}
if config.DebugPort > 0 {
go http.ListenAndServe(":"+strconv.Itoa(config.DebugPort), nil)
}
}
func hear(message string) {
experience <- message
}
func learn() {
var m []string
var valid bool
var err error
for message := range experience {
valid = false
m = strings.Split(stripCodes(strings.ToLower(message)), " ")
for _, tidbit := range m {
if !validword.MatchString(tidbit) {
continue // Contains number or other invalid char
} else if tidbit == config.Nick {
continue
}
err = memory.Add(tidbit)
if err != nil {
log.Fatalf("Failed to add memory: %v", err)
}
valid = true
}
if valid {
err = memory.Add("")
if err != nil {
log.Fatalf("Failed to add memory: %v", err)
}
}
}
}
func respond(message string) string {
message = strings.ToLower(message)
pieces := strings.Split(stripCommonWords(message), " ")
timeout := time.After(time.Duration(config.MarkovTimeout) * time.Millisecond)
var thought []string
var response string
var perspective *markov.Generator
Vocalize:
for {
thought = nil
perspective = markov.NewGenerator(brain, uint(config.MarkovOrder), rand.New(rand.NewSource(time.Now().UTC().UnixNano())))
for {
tidbit, err := perspective.Get()
if err != nil {
log.Fatalf("Unable to think: %v", err)
}
if tidbit == "" {
break
}
thought = append(thought, tidbit)
}
if len(thought) >= config.MarkovWords {
if len(pieces) == 0 || pieces[0] == "" {
break Vocalize
} else if response == "" {
response = strings.Join(thought, " ")
} else {
for _, piece := range pieces {
for _, tidbit := range thought {
if tidbit == piece {
break Vocalize
}
}
}
}
}
select {
case <-timeout:
break Vocalize
default:
}
}
if len(thought) >= config.MarkovWords {
return strings.Join(thought, " ")
}
return response
}
func containsNick(message string) bool {
message = stripCodes(strings.ToLower(message))
nick := strings.ToLower(config.Nick)
m := strings.Split(message, " ")
for _, mword := range m {
if mword == nick {
return true
}
}
return false
}
func stripCodes(message string) string {
return invalidchars.ReplaceAllString(message, "")
}
func stripCommonWords(message string) string {
nick := strings.ToLower(config.Nick)
m := strings.Split(message, " ")
for i, mword := range m {
if mword == nick {
m[i] = ""
continue
}
for _, word := range commonwords {
if word == mword {
m[i] = ""
}
}
}
return strings.Join(m, " ")
}
func importMemories() {
if *importFile != "" {
if _, err := os.Stat(*importFile); err == nil {
lines, err := os.Open(*importFile)
defer lines.Close()
if err == nil {
var line string
linecount := int64(0)
self.Bolt.NoSync = true
log.Printf("Importing %s into memory...", *importFile)
scanner := bufio.NewScanner(lines)
for scanner.Scan() {
linecount++
line = strings.TrimSpace(scanner.Text())
if line != "" {
hear(line)
if linecount%1000 == 0 {
log.Printf("Import progress: " + humanize.Comma(linecount) + " lines")
}
}
}
self.Bolt.NoSync = false
self.Bolt.Sync()
log.Printf("Imported %s lines into memory", humanize.Comma(linecount))
} else {
log.Fatalf("Failed to read import file %s", *importFile)
}
} else {
log.Fatalf("Import file %s does not exist", *importFile)
}
}
}
func saveMemories() {
for {
time.Sleep(1 * time.Hour)
if err := self.Bolt.Sync(); err != nil {
log.Printf("Error! Unable to write memories to file: %v", err)
}
}
}
func terminate() {
self.Close()
os.Exit(0)
}
func main() {
flag.Parse()
rand.Seed(time.Now().UTC().UnixNano())
var err error
// Intend
loadConfig()
// Remember
dbfile := path.Join(*dataDir, "sage.db")
self, err = markov.NewBoltTableStore(dbfile)
if err != nil {
log.Fatalf("Failed to open %s: %v", dbfile, err)
}
go learn()
go saveMemories()
// Become
brain = markov.NewModel(self)
memory = markov.NewAccumulator(brain, uint(config.MarkovOrder))
go importMemories()
// Explore
ident := config.Ident
if ident == "" {
ident = config.Nick
}
name := config.Name
if name == "" {
name = config.Nick
}
cfg := irc.NewConfig(config.Nick, ident, name)
cfg.Server = config.Server
cfg.Version = "sage https://gitlab.com/tslocum/sage"
cfg.NewNick = func(n string) string { return n + "^" }
client = irc.Client(cfg)
client.HandleFunc(irc.CONNECTED,
func(conn *irc.Conn, line *irc.Line) {
log.Println("Connected!")
if config.NickPassword != "" {
conn.Privmsg("NickServ", "IDENTIFY "+config.NickPassword)
}
for _, channel := range config.Channels {
conn.Join(channel)
}
})
client.HandleFunc(irc.PRIVMSG,
func(conn *irc.Conn, line *irc.Line) {
channel := line.Args[0]
message := line.Args[1]
hear(message)
if containsNick(message) || rand.Intn(10) == 7 {
client.Privmsg(channel, respond(message))
}
})
quit := make(chan bool)
client.HandleFunc(irc.DISCONNECTED,
func(conn *irc.Conn, line *irc.Line) {
quit <- true
})
termc := make(chan os.Signal, 2)
signal.Notify(termc, os.Interrupt, syscall.SIGTERM)
go func() {
<-termc
terminate()
}()
for {
log.Printf("Connecting to %s as %s...", config.Server, config.Nick)
if err := client.Connect(); err != nil {
log.Printf("Error! Unable to connect: %v", err.Error())
}
<-quit
log.Println("Disconnected...")
time.Sleep(30 * time.Second)
}
terminate()
}