util.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. package main
  2. import (
  3. "net"
  4. "sync"
  5. "time"
  6. )
  7. // Simple in-memory rate limiter
  8. // To disable: Remove NewRateLimiter calls from protocol files
  9. type RateLimiter struct {
  10. requests map[string][]time.Time
  11. mu sync.Mutex
  12. limit int // requests per window
  13. window time.Duration // time window
  14. stopCh chan struct{} // for cleanup goroutine
  15. }
  16. func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
  17. r := &RateLimiter{
  18. requests: make(map[string][]time.Time),
  19. limit: limit,
  20. window: window,
  21. stopCh: make(chan struct{}),
  22. }
  23. // Start cleanup goroutine
  24. go r.cleanup()
  25. return r
  26. }
  27. func (r *RateLimiter) Allow(addr string) bool {
  28. r.mu.Lock()
  29. defer r.mu.Unlock()
  30. // Extract IP without port (addr might be "1.2.3.4:5678" or just "1.2.3.4")
  31. ip, _, _ := net.SplitHostPort(addr)
  32. if ip == "" {
  33. ip = addr // addr was already just an IP
  34. }
  35. now := time.Now()
  36. cutoff := now.Add(-r.window)
  37. // Get or create request list
  38. requests := r.requests[ip]
  39. // Remove old requests
  40. valid := []time.Time{}
  41. for _, t := range requests {
  42. if t.After(cutoff) {
  43. valid = append(valid, t)
  44. }
  45. }
  46. // Check limit
  47. if len(valid) >= r.limit {
  48. return false
  49. }
  50. // Add new request
  51. valid = append(valid, now)
  52. r.requests[ip] = valid
  53. return true
  54. }
  55. // Periodic cleanup to prevent unbounded memory growth
  56. func (r *RateLimiter) cleanup() {
  57. ticker := time.NewTicker(r.window)
  58. defer ticker.Stop()
  59. for {
  60. select {
  61. case <-ticker.C:
  62. r.mu.Lock()
  63. now := time.Now()
  64. cutoff := now.Add(-r.window)
  65. // Remove IPs with no recent requests
  66. for ip, requests := range r.requests {
  67. valid := []time.Time{}
  68. for _, t := range requests {
  69. if t.After(cutoff) {
  70. valid = append(valid, t)
  71. }
  72. }
  73. if len(valid) == 0 {
  74. delete(r.requests, ip)
  75. } else {
  76. r.requests[ip] = valid
  77. }
  78. }
  79. r.mu.Unlock()
  80. case <-r.stopCh:
  81. return
  82. }
  83. }
  84. }
  85. // Stop the rate limiter cleanup
  86. func (r *RateLimiter) Stop() {
  87. select {
  88. case <-r.stopCh:
  89. // Already closed
  90. default:
  91. close(r.stopCh)
  92. }
  93. }
  94. // Add other shared utilities here as needed
  95. // Each should be self-contained and optional
  96. // Global rate limiter instance
  97. var rateLimiter = NewRateLimiter(100, time.Minute) // 100 requests per minute per IP