ssh.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. package main
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "fmt"
  6. "net"
  7. "strings"
  8. "golang.org/x/crypto/ssh"
  9. )
  10. func StartSSHServer(port int) error {
  11. // SSH server configuration
  12. config := &ssh.ServerConfig{
  13. NoClientAuth: true, // Anonymous access
  14. }
  15. // Get or create persistent host key
  16. privateKey, err := getOrCreateHostKey()
  17. if err != nil {
  18. return fmt.Errorf("failed to get host key: %v", err)
  19. }
  20. config.AddHostKey(privateKey)
  21. // Listen for connections
  22. listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
  23. if err != nil {
  24. return err
  25. }
  26. defer listener.Close()
  27. fmt.Printf("SSH server listening on :%d\n", port)
  28. // Simple connection limiting
  29. sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
  30. for {
  31. conn, err := listener.Accept()
  32. if err != nil {
  33. // Connection error - continue accepting others
  34. continue
  35. }
  36. select {
  37. case sem <- struct{}{}:
  38. go func() {
  39. defer func() { <-sem }()
  40. handleConnection(conn, config)
  41. }()
  42. default:
  43. // Too many connections
  44. conn.Close()
  45. }
  46. }
  47. }
  48. func handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
  49. defer netConn.Close()
  50. // Rate limiting
  51. if !rateLimitAllow(netConn.RemoteAddr().String()) {
  52. netConn.Write([]byte("Rate limit exceeded\r\n"))
  53. return
  54. }
  55. // Perform SSH handshake
  56. sshConn, chans, reqs, err := ssh.NewServerConn(netConn, config)
  57. if err != nil {
  58. // Handshake failed - continue accepting others
  59. return
  60. }
  61. defer sshConn.Close()
  62. go ssh.DiscardRequests(reqs)
  63. // Handle channels (sessions)
  64. for newChannel := range chans {
  65. if newChannel.ChannelType() != "session" {
  66. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  67. continue
  68. }
  69. channel, requests, err := newChannel.Accept()
  70. if err != nil {
  71. // Channel error - continue
  72. continue
  73. }
  74. go handleSession(channel, requests)
  75. }
  76. }
  77. func handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
  78. defer channel.Close()
  79. // Handle session requests
  80. go func() {
  81. for req := range requests {
  82. switch req.Type {
  83. case "shell", "pty-req":
  84. req.Reply(true, nil)
  85. default:
  86. req.Reply(false, nil)
  87. }
  88. }
  89. }()
  90. fmt.Fprintf(channel, "Welcome to ch.at\r\n")
  91. fmt.Fprintf(channel, "Type your message and press Enter.\r\n")
  92. fmt.Fprintf(channel, "Exit: type 'exit', Ctrl+C, or Ctrl+D\r\n")
  93. fmt.Fprintf(channel, "> ")
  94. // Read line by line
  95. var input strings.Builder
  96. buf := make([]byte, 1024)
  97. for {
  98. n, err := channel.Read(buf)
  99. if err != nil {
  100. // EOF (Ctrl+D) or other error - exit cleanly
  101. return
  102. }
  103. data := string(buf[:n])
  104. for _, ch := range data {
  105. if ch == 3 { // Ctrl+C
  106. fmt.Fprintf(channel, "^C\r\n")
  107. return
  108. } else if ch == '\n' || ch == '\r' {
  109. fmt.Fprintf(channel, "\r\n") // Echo newline
  110. if input.Len() > 0 {
  111. query := strings.TrimSpace(input.String())
  112. input.Reset()
  113. if query == "exit" {
  114. return
  115. }
  116. // Get LLM response with streaming
  117. ch := make(chan string)
  118. go func() {
  119. if _, err := LLM(query, ch); err != nil {
  120. fmt.Fprintf(channel, "Error: %s\r\n", err.Error())
  121. }
  122. }()
  123. // Stream response as it arrives
  124. for chunk := range ch {
  125. fmt.Fprint(channel, chunk)
  126. }
  127. fmt.Fprintf(channel, "\r\n> ")
  128. }
  129. } else if ch == '\b' || ch == 127 { // Backspace or Delete
  130. if input.Len() > 0 {
  131. // Remove last rune (UTF-8 safe)
  132. str := input.String()
  133. runes := []rune(str)
  134. input.Reset()
  135. input.WriteString(string(runes[:len(runes)-1]))
  136. // Move cursor back, overwrite with space, move back again
  137. fmt.Fprintf(channel, "\b \b")
  138. }
  139. } else {
  140. // Echo the character back to the user
  141. fmt.Fprintf(channel, "%c", ch)
  142. input.WriteRune(ch)
  143. }
  144. }
  145. }
  146. }
  147. // getOrCreateHostKey generates a new ephemeral host key
  148. func getOrCreateHostKey() (ssh.Signer, error) {
  149. // Generate new ephemeral key each time
  150. // Users will see "host key changed" warnings on each restart
  151. key, err := rsa.GenerateKey(rand.Reader, 2048)
  152. if err != nil {
  153. return nil, err
  154. }
  155. return ssh.NewSignerFromKey(key)
  156. }