dns.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. package main
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "strings"
  7. )
  8. type DNSServer struct {
  9. port int
  10. }
  11. func NewDNSServer(port int) *DNSServer {
  12. return &DNSServer{port: port}
  13. }
  14. func (s *DNSServer) Start() error {
  15. addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", s.port))
  16. if err != nil {
  17. return err
  18. }
  19. conn, err := net.ListenUDP("udp", addr)
  20. if err != nil {
  21. return err
  22. }
  23. defer conn.Close()
  24. fmt.Printf("DNS server listening on :%d\n", s.port)
  25. buf := make([]byte, 512) // DNS messages are typically small
  26. for {
  27. n, clientAddr, err := conn.ReadFromUDP(buf)
  28. if err != nil {
  29. // Read error - continue
  30. continue
  31. }
  32. go s.handleQuery(conn, clientAddr, buf[:n])
  33. }
  34. }
  35. func (s *DNSServer) handleQuery(conn *net.UDPConn, addr *net.UDPAddr, query []byte) {
  36. // Validate minimum DNS packet size
  37. if len(query) < 12 {
  38. return
  39. }
  40. // Rate limiting
  41. if !rateLimiter.Allow(addr.String()) {
  42. return // Silently drop - DNS doesn't have error responses for rate limits
  43. }
  44. // Validate DNS header flags (must be a query, not response)
  45. if query[2]&0x80 != 0 {
  46. return // It's a response, not a query
  47. }
  48. // Extract question from query
  49. question := extractQuestion(query)
  50. if question == "" {
  51. return
  52. }
  53. // Remove .ch.at suffix if present
  54. question = strings.TrimSuffix(question, ".ch.at")
  55. question = strings.TrimSuffix(question, ".")
  56. // Convert DNS format to readable (replace - with space)
  57. prompt := strings.ReplaceAll(question, "-", " ")
  58. // Get LLM response
  59. ctx := context.Background()
  60. response, err := getLLMResponse(ctx, prompt)
  61. if err != nil {
  62. response = "Error: " + err.Error()
  63. }
  64. // Build DNS response with chunked TXT records
  65. reply := buildDNSResponse(query, response)
  66. // Ensure response fits in UDP packet (RFC recommends 512 bytes)
  67. if len(reply) > 512 {
  68. // Truncate and set TC bit
  69. reply = reply[:512]
  70. reply[2] |= 0x02 // Set TC (truncation) bit
  71. }
  72. conn.WriteToUDP(reply, addr)
  73. }
  74. func extractQuestion(query []byte) string {
  75. // Skip header (12 bytes)
  76. if len(query) < 12 {
  77. return ""
  78. }
  79. pos := 12
  80. var name []string
  81. totalLength := 0
  82. // Parse domain name labels (max 128 to prevent DoS)
  83. for i := 0; i < 128 && pos < len(query); i++ {
  84. if pos >= len(query) {
  85. return ""
  86. }
  87. length := int(query[pos])
  88. if length == 0 {
  89. break
  90. }
  91. // DNS compression uses first 2 bits = 11 (0xC0)
  92. // We reject these for simplicity and security
  93. if length&0xC0 == 0xC0 {
  94. return ""
  95. }
  96. // DNS label length must be <= 63
  97. if length > 63 {
  98. return ""
  99. }
  100. pos++
  101. if pos+length > len(query) {
  102. return ""
  103. }
  104. // Track total domain name length (max 255)
  105. totalLength += length + 1
  106. if totalLength > 255 {
  107. return ""
  108. }
  109. // Validate label contains reasonable characters
  110. label := query[pos : pos+length]
  111. name = append(name, string(label))
  112. pos += length
  113. }
  114. // Ensure we read a complete question (should have type and class after)
  115. if pos+4 > len(query) {
  116. return ""
  117. }
  118. return strings.Join(name, ".")
  119. }
  120. func buildDNSResponse(query []byte, answer string) []byte {
  121. resp := make([]byte, len(query))
  122. copy(resp, query)
  123. // Set response flags (QR=1, AA=1)
  124. resp[2] = 0x81
  125. resp[3] = 0x80
  126. // Set answer count to 1
  127. resp[7] = 1
  128. // Skip to end of question section
  129. pos := 12
  130. for pos < len(resp) {
  131. if resp[pos] == 0 {
  132. pos += 5 // Skip null terminator + type + class
  133. break
  134. }
  135. pos++
  136. }
  137. // Add answer section
  138. // Pointer to question name
  139. resp = append(resp, 0xc0, 0x0c)
  140. // Type TXT (16), Class IN (1)
  141. resp = append(resp, 0x00, 0x10, 0x00, 0x01)
  142. // TTL (0)
  143. resp = append(resp, 0x00, 0x00, 0x00, 0x00)
  144. // Build TXT record data with chunking
  145. txtData := buildTXTData(answer)
  146. // Data length
  147. resp = append(resp, byte(len(txtData)>>8), byte(len(txtData)))
  148. // TXT data
  149. resp = append(resp, txtData...)
  150. return resp
  151. }
  152. func buildTXTData(text string) []byte {
  153. var data []byte
  154. // Split into 255-byte chunks
  155. for len(text) > 0 {
  156. chunkLen := len(text)
  157. if chunkLen > 255 {
  158. chunkLen = 255
  159. }
  160. data = append(data, byte(chunkLen))
  161. data = append(data, text[:chunkLen]...)
  162. text = text[chunkLen:]
  163. }
  164. return data
  165. }