twins/server.go

313 lines
6.7 KiB
Go

package main
import (
"bufio"
"bytes"
"crypto/tls"
"crypto/x509"
"fmt"
"log"
"net"
"net/url"
"path"
"regexp"
"strconv"
"strings"
"time"
"unicode/utf8"
)
const (
readTimeout = 30 * time.Second
urlMaxLength = 1024
)
const (
statusInput = 10
statusSensitiveInput = 11
statusSuccess = 20
statusRedirectTemporary = 30
statusRedirectPermanent = 31
statusTemporaryFailure = 40
statusUnavailable = 41
statusCGIError = 42
statusProxyError = 43
statusPermanentFailure = 50
statusNotFound = 51
statusGone = 52
statusProxyRequestRefused = 53
statusBadRequest = 59
)
var slashesRegexp = regexp.MustCompile(`[^\\]\/`)
func writeHeader(c net.Conn, code int, meta string) {
fmt.Fprintf(c, "%d %s\r\n", code, meta)
if verbose {
log.Printf("< %d %s\n", code, meta)
}
}
func writeStatus(c net.Conn, code int) {
var meta string
switch code {
case statusTemporaryFailure:
meta = "Temporary failure"
case statusProxyError:
meta = "Proxy error"
case statusBadRequest:
meta = "Bad request"
case statusNotFound:
meta = "Not found"
case statusProxyRequestRefused:
meta = "Proxy request refused"
}
writeHeader(c, code, meta)
}
func scanCRLF(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := bytes.IndexByte(data, '\r'); i >= 0 {
// We have a full newline-terminated line.
return i + 1, data[0:i], nil
}
// If we're at EOF, we have a final, non-terminated line. Return it.
if atEOF {
return len(data), data, nil
}
// Request more data.
return 0, nil, nil
}
func replaceWithUserInput(command []string, request *url.URL) []string {
newCommand := make([]string, len(command))
copy(newCommand, command)
for i, piece := range newCommand {
if strings.Contains(piece, "$USERINPUT") {
requestQuery, err := url.QueryUnescape(request.RawQuery)
if err == nil {
newCommand[i] = strings.ReplaceAll(piece, "$USERINPUT", requestQuery)
}
}
}
return newCommand
}
func servePath(c *tls.Conn, request *url.URL, serve *pathConfig) {
resolvedPath := request.Path
requestSplit := strings.Split(request.Path, "/")
pathSlashes := len(slashesRegexp.FindAllStringIndex(serve.Path, -1))
if len(request.Path) > 0 && request.Path[0] == '/' {
pathSlashes++ // Regexp does not match starting slash
}
if len(requestSplit) >= pathSlashes+1 {
resolvedPath = "/" + strings.Join(requestSplit[pathSlashes+1:], "/")
}
if serve.Proxy != "" {
serveProxy(c, request, serve.Proxy)
return
} else if serve.FastCGI != "" {
contentType := "text/gemini; charset=utf-8"
if serve.Type != "" {
contentType = serve.Type
}
writeHeader(c, statusSuccess, contentType)
filePath := path.Join(serve.Root, request.Path[1:])
serveFastCGI(c, config.fcgiPools[serve.FastCGI], request, filePath)
return
} else if serve.cmd != nil {
requireInput := serve.Input != "" || serve.SensitiveInput != ""
if requireInput {
newCommand := replaceWithUserInput(serve.cmd, request)
if newCommand != nil {
serveCommand(c, request, newCommand)
return
}
}
serveCommand(c, request, serve.cmd)
return
}
filePath := resolvedPath
if len(filePath) > 0 && filePath[0] == '/' {
filePath = filePath[1:]
}
serveFile(c, request, path.Join(serve.Root, filePath), serve.ListDirectory)
}
func serveConn(c *tls.Conn) {
var requestData string
scanner := bufio.NewScanner(c)
scanner.Split(scanCRLF)
if scanner.Scan() {
requestData = scanner.Text()
}
if err := scanner.Err(); err != nil {
writeStatus(c, statusBadRequest)
return
}
state := c.ConnectionState()
certs := state.PeerCertificates
var clientCertKeys [][]byte
for _, cert := range certs {
pubKey, err := x509.MarshalPKIXPublicKey(cert.PublicKey)
if err != nil {
continue
}
clientCertKeys = append(clientCertKeys, pubKey)
}
if verbose {
log.Printf("> %s\n", requestData)
}
if len(requestData) > urlMaxLength || !utf8.ValidString(requestData) {
writeStatus(c, statusBadRequest)
return
}
request, err := url.Parse(requestData)
if err != nil {
writeStatus(c, statusBadRequest)
return
}
requestHostname := request.Hostname()
if requestHostname == "" || strings.ContainsRune(requestHostname, ' ') {
writeStatus(c, statusBadRequest)
return
}
var requestPort int
if request.Port() != "" {
requestPort, err = strconv.Atoi(request.Port())
if err != nil {
requestPort = 0
}
}
if request.Scheme == "" {
request.Scheme = "gemini"
}
if request.Scheme != "gemini" || (requestPort > 0 && requestPort != config.port) {
writeStatus(c, statusProxyRequestRefused)
}
if request.Path == "" {
// Redirect to /
writeHeader(c, statusRedirectPermanent, requestData+"/")
return
}
pathBytes := []byte(request.Path)
strippedPath := request.Path
if strippedPath[0] == '/' {
strippedPath = strippedPath[1:]
}
var matchedHost bool
for hostname := range config.Hosts {
if requestHostname != hostname {
continue
}
matchedHost = true
for _, serve := range config.Hosts[hostname].Paths {
matchedRegexp := serve.r != nil && serve.r.Match(pathBytes)
matchedPrefix := serve.r == nil && strings.HasPrefix(request.Path, serve.Path)
if !matchedRegexp && !matchedPrefix {
continue
}
requireInput := serve.Input != "" || serve.SensitiveInput != ""
if request.RawQuery == "" && requireInput {
if serve.Input != "" {
writeHeader(c, statusInput, serve.Input)
return
} else if serve.SensitiveInput != "" {
writeHeader(c, statusSensitiveInput, serve.SensitiveInput)
return
}
}
if matchedRegexp || matchedPrefix {
servePath(c, request, serve)
return
}
}
break
}
if matchedHost {
writeStatus(c, statusNotFound)
} else {
writeStatus(c, statusProxyRequestRefused)
}
}
func handleConn(c *tls.Conn) {
if verbose {
t := time.Now()
defer func() {
d := time.Since(t)
if d > time.Second {
d = d.Round(time.Second)
} else {
d = d.Round(time.Millisecond)
}
log.Printf("took %s", d)
}()
}
defer c.Close()
c.SetReadDeadline(time.Now().Add(readTimeout))
serveConn(c)
}
func getCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
host := config.Hosts[info.ServerName]
if host != nil {
return host.cert, nil
}
for _, host := range config.Hosts {
return host.cert, nil
}
return nil, nil
}
func handleListener(l net.Listener) {
for {
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
}
go handleConn(conn.(*tls.Conn))
}
}
func listen(address string) {
tlsConfig := &tls.Config{
ClientAuth: tls.RequestClientCert,
GetCertificate: getCertificate,
}
listener, err := tls.Listen("tcp", address, tlsConfig)
if err != nil {
log.Fatalf("failed to listen on %s: %s", address, err)
}
handleListener(listener)
}