limit.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package util
  2. import (
  3. "errors"
  4. "io"
  5. "sync"
  6. )
  7. // ErrLimitReached is the error returned by the Limiter and LimitWriter when the predefined limit has been reached
  8. var ErrLimitReached = errors.New("limit reached")
  9. // Limiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
  10. // ErrLimitReached will be returned. Limiter may be used by multiple goroutines.
  11. type Limiter struct {
  12. value int64
  13. limit int64
  14. mu sync.Mutex
  15. }
  16. // NewLimiter creates a new Limiter
  17. func NewLimiter(limit int64) *Limiter {
  18. return &Limiter{
  19. limit: limit,
  20. }
  21. }
  22. // Add adds n to the limiters internal value, but only if the limit has not been reached. If the limit would be
  23. // exceeded after adding n, ErrLimitReached is returned.
  24. func (l *Limiter) Add(n int64) error {
  25. l.mu.Lock()
  26. defer l.mu.Unlock()
  27. if l.limit == 0 {
  28. l.value += n
  29. return nil
  30. } else if l.value+n <= l.limit {
  31. l.value += n
  32. return nil
  33. } else {
  34. return ErrLimitReached
  35. }
  36. }
  37. // Sub subtracts a value from the limiters internal value
  38. func (l *Limiter) Sub(n int64) {
  39. l.Add(-n)
  40. }
  41. // Set sets the value of the limiter to n. This function ignores the limit. It is meant to set the value
  42. // based on reality.
  43. func (l *Limiter) Set(n int64) {
  44. l.mu.Lock()
  45. l.value = n
  46. l.mu.Unlock()
  47. }
  48. // Value returns the internal value of the limiter
  49. func (l *Limiter) Value() int64 {
  50. l.mu.Lock()
  51. defer l.mu.Unlock()
  52. return l.value
  53. }
  54. // Limit returns the defined limit
  55. func (l *Limiter) Limit() int64 {
  56. return l.limit
  57. }
  58. // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
  59. // writer w until any of the limiter's limit is reached, at which point a Write will return ErrLimitReached.
  60. // Each limiter's value is increased with every write.
  61. type LimitWriter struct {
  62. w io.Writer
  63. written int64
  64. limiters []*Limiter
  65. mu sync.Mutex
  66. }
  67. // NewLimitWriter creates a new LimitWriter
  68. func NewLimitWriter(w io.Writer, limiters ...*Limiter) *LimitWriter {
  69. return &LimitWriter{
  70. w: w,
  71. limiters: limiters,
  72. }
  73. }
  74. // Write passes through all writes to the underlying writer until any of the given limiter's limit is reached
  75. func (w *LimitWriter) Write(p []byte) (n int, err error) {
  76. w.mu.Lock()
  77. defer w.mu.Unlock()
  78. for i := 0; i < len(w.limiters); i++ {
  79. if err := w.limiters[i].Add(int64(len(p))); err != nil {
  80. for j := i - 1; j >= 0; j-- {
  81. w.limiters[j].Sub(int64(len(p)))
  82. }
  83. return 0, ErrLimitReached
  84. }
  85. }
  86. n, err = w.w.Write(p)
  87. w.written += int64(n)
  88. return
  89. }