1
2 package portal
3
4 import (
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log"
10 "os"
11 "os/exec"
12 "path"
13 "syscall"
14 "time"
15 "unsafe"
16
17 "github.com/creack/pty"
18 "github.com/gliderlabs/ssh"
19 gossh "golang.org/x/crypto/ssh"
20 )
21
22 const (
23
24 ListenTimeout = 1 * time.Second
25
26
27 IdleTimeout = 1 * time.Minute
28 )
29
30
31 type Portal struct {
32 Name string
33 Address string
34 Command []string
35 Server *ssh.Server
36 }
37
38
39 func New(name string, address string, command []string) (*Portal, error) {
40 if address == "" {
41 return nil, errors.New("no address supplied")
42 } else if command == nil || command[0] == "" {
43 return nil, errors.New("no command supplied")
44 }
45
46 server := &ssh.Server{
47 Addr: address,
48 IdleTimeout: IdleTimeout,
49 Handler: func(sshSession ssh.Session) {
50 ptyReq, winCh, isPty := sshSession.Pty()
51 if !isPty {
52 io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
53 sshSession.Exit(1)
54 return
55 }
56
57 cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
58 defer cancelCmd()
59
60 var args []string
61 if len(command) > 1 {
62 args = command[1:]
63 }
64 cmd := exec.CommandContext(cmdCtx, command[0], args...)
65
66 cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
67
68 stderr, err := cmd.StderrPipe()
69 if err != nil {
70 log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
71 return
72 }
73 go func() {
74 io.Copy(sshSession.Stderr(), stderr)
75 }()
76
77 f, err := pty.Start(cmd)
78 if err != nil {
79 io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
80 sshSession.Exit(1)
81 return
82 }
83 go func() {
84 for win := range winCh {
85 setWinsize(f, win.Width, win.Height)
86 }
87 }()
88
89 go func() {
90 io.Copy(f, sshSession)
91 }()
92 io.Copy(sshSession, f)
93
94 f.Close()
95 cmd.Wait()
96 },
97 PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
98 return true
99 },
100 PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
101 return true
102 },
103 PasswordHandler: func(ctx ssh.Context, password string) bool {
104 return true
105 },
106 KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
107 return true
108 },
109 }
110
111 homeDir, err := os.UserHomeDir()
112 if err != nil {
113 return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
114 }
115
116 err = server.SetOption(ssh.HostKeyFile(path.Join(homeDir, ".ssh", "id_rsa")))
117 if err != nil {
118 return nil, fmt.Errorf("failed to set host key file: %s", err)
119 }
120
121 t := time.NewTimer(ListenTimeout)
122 errs := make(chan error)
123 go func() {
124 err := server.ListenAndServe()
125 if err != nil {
126 errs <- fmt.Errorf("failed to start SSH server: %s", err)
127 }
128 }()
129 select {
130 case err = <-errs:
131 return nil, err
132 case <-t.C:
133
134 }
135
136 p := Portal{Name: name, Address: address, Command: command, Server: server}
137
138 return &p, nil
139 }
140
141
142 func (p *Portal) Close() {
143 p.Server.Close()
144 }
145
146
147 func (p *Portal) Shutdown() {
148 p.Server.Shutdown(context.Background())
149 }
150
151 func setWinsize(f *os.File, w, h int) {
152 syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
153 uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
154 }
155
View as plain text