143 lines
3.1 KiB
Go
143 lines
3.1 KiB
Go
package gate
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"path"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/creack/pty"
|
|
"github.com/gliderlabs/ssh"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
ListenTimeout = 1 * time.Second
|
|
IdleTimeout = 1 * time.Minute
|
|
)
|
|
|
|
type Portal struct {
|
|
Name string
|
|
Address string
|
|
Command []string
|
|
Server *ssh.Server
|
|
}
|
|
|
|
func setWinsize(f *os.File, w, h int) {
|
|
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
|
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
|
|
}
|
|
|
|
func NewPortal(name string, address string, command []string) (*Portal, error) {
|
|
if address == "" {
|
|
return nil, errors.New("no address supplied")
|
|
} else if command == nil || command[0] == "" {
|
|
return nil, errors.New("no command supplied")
|
|
}
|
|
|
|
server := &ssh.Server{
|
|
Addr: address,
|
|
IdleTimeout: IdleTimeout,
|
|
Handler: func(sshSession ssh.Session) {
|
|
ptyReq, winCh, isPty := sshSession.Pty()
|
|
if !isPty {
|
|
io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
|
|
sshSession.Exit(1)
|
|
return
|
|
}
|
|
|
|
cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
|
|
defer cancelCmd()
|
|
|
|
var args []string
|
|
if len(command) > 1 {
|
|
args = command[1:]
|
|
}
|
|
cmd := exec.CommandContext(cmdCtx, command[0], args...)
|
|
|
|
cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
|
|
|
|
stderr, err := cmd.StderrPipe()
|
|
if err != nil {
|
|
log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
|
|
return
|
|
}
|
|
go func() {
|
|
io.Copy(sshSession.Stderr(), stderr)
|
|
}()
|
|
|
|
f, err := pty.Start(cmd)
|
|
if err != nil {
|
|
io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
|
|
sshSession.Exit(1)
|
|
return
|
|
}
|
|
go func() {
|
|
for win := range winCh {
|
|
setWinsize(f, win.Width, win.Height)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
io.Copy(f, sshSession)
|
|
}()
|
|
io.Copy(sshSession, f)
|
|
|
|
f.Close()
|
|
cmd.Wait()
|
|
},
|
|
PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
|
|
return true
|
|
},
|
|
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
|
|
return true
|
|
},
|
|
PasswordHandler: func(ctx ssh.Context, password string) bool {
|
|
return true
|
|
},
|
|
KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
|
|
return true
|
|
},
|
|
}
|
|
|
|
homeDir, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
|
|
}
|
|
|
|
err = server.SetOption(ssh.HostKeyFile(path.Join(homeDir, ".ssh", "id_rsa")))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to set host key file: %s", err)
|
|
}
|
|
|
|
t := time.NewTimer(ListenTimeout)
|
|
errs := make(chan error)
|
|
go func() {
|
|
err := server.ListenAndServe()
|
|
if err != nil {
|
|
errs <- fmt.Errorf("failed to start SSH server: %s", err)
|
|
}
|
|
}()
|
|
select {
|
|
case err = <-errs:
|
|
return nil, err
|
|
case <-t.C:
|
|
// Server started
|
|
}
|
|
|
|
p := Portal{Name: name, Address: address, Command: command, Server: server}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
func (p *Portal) Shutdown() {
|
|
p.Server.Close()
|
|
}
|