ssh.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. package main
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/rsa"
  6. "crypto/x509"
  7. "encoding/pem"
  8. "fmt"
  9. "io"
  10. "net"
  11. "os"
  12. "strings"
  13. "golang.org/x/crypto/ssh"
  14. )
  15. type SSHServer struct {
  16. port int
  17. }
  18. func NewSSHServer(port int) *SSHServer {
  19. return &SSHServer{port: port}
  20. }
  21. func (s *SSHServer) Start() error {
  22. // SSH server configuration
  23. config := &ssh.ServerConfig{
  24. NoClientAuth: true, // Anonymous access
  25. }
  26. // Get or create persistent host key
  27. privateKey, err := getOrCreateHostKey()
  28. if err != nil {
  29. return fmt.Errorf("failed to get host key: %v", err)
  30. }
  31. config.AddHostKey(privateKey)
  32. // Listen for connections
  33. listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
  34. if err != nil {
  35. return err
  36. }
  37. defer listener.Close()
  38. fmt.Printf("SSH server listening on :%d\n", s.port)
  39. // Simple connection limiting
  40. sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
  41. for {
  42. conn, err := listener.Accept()
  43. if err != nil {
  44. // Connection error - continue accepting others
  45. continue
  46. }
  47. select {
  48. case sem <- struct{}{}:
  49. go func() {
  50. defer func() { <-sem }()
  51. s.handleConnection(conn, config)
  52. }()
  53. default:
  54. // Too many connections
  55. conn.Close()
  56. }
  57. }
  58. }
  59. func (s *SSHServer) handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
  60. defer netConn.Close()
  61. // Rate limiting
  62. if !rateLimiter.Allow(netConn.RemoteAddr().String()) {
  63. netConn.Write([]byte("Rate limit exceeded. Please try again later.\r\n"))
  64. return
  65. }
  66. // Perform SSH handshake
  67. sshConn, chans, reqs, err := ssh.NewServerConn(netConn, config)
  68. if err != nil {
  69. // Handshake failed - continue accepting others
  70. return
  71. }
  72. defer sshConn.Close()
  73. go ssh.DiscardRequests(reqs)
  74. // Handle channels (sessions)
  75. for newChannel := range chans {
  76. if newChannel.ChannelType() != "session" {
  77. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  78. continue
  79. }
  80. channel, requests, err := newChannel.Accept()
  81. if err != nil {
  82. // Channel error - continue
  83. continue
  84. }
  85. go s.handleSession(channel, requests)
  86. }
  87. }
  88. func (s *SSHServer) handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
  89. defer channel.Close()
  90. // Handle session requests
  91. go func() {
  92. for req := range requests {
  93. switch req.Type {
  94. case "shell", "pty-req":
  95. req.Reply(true, nil)
  96. default:
  97. req.Reply(false, nil)
  98. }
  99. }
  100. }()
  101. fmt.Fprintf(channel, "Welcome to ch.at\r\n")
  102. fmt.Fprintf(channel, "Type your message and press Enter. Type 'exit' to quit.\r\n")
  103. fmt.Fprintf(channel, "> ")
  104. // Read line by line
  105. var input strings.Builder
  106. buf := make([]byte, 1024)
  107. for {
  108. n, err := channel.Read(buf)
  109. if err != nil {
  110. if err != io.EOF {
  111. // Read error - exit session
  112. }
  113. return
  114. }
  115. data := string(buf[:n])
  116. for _, ch := range data {
  117. if ch == '\n' || ch == '\r' {
  118. if input.Len() > 0 {
  119. query := strings.TrimSpace(input.String())
  120. input.Reset()
  121. if query == "exit" {
  122. fmt.Fprintf(channel, "Goodbye!\r\n")
  123. return
  124. }
  125. // Get LLM response with streaming
  126. ctx := context.Background()
  127. stream, err := getLLMResponseStream(ctx, query)
  128. if err != nil {
  129. fmt.Fprintf(channel, "Error: %v\r\n", err)
  130. fmt.Fprintf(channel, "> ")
  131. continue
  132. }
  133. // Stream response as it arrives
  134. for chunk := range stream {
  135. fmt.Fprint(channel, chunk)
  136. if f, ok := channel.(interface{ Flush() }); ok {
  137. f.Flush()
  138. }
  139. }
  140. fmt.Fprintf(channel, "\r\n> ")
  141. }
  142. } else {
  143. input.WriteRune(ch)
  144. }
  145. }
  146. }
  147. }
  148. // getOrCreateHostKey loads existing key or generates new one
  149. func getOrCreateHostKey() (ssh.Signer, error) {
  150. keyPath := "ssh_host_key"
  151. // Try to load existing key
  152. if keyData, err := os.ReadFile(keyPath); err == nil {
  153. return ssh.ParsePrivateKey(keyData)
  154. }
  155. // Generate new ephemeral key (more private but less convenient)
  156. // Users will see "host key changed" warnings on each restart
  157. key, err := rsa.GenerateKey(rand.Reader, 2048)
  158. if err != nil {
  159. return nil, err
  160. }
  161. // Optionally save for convenience (comment out for max privacy)
  162. keyData := pem.EncodeToMemory(&pem.Block{
  163. Type: "RSA PRIVATE KEY",
  164. Bytes: x509.MarshalPKCS1PrivateKey(key),
  165. })
  166. if err := os.WriteFile(keyPath, keyData, 0600); err != nil {
  167. // Couldn't save host key - continue anyway
  168. }
  169. return ssh.NewSignerFromKey(key)
  170. }