ssh.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. package main
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "encoding/pem"
  8. "fmt"
  9. "net"
  10. "os"
  11. "strings"
  12. "golang.org/x/crypto/ssh"
  13. )
  14. type SSHServer struct {
  15. port int
  16. }
  17. func NewSSHServer(port int) *SSHServer {
  18. return &SSHServer{port: port}
  19. }
  20. func (s *SSHServer) Start() error {
  21. // SSH server configuration
  22. config := &ssh.ServerConfig{
  23. NoClientAuth: true, // Anonymous access
  24. }
  25. // Get or create persistent host key
  26. privateKey, err := getOrCreateHostKey()
  27. if err != nil {
  28. return fmt.Errorf("failed to get host key: %v", err)
  29. }
  30. config.AddHostKey(privateKey)
  31. // Listen for connections
  32. listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
  33. if err != nil {
  34. return err
  35. }
  36. defer listener.Close()
  37. fmt.Printf("SSH server listening on :%d\n", s.port)
  38. // Simple connection limiting
  39. sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
  40. for {
  41. conn, err := listener.Accept()
  42. if err != nil {
  43. // Connection error - continue accepting others
  44. continue
  45. }
  46. select {
  47. case sem <- struct{}{}:
  48. go func() {
  49. defer func() { <-sem }()
  50. s.handleConnection(conn, config)
  51. }()
  52. default:
  53. // Too many connections
  54. conn.Close()
  55. }
  56. }
  57. }
  58. func (s *SSHServer) handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
  59. defer netConn.Close()
  60. // Rate limiting
  61. if !rateLimiter.Allow(netConn.RemoteAddr().String()) {
  62. netConn.Write([]byte("Rate limit exceeded. Please try again later.\r\n"))
  63. return
  64. }
  65. // Perform SSH handshake
  66. sshConn, chans, reqs, err := ssh.NewServerConn(netConn, config)
  67. if err != nil {
  68. // Handshake failed - continue accepting others
  69. return
  70. }
  71. defer sshConn.Close()
  72. go ssh.DiscardRequests(reqs)
  73. // Handle channels (sessions)
  74. for newChannel := range chans {
  75. if newChannel.ChannelType() != "session" {
  76. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  77. continue
  78. }
  79. channel, requests, err := newChannel.Accept()
  80. if err != nil {
  81. // Channel error - continue
  82. continue
  83. }
  84. go s.handleSession(channel, requests)
  85. }
  86. }
  87. func (s *SSHServer) handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
  88. defer channel.Close()
  89. // Handle session requests
  90. go func() {
  91. for req := range requests {
  92. switch req.Type {
  93. case "shell", "pty-req":
  94. req.Reply(true, nil)
  95. default:
  96. req.Reply(false, nil)
  97. }
  98. }
  99. }()
  100. fmt.Fprintf(channel, "Welcome to ch.at\r\n")
  101. fmt.Fprintf(channel, "Type your message and press Enter.\r\n")
  102. fmt.Fprintf(channel, "Exit: type 'exit', Ctrl+C, or Ctrl+D\r\n")
  103. fmt.Fprintf(channel, "> ")
  104. // Read line by line
  105. var input strings.Builder
  106. buf := make([]byte, 1024)
  107. for {
  108. n, err := channel.Read(buf)
  109. if err != nil {
  110. // EOF (Ctrl+D) or other error - exit cleanly
  111. return
  112. }
  113. data := string(buf[:n])
  114. for _, ch := range data {
  115. if ch == 3 { // Ctrl+C
  116. fmt.Fprintf(channel, "^C\r\n")
  117. return
  118. } else if ch == '\n' || ch == '\r' {
  119. fmt.Fprintf(channel, "\r\n") // Echo newline
  120. if input.Len() > 0 {
  121. query := strings.TrimSpace(input.String())
  122. input.Reset()
  123. if query == "exit" {
  124. return
  125. }
  126. // Get LLM response with streaming
  127. ctx := context.Background()
  128. stream, err := getLLMResponseStream(ctx, query)
  129. if err != nil {
  130. fmt.Fprintf(channel, "Error: %v\r\n", err)
  131. fmt.Fprintf(channel, "> ")
  132. continue
  133. }
  134. // Stream response as it arrives
  135. for chunk := range stream {
  136. fmt.Fprint(channel, chunk)
  137. if f, ok := channel.(interface{ Flush() }); ok {
  138. f.Flush()
  139. }
  140. }
  141. fmt.Fprintf(channel, "\r\n> ")
  142. }
  143. } else if ch == '\b' || ch == 127 { // Backspace or Delete
  144. if input.Len() > 0 {
  145. // Remove last character from buffer
  146. str := input.String()
  147. input.Reset()
  148. if len(str) > 0 {
  149. input.WriteString(str[:len(str)-1])
  150. // Move cursor back, overwrite with space, move back again
  151. fmt.Fprintf(channel, "\b \b")
  152. }
  153. }
  154. } else {
  155. // Echo the character back to the user
  156. fmt.Fprintf(channel, "%c", ch)
  157. input.WriteRune(ch)
  158. }
  159. }
  160. }
  161. }
  162. // getOrCreateHostKey loads existing key or generates new one
  163. func getOrCreateHostKey() (ssh.Signer, error) {
  164. keyPath := "ssh_host_key"
  165. // Try to load existing key
  166. if keyData, err := os.ReadFile(keyPath); err == nil {
  167. return ssh.ParsePrivateKey(keyData)
  168. }
  169. // Generate new ephemeral key (more private but less convenient)
  170. // Users will see "host key changed" warnings on each restart
  171. key, err := rsa.GenerateKey(rand.Reader, 2048)
  172. if err != nil {
  173. return nil, err
  174. }
  175. // Optionally save for convenience (comment out for max privacy)
  176. keyData := pem.EncodeToMemory(&pem.Block{
  177. Type: "RSA PRIVATE KEY",
  178. Bytes: x509.MarshalPKCS1PrivateKey(key),
  179. })
  180. if err := os.WriteFile(keyPath, keyData, 0600); err != nil {
  181. // Couldn't save host key - continue anyway
  182. }
  183. return ssh.NewSignerFromKey(key)
  184. }