file_cache.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package server
  2. import (
  3. "errors"
  4. "fmt"
  5. "heckel.io/ntfy/util"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "regexp"
  10. "sync"
  11. "time"
  12. )
  13. var (
  14. fileIDRegex = regexp.MustCompile(fmt.Sprintf(`^[-_A-Za-z0-9]{%d}$`, messageIDLength))
  15. errInvalidFileID = errors.New("invalid file ID")
  16. errFileExists = errors.New("file exists")
  17. )
  18. type fileCache struct {
  19. dir string
  20. totalSizeCurrent int64
  21. totalSizeLimit int64
  22. mu sync.Mutex
  23. }
  24. func newFileCache(dir string, totalSizeLimit int64) (*fileCache, error) {
  25. if err := os.MkdirAll(dir, 0700); err != nil {
  26. return nil, err
  27. }
  28. size, err := dirSize(dir)
  29. if err != nil {
  30. return nil, err
  31. }
  32. return &fileCache{
  33. dir: dir,
  34. totalSizeCurrent: size,
  35. totalSizeLimit: totalSizeLimit,
  36. }, nil
  37. }
  38. func (c *fileCache) Write(id string, in io.Reader, limiters ...util.Limiter) (int64, error) {
  39. if !fileIDRegex.MatchString(id) {
  40. return 0, errInvalidFileID
  41. }
  42. file := filepath.Join(c.dir, id)
  43. if _, err := os.Stat(file); err == nil {
  44. return 0, errFileExists
  45. }
  46. f, err := os.OpenFile(file, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
  47. if err != nil {
  48. return 0, err
  49. }
  50. defer f.Close()
  51. limiters = append(limiters, util.NewFixedLimiter(c.Remaining()))
  52. limitWriter := util.NewLimitWriter(f, limiters...)
  53. size, err := io.Copy(limitWriter, in)
  54. if err != nil {
  55. os.Remove(file)
  56. return 0, err
  57. }
  58. if err := f.Close(); err != nil {
  59. os.Remove(file)
  60. return 0, err
  61. }
  62. c.mu.Lock()
  63. c.totalSizeCurrent += size
  64. c.mu.Unlock()
  65. return size, nil
  66. }
  67. func (c *fileCache) Remove(ids ...string) error {
  68. for _, id := range ids {
  69. if !fileIDRegex.MatchString(id) {
  70. return errInvalidFileID
  71. }
  72. file := filepath.Join(c.dir, id)
  73. _ = os.Remove(file) // Best effort delete
  74. }
  75. size, err := dirSize(c.dir)
  76. if err != nil {
  77. return err
  78. }
  79. c.mu.Lock()
  80. c.totalSizeCurrent = size
  81. c.mu.Unlock()
  82. return nil
  83. }
  84. // Expired returns a list of file IDs for expired files
  85. func (c *fileCache) Expired(olderThan time.Time) ([]string, error) {
  86. entries, err := os.ReadDir(c.dir)
  87. if err != nil {
  88. return nil, err
  89. }
  90. var ids []string
  91. for _, e := range entries {
  92. info, err := e.Info()
  93. if err != nil {
  94. continue
  95. }
  96. if info.ModTime().Before(olderThan) && fileIDRegex.MatchString(e.Name()) {
  97. ids = append(ids, e.Name())
  98. }
  99. }
  100. return ids, nil
  101. }
  102. func (c *fileCache) Size() int64 {
  103. c.mu.Lock()
  104. defer c.mu.Unlock()
  105. return c.totalSizeCurrent
  106. }
  107. func (c *fileCache) Remaining() int64 {
  108. c.mu.Lock()
  109. defer c.mu.Unlock()
  110. remaining := c.totalSizeLimit - c.totalSizeCurrent
  111. if remaining < 0 {
  112. return 0
  113. }
  114. return remaining
  115. }
  116. func dirSize(dir string) (int64, error) {
  117. entries, err := os.ReadDir(dir)
  118. if err != nil {
  119. return 0, err
  120. }
  121. var size int64
  122. for _, e := range entries {
  123. info, err := e.Info()
  124. if err != nil {
  125. return 0, err
  126. }
  127. size += info.Size()
  128. }
  129. return size, nil
  130. }