1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
| func processDirectTcpIpNewChannel(serverConn *ssh.ServerConn, newChannel ssh.NewChannel) { var payload forwardedTCPPayload if err := ssh.Unmarshal(newChannel.ExtraData(), &payload); err != nil { log.Println(err) newChannel.Reject(ssh.ConnectionFailed, "can't parse tcp forward payload") }
log.Printf("process tcp port forwarding, host: %s, port: %d, origial host: %s, original port: %d\n", payload.Addr, payload.Port, payload.OriginAddr, payload.OriginPort)
localAddr := serverConn.LocalAddr()
tcpAddr, _ := net.ResolveTCPAddr(localAddr.Network(), localAddr.String())
var dialerIp net.IP if !tcpAddr.IP.IsLoopback() { dialerIp = tcpAddr.IP }
log.Printf("server local addr: %s, dialer ip: %s\n", localAddr.String(), tcpAddr.IP) dialer := net.Dialer{ LocalAddr: &net.TCPAddr{IP: dialerIp}, }
remoteAddr := fmt.Sprintf("%s:%d", payload.Addr, payload.Port) conn, err := dialer.Dial("tcp", remoteAddr) if err != nil { log.Println(err) newChannel.Reject(ssh.ConnectionFailed, "connect to dest host failed") return }
channel, requests, err := newChannel.Accept() if err != nil { log.Println(err) newChannel.Reject(ssh.ConnectionFailed, "failed to accept") return }
go func(in <-chan *ssh.Request) { for req := range in { log.Println(req.Type) } }(requests)
done := make(chan struct{})
go forward(conn, channel, done)
go forward(channel, conn, done)
<-done <-done }
func sshProxyConnectionManager(nConn net.Conn, config *ssh.ServerConfig, conns map[ssh.ServerConn]struct{}) { conn, chans, reqs, err := ssh.NewServerConn(nConn, config) if err != nil { log.Println("failed to handshake: ", err) return }
if conn.Permissions == nil { log.Printf("logged in with username %s", conn.Conn.User()) } else { log.Printf("logged in with key %s", conn.Permissions.Extensions["username"]) }
conns[*conn] = struct{}{}
go func() { ssh.DiscardRequests(reqs) }()
for newChannel := range chans { switch newChannel.ChannelType() { case DirectTcpIpChannelType: go processDirectTcpIpNewChannel(conn, newChannel) default: newChannel.Reject(ssh.UnknownChannelType, ssh.UnknownChannelType.String()) continue } } }
func sshProxy(l net.Listener, config *ssh.ServerConfig, done chan bool) { running := true conns := make(map[ssh.ServerConn]struct{}, 10)
go func() { <-done if len(conns) > 0 { log.Println("begin to close all current connections") for conn := range conns { conn.Close() } }
log.Println("close listener") l.Close() running = false }()
var wg sync.WaitGroup log.Println("begin to accept ssh connections") wg.Add(1) for running { nConn, err := l.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { log.Println("listener closed, quit accept loop") wg.Done() break } else { wg.Done() log.Fatal("failed to accept incoming connection: ", err) } }
go sshProxyConnectionManager(nConn, config, conns) }
wg.Wait() }
|