ssh.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package main
  2. import (
  3. "crypto/rand"
  4. "crypto/rsa"
  5. "fmt"
  6. "net"
  7. "strings"
  8. "golang.org/x/crypto/ssh"
  9. )
  10. func StartSSHServer(port int) error {
  11. // SSH server configuration
  12. config := &ssh.ServerConfig{
  13. NoClientAuth: true, // Anonymous access
  14. }
  15. // Get or create persistent host key
  16. privateKey, err := getOrCreateHostKey()
  17. if err != nil {
  18. return fmt.Errorf("failed to get host key: %v", err)
  19. }
  20. config.AddHostKey(privateKey)
  21. // Listen for connections
  22. listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
  23. if err != nil {
  24. return err
  25. }
  26. defer listener.Close()
  27. // Simple connection limiting
  28. sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
  29. for {
  30. conn, err := listener.Accept()
  31. if err != nil {
  32. // Connection error - continue accepting others
  33. continue
  34. }
  35. select {
  36. case sem <- struct{}{}:
  37. go func() {
  38. defer func() { <-sem }()
  39. handleConnection(conn, config)
  40. }()
  41. default:
  42. // Too many connections
  43. conn.Close()
  44. }
  45. }
  46. }
  47. func handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
  48. defer netConn.Close()
  49. // Rate limiting
  50. if !rateLimitAllow(netConn.RemoteAddr().String()) {
  51. netConn.Write([]byte("Rate limit exceeded\r\n"))
  52. return
  53. }
  54. // Perform SSH handshake
  55. sshConn, chans, reqs, err := ssh.NewServerConn(netConn, config)
  56. if err != nil {
  57. // Handshake failed - continue accepting others
  58. return
  59. }
  60. defer sshConn.Close()
  61. go ssh.DiscardRequests(reqs)
  62. // Handle channels (sessions)
  63. for newChannel := range chans {
  64. if newChannel.ChannelType() != "session" {
  65. newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
  66. continue
  67. }
  68. channel, requests, err := newChannel.Accept()
  69. if err != nil {
  70. // Channel error - continue
  71. continue
  72. }
  73. go handleSession(channel, requests)
  74. }
  75. }
  76. func handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
  77. defer channel.Close()
  78. // Handle session requests
  79. go func() {
  80. for req := range requests {
  81. switch req.Type {
  82. case "shell", "pty-req":
  83. req.Reply(true, nil)
  84. default:
  85. req.Reply(false, nil)
  86. }
  87. }
  88. }()
  89. fmt.Fprintf(channel, "Welcome to ch.at\r\n")
  90. fmt.Fprintf(channel, "Type your message and press Enter.\r\n")
  91. fmt.Fprintf(channel, "Exit: type 'exit', Ctrl+C, or Ctrl+D\r\n")
  92. fmt.Fprintf(channel, "> ")
  93. // Read line by line
  94. var input strings.Builder
  95. buf := make([]byte, 1024)
  96. for {
  97. n, err := channel.Read(buf)
  98. if err != nil {
  99. // EOF (Ctrl+D) or other error - exit cleanly
  100. return
  101. }
  102. data := string(buf[:n])
  103. for _, ch := range data {
  104. if ch == 3 { // Ctrl+C
  105. fmt.Fprintf(channel, "^C\r\n")
  106. return
  107. } else if ch == '\n' || ch == '\r' {
  108. fmt.Fprintf(channel, "\r\n") // Echo newline
  109. if input.Len() > 0 {
  110. query := strings.TrimSpace(input.String())
  111. input.Reset()
  112. if query == "exit" {
  113. return
  114. }
  115. // Get LLM response with streaming
  116. ch := make(chan string)
  117. go func() {
  118. if _, err := LLM(query, ch); err != nil {
  119. fmt.Fprintf(channel, "Error: %s\r\n", err.Error())
  120. }
  121. }()
  122. // Stream response as it arrives
  123. for chunk := range ch {
  124. fmt.Fprint(channel, chunk)
  125. }
  126. fmt.Fprintf(channel, "\r\n> ")
  127. }
  128. } else if ch == '\b' || ch == 127 { // Backspace or Delete
  129. if input.Len() > 0 {
  130. // Remove last rune (UTF-8 safe)
  131. str := input.String()
  132. runes := []rune(str)
  133. input.Reset()
  134. input.WriteString(string(runes[:len(runes)-1]))
  135. // Move cursor back, overwrite with space, move back again
  136. fmt.Fprintf(channel, "\b \b")
  137. }
  138. } else {
  139. // Echo the character back to the user
  140. fmt.Fprintf(channel, "%c", ch)
  141. input.WriteRune(ch)
  142. }
  143. }
  144. }
  145. }
  146. // getOrCreateHostKey generates a new ephemeral host key
  147. func getOrCreateHostKey() (ssh.Signer, error) {
  148. // Generate new ephemeral key each time
  149. // Users will see "host key changed" warnings on each restart
  150. key, err := rsa.GenerateKey(rand.Reader, 2048)
  151. if err != nil {
  152. return nil, err
  153. }
  154. return ssh.NewSignerFromKey(key)
  155. }