sshtargate/portal/portal.go

155 lines
3.6 KiB
Go

// Package portal provides SSH portals to applications.
package portal
import (
"context"
"errors"
"fmt"
"io"
"log"
"os"
"os/exec"
"path"
"syscall"
"time"
"unsafe"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"
)
const (
// ListenTimeout is the maximum time to start listening on an address.
ListenTimeout = 1 * time.Second
// IdleTimeout is the maximum time for a connection to be inactive.
IdleTimeout = 1 * time.Minute
)
// Portal is an SSH portal to an application.
type Portal struct {
Name string
Address string
Command []string
Server *ssh.Server
}
// New opens an SSH portal to an application.
func New(name string, address string, command []string) (*Portal, error) {
if address == "" {
return nil, errors.New("no address supplied")
} else if command == nil || command[0] == "" {
return nil, errors.New("no command supplied")
}
server := &ssh.Server{
Addr: address,
IdleTimeout: IdleTimeout,
Handler: func(sshSession ssh.Session) {
ptyReq, winCh, isPty := sshSession.Pty()
if !isPty {
io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
sshSession.Exit(1)
return
}
cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
defer cancelCmd()
var args []string
if len(command) > 1 {
args = command[1:]
}
cmd := exec.CommandContext(cmdCtx, command[0], args...)
cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
stderr, err := cmd.StderrPipe()
if err != nil {
log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
return
}
go func() {
io.Copy(sshSession.Stderr(), stderr)
}()
f, err := pty.Start(cmd)
if err != nil {
io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
sshSession.Exit(1)
return
}
go func() {
for win := range winCh {
setWinsize(f, win.Width, win.Height)
}
}()
go func() {
io.Copy(f, sshSession)
}()
io.Copy(sshSession, f)
f.Close()
cmd.Wait()
},
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
return true
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
return true
},
PasswordHandler: func(ctx ssh.Context, password string) bool {
return true
},
KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
return true
},
}
homeDir, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
}
err = server.SetOption(ssh.HostKeyFile(path.Join(homeDir, ".ssh", "id_rsa")))
if err != nil {
return nil, fmt.Errorf("failed to set host key file: %s", err)
}
t := time.NewTimer(ListenTimeout)
errs := make(chan error)
go func() {
err := server.ListenAndServe()
if err != nil {
errs <- fmt.Errorf("failed to start SSH server: %s", err)
}
}()
select {
case err = <-errs:
return nil, err
case <-t.C:
// Server started
}
p := Portal{Name: name, Address: address, Command: command, Server: server}
return &p, nil
}
// Close closes the portal immediately.
func (p *Portal) Close() {
p.Server.Close()
}
// Shutdown closes the portal without interrupting active connections.
func (p *Portal) Shutdown() {
p.Server.Shutdown(context.Background())
}
func setWinsize(f *os.File, w, h int) {
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
}