...

Source file src/gitlab.com/tslocum/sshtargate/portal/portal.go

Documentation: gitlab.com/tslocum/sshtargate/portal

     1  // Package portal provides SSH portals to applications.
     2  package portal
     3  
     4  import (
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"os"
    11  	"os/exec"
    12  	"path"
    13  	"syscall"
    14  	"time"
    15  	"unsafe"
    16  
    17  	"github.com/creack/pty"
    18  	"github.com/gliderlabs/ssh"
    19  	gossh "golang.org/x/crypto/ssh"
    20  )
    21  
    22  const (
    23  	// ListenTimeout is the maximum time to start listening on an address.
    24  	ListenTimeout = 1 * time.Second
    25  
    26  	// IdleTimeout is the maximum time for a connection to be inactive.
    27  	IdleTimeout = 1 * time.Minute
    28  )
    29  
    30  // Portal is an SSH portal to an application.
    31  type Portal struct {
    32  	Name    string
    33  	Address string
    34  	Command []string
    35  	Server  *ssh.Server
    36  }
    37  
    38  // New opens an SSH portal to an application.
    39  func New(name string, address string, command []string) (*Portal, error) {
    40  	if address == "" {
    41  		return nil, errors.New("no address supplied")
    42  	} else if command == nil || command[0] == "" {
    43  		return nil, errors.New("no command supplied")
    44  	}
    45  
    46  	server := &ssh.Server{
    47  		Addr:        address,
    48  		IdleTimeout: IdleTimeout,
    49  		Handler: func(sshSession ssh.Session) {
    50  			ptyReq, winCh, isPty := sshSession.Pty()
    51  			if !isPty {
    52  				io.WriteString(sshSession, "failed to start command: non-interactive terminals are not supported\n")
    53  				sshSession.Exit(1)
    54  				return
    55  			}
    56  
    57  			cmdCtx, cancelCmd := context.WithCancel(sshSession.Context())
    58  			defer cancelCmd()
    59  
    60  			var args []string
    61  			if len(command) > 1 {
    62  				args = command[1:]
    63  			}
    64  			cmd := exec.CommandContext(cmdCtx, command[0], args...)
    65  
    66  			cmd.Env = append(sshSession.Environ(), fmt.Sprintf("TERM=%s", ptyReq.Term))
    67  
    68  			stderr, err := cmd.StderrPipe()
    69  			if err != nil {
    70  				log.Printf("error: failed to create stderr pipe for portal %s: %s", name, err)
    71  				return
    72  			}
    73  			go func() {
    74  				io.Copy(sshSession.Stderr(), stderr)
    75  			}()
    76  
    77  			f, err := pty.Start(cmd)
    78  			if err != nil {
    79  				io.WriteString(sshSession, fmt.Sprintf("failed to start command: failed to initialize pseudo-terminal: %s\n", err))
    80  				sshSession.Exit(1)
    81  				return
    82  			}
    83  			go func() {
    84  				for win := range winCh {
    85  					setWinsize(f, win.Width, win.Height)
    86  				}
    87  			}()
    88  
    89  			go func() {
    90  				io.Copy(f, sshSession)
    91  			}()
    92  			io.Copy(sshSession, f)
    93  
    94  			f.Close()
    95  			cmd.Wait()
    96  		},
    97  		PtyCallback: func(ctx ssh.Context, pty ssh.Pty) bool {
    98  			return true
    99  		},
   100  		PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
   101  			return true
   102  		},
   103  		PasswordHandler: func(ctx ssh.Context, password string) bool {
   104  			return true
   105  		},
   106  		KeyboardInteractiveHandler: func(ctx ssh.Context, challenger gossh.KeyboardInteractiveChallenge) bool {
   107  			return true
   108  		},
   109  	}
   110  
   111  	homeDir, err := os.UserHomeDir()
   112  	if err != nil {
   113  		return nil, fmt.Errorf("failed to retrieve user home dir: %s", err)
   114  	}
   115  
   116  	err = server.SetOption(ssh.HostKeyFile(path.Join(homeDir, ".ssh", "id_rsa")))
   117  	if err != nil {
   118  		return nil, fmt.Errorf("failed to set host key file: %s", err)
   119  	}
   120  
   121  	t := time.NewTimer(ListenTimeout)
   122  	errs := make(chan error)
   123  	go func() {
   124  		err := server.ListenAndServe()
   125  		if err != nil {
   126  			errs <- fmt.Errorf("failed to start SSH server: %s", err)
   127  		}
   128  	}()
   129  	select {
   130  	case err = <-errs:
   131  		return nil, err
   132  	case <-t.C:
   133  		// Server started
   134  	}
   135  
   136  	p := Portal{Name: name, Address: address, Command: command, Server: server}
   137  
   138  	return &p, nil
   139  }
   140  
   141  // Close closes the portal immediately.
   142  func (p *Portal) Close() {
   143  	p.Server.Close()
   144  }
   145  
   146  // Shutdown closes the portal without interrupting active connections.
   147  func (p *Portal) Shutdown() {
   148  	p.Server.Shutdown(context.Background())
   149  }
   150  
   151  func setWinsize(f *os.File, w, h int) {
   152  	syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
   153  		uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
   154  }
   155  

View as plain text