limit.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. package util
  2. import (
  3. "errors"
  4. "golang.org/x/time/rate"
  5. "io"
  6. "sync"
  7. "time"
  8. )
  9. // ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached
  10. var ErrLimitReached = errors.New("limit reached")
  11. // Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
  12. type Limiter interface {
  13. // Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached
  14. Allow(n int64) error
  15. }
  16. // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
  17. // ErrLimitReached will be returned. FixedLimiter may be used by multiple goroutines.
  18. type FixedLimiter struct {
  19. value int64
  20. limit int64
  21. mu sync.Mutex
  22. }
  23. // NewFixedLimiter creates a new Limiter
  24. func NewFixedLimiter(limit int64) *FixedLimiter {
  25. return NewFixedLimiterWithValue(limit, 0)
  26. }
  27. // NewFixedLimiterWithValue creates a new Limiter and sets the initial value
  28. func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
  29. return &FixedLimiter{
  30. limit: limit,
  31. value: value,
  32. }
  33. }
  34. // Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
  35. // exceeded after adding n, ErrLimitReached is returned.
  36. func (l *FixedLimiter) Allow(n int64) error {
  37. l.mu.Lock()
  38. defer l.mu.Unlock()
  39. if l.value+n > l.limit {
  40. return ErrLimitReached
  41. }
  42. l.value += n
  43. return nil
  44. }
  45. // Value returns the current limiter value
  46. func (l *FixedLimiter) Value() int64 {
  47. l.mu.Lock()
  48. defer l.mu.Unlock()
  49. return l.value
  50. }
  51. // Reset sets the limiter's value back to zero
  52. func (l *FixedLimiter) Reset() {
  53. l.mu.Lock()
  54. defer l.mu.Unlock()
  55. l.value = 0
  56. }
  57. // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
  58. type RateLimiter struct {
  59. limiter *rate.Limiter
  60. }
  61. // NewRateLimiter creates a new RateLimiter
  62. func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
  63. return &RateLimiter{
  64. limiter: rate.NewLimiter(r, b),
  65. }
  66. }
  67. // NewBytesLimiter creates a RateLimiter that is meant to be used for a bytes-per-interval limit,
  68. // e.g. 250 MB per day. And example of the underlying idea can be found here: https://go.dev/play/p/0ljgzIZQ6dJ
  69. func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
  70. return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
  71. }
  72. // Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
  73. // exceeded after adding n, ErrLimitReached is returned.
  74. func (l *RateLimiter) Allow(n int64) error {
  75. if n <= 0 {
  76. return nil // No-op. Can't take back bytes you're written!
  77. }
  78. if !l.limiter.AllowN(time.Now(), int(n)) {
  79. return ErrLimitReached
  80. }
  81. return nil
  82. }
  83. // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
  84. // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
  85. // Each limiter's value is increased with every write.
  86. type LimitWriter struct {
  87. w io.Writer
  88. written int64
  89. limiters []Limiter
  90. mu sync.Mutex
  91. }
  92. // NewLimitWriter creates a new LimitWriter
  93. func NewLimitWriter(w io.Writer, limiters ...Limiter) *LimitWriter {
  94. return &LimitWriter{
  95. w: w,
  96. limiters: limiters,
  97. }
  98. }
  99. // Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
  100. func (w *LimitWriter) Write(p []byte) (n int, err error) {
  101. w.mu.Lock()
  102. defer w.mu.Unlock()
  103. for i := 0; i < len(w.limiters); i++ {
  104. if err := w.limiters[i].Allow(int64(len(p))); err != nil {
  105. for j := i - 1; j >= 0; j-- {
  106. w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed
  107. }
  108. return 0, ErrLimitReached
  109. }
  110. }
  111. n, err = w.w.Write(p)
  112. w.written += int64(n)
  113. return
  114. }