util.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. package util
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "math/rand"
  11. "net/netip"
  12. "os"
  13. "regexp"
  14. "strconv"
  15. "strings"
  16. "sync"
  17. "time"
  18. "github.com/gabriel-vasile/mimetype"
  19. "golang.org/x/term"
  20. "golang.org/x/time/rate"
  21. )
  22. const (
  23. randomStringCharset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
  24. randomStringLowerCaseCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
  25. )
  26. var (
  27. random = rand.New(rand.NewSource(time.Now().UnixNano()))
  28. randomMutex = sync.Mutex{}
  29. sizeStrRegex = regexp.MustCompile(`(?i)^(\d+)([gmkb])?$`)
  30. errInvalidPriority = errors.New("invalid priority")
  31. noQuotesRegex = regexp.MustCompile(`^[-_./:@a-zA-Z0-9]+$`)
  32. )
  33. // Errors for UnmarshalJSON and UnmarshalJSONWithLimit functions
  34. var (
  35. ErrUnmarshalJSON = errors.New("unmarshalling JSON failed")
  36. ErrTooLargeJSON = errors.New("too large JSON")
  37. )
  38. // FileExists checks if a file exists, and returns true if it does
  39. func FileExists(filename string) bool {
  40. stat, _ := os.Stat(filename)
  41. return stat != nil
  42. }
  43. // Contains returns true if needle is contained in haystack
  44. func Contains[T comparable](haystack []T, needle T) bool {
  45. for _, s := range haystack {
  46. if s == needle {
  47. return true
  48. }
  49. }
  50. return false
  51. }
  52. // ContainsIP returns true if any one of the of prefixes contains the ip.
  53. func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
  54. for _, s := range haystack {
  55. if s.Contains(needle) {
  56. return true
  57. }
  58. }
  59. return false
  60. }
  61. // ContainsAll returns true if all needles are contained in haystack
  62. func ContainsAll[T comparable](haystack []T, needles []T) bool {
  63. for _, needle := range needles {
  64. if !Contains(haystack, needle) {
  65. return false
  66. }
  67. }
  68. return true
  69. }
  70. // SplitNoEmpty splits a string using strings.Split, but filters out empty strings
  71. func SplitNoEmpty(s string, sep string) []string {
  72. res := make([]string, 0)
  73. for _, r := range strings.Split(s, sep) {
  74. if r != "" {
  75. res = append(res, r)
  76. }
  77. }
  78. return res
  79. }
  80. // SplitKV splits a string into a key/value pair using a separator, and trimming space. If the separator
  81. // is not found, key is empty.
  82. func SplitKV(s string, sep string) (key string, value string) {
  83. kv := strings.SplitN(strings.TrimSpace(s), sep, 2)
  84. if len(kv) == 2 {
  85. return strings.TrimSpace(kv[0]), strings.TrimSpace(kv[1])
  86. }
  87. return "", strings.TrimSpace(kv[0])
  88. }
  89. // Map applies a function to each element of a slice and returns a new slice with the results
  90. // Example: Map([]int{1, 2, 3}, func(i int) int { return i * 2 }) -> []int{2, 4, 6}
  91. func Map[T any, U any](slice []T, f func(T) U) []U {
  92. result := make([]U, len(slice))
  93. for i, v := range slice {
  94. result[i] = f(v)
  95. }
  96. return result
  97. }
  98. // Filter returns a new slice containing only the elements of the original slice for which the
  99. // given function returns true.
  100. func Filter[T any](slice []T, f func(T) bool) []T {
  101. result := make([]T, 0)
  102. for _, v := range slice {
  103. if f(v) {
  104. result = append(result, v)
  105. }
  106. }
  107. return result
  108. }
  109. // Find returns the first element in the slice that satisfies the given function, and a boolean indicating
  110. // whether such an element was found. If no element is found, it returns the zero value of T and false.
  111. func Find[T any](slice []T, f func(T) bool) (T, bool) {
  112. for _, v := range slice {
  113. if f(v) {
  114. return v, true
  115. }
  116. }
  117. var zero T
  118. return zero, false
  119. }
  120. // RandomString returns a random string with a given length
  121. func RandomString(length int) string {
  122. return RandomStringPrefix("", length)
  123. }
  124. // RandomStringPrefix returns a random string with a given length, with a prefix
  125. func RandomStringPrefix(prefix string, length int) string {
  126. return randomStringPrefixWithCharset(prefix, length, randomStringCharset)
  127. }
  128. // RandomLowerStringPrefix returns a random lowercase-only string with a given length, with a prefix
  129. func RandomLowerStringPrefix(prefix string, length int) string {
  130. return randomStringPrefixWithCharset(prefix, length, randomStringLowerCaseCharset)
  131. }
  132. func randomStringPrefixWithCharset(prefix string, length int, charset string) string {
  133. randomMutex.Lock() // Who would have thought that random.Intn() is not thread-safe?!
  134. defer randomMutex.Unlock()
  135. b := make([]byte, length-len(prefix))
  136. for i := range b {
  137. b[i] = charset[random.Intn(len(charset))]
  138. }
  139. return prefix + string(b)
  140. }
  141. // ValidRandomString returns true if the given string matches the format created by RandomString
  142. func ValidRandomString(s string, length int) bool {
  143. if len(s) != length {
  144. return false
  145. }
  146. for _, c := range strings.Split(s, "") {
  147. if !strings.Contains(randomStringCharset, c) {
  148. return false
  149. }
  150. }
  151. return true
  152. }
  153. // ParsePriority parses a priority string into its equivalent integer value
  154. func ParsePriority(priority string) (int, error) {
  155. p := strings.TrimSpace(strings.ToLower(priority))
  156. switch p {
  157. case "":
  158. return 0, nil
  159. case "1", "min":
  160. return 1, nil
  161. case "2", "low":
  162. return 2, nil
  163. case "3", "default":
  164. return 3, nil
  165. case "4", "high":
  166. return 4, nil
  167. case "5", "max", "urgent":
  168. return 5, nil
  169. default:
  170. return 0, errInvalidPriority
  171. }
  172. }
  173. // PriorityString converts a priority number to a string
  174. func PriorityString(priority int) (string, error) {
  175. switch priority {
  176. case 0:
  177. return "default", nil
  178. case 1:
  179. return "min", nil
  180. case 2:
  181. return "low", nil
  182. case 3:
  183. return "default", nil
  184. case 4:
  185. return "high", nil
  186. case 5:
  187. return "max", nil
  188. default:
  189. return "", errInvalidPriority
  190. }
  191. }
  192. // ShortTopicURL shortens the topic URL to be human-friendly, removing the http:// or https://
  193. func ShortTopicURL(s string) string {
  194. return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://")
  195. }
  196. // DetectContentType probes the byte array b and returns mime type and file extension.
  197. // The filename is only used to override certain special cases.
  198. func DetectContentType(b []byte, filename string) (mimeType string, ext string) {
  199. if strings.HasSuffix(strings.ToLower(filename), ".apk") {
  200. return "application/vnd.android.package-archive", ".apk"
  201. }
  202. m := mimetype.Detect(b)
  203. mimeType, ext = m.String(), m.Extension()
  204. if ext == "" {
  205. ext = ".bin"
  206. }
  207. return
  208. }
  209. // ParseSize parses a size string like 2K or 2M into bytes. If no unit is found, e.g. 123, bytes is assumed.
  210. func ParseSize(s string) (int64, error) {
  211. matches := sizeStrRegex.FindStringSubmatch(s)
  212. if matches == nil {
  213. return -1, fmt.Errorf("invalid size %s", s)
  214. }
  215. value, err := strconv.Atoi(matches[1])
  216. if err != nil {
  217. return -1, fmt.Errorf("cannot convert number %s", matches[1])
  218. }
  219. switch strings.ToUpper(matches[2]) {
  220. case "T":
  221. return int64(value) * 1024 * 1024 * 1024 * 1024, nil
  222. case "G":
  223. return int64(value) * 1024 * 1024 * 1024, nil
  224. case "M":
  225. return int64(value) * 1024 * 1024, nil
  226. case "K":
  227. return int64(value) * 1024, nil
  228. default:
  229. return int64(value), nil
  230. }
  231. }
  232. // FormatSize formats the size in a way that it can be parsed by ParseSize.
  233. // It does not include decimal places. Uneven sizes are rounded down.
  234. func FormatSize(b int64) string {
  235. const unit = 1024
  236. if b < unit {
  237. return fmt.Sprintf("%d", b)
  238. }
  239. div, exp := int64(unit), 0
  240. for n := b / unit; n >= unit; n /= unit {
  241. div *= unit
  242. exp++
  243. }
  244. return fmt.Sprintf("%d%c", int(math.Floor(float64(b)/float64(div))), "KMGT"[exp])
  245. }
  246. // FormatSizeHuman formats bytes into a human-readable notation, e.g. 2.1 MB
  247. func FormatSizeHuman(b int64) string {
  248. const unit = 1024
  249. if b < unit {
  250. return fmt.Sprintf("%d bytes", b)
  251. }
  252. div, exp := int64(unit), 0
  253. for n := b / unit; n >= unit; n /= unit {
  254. div *= unit
  255. exp++
  256. }
  257. return fmt.Sprintf("%.1f %cB", float64(b)/float64(div), "KMGT"[exp])
  258. }
  259. // ReadPassword will read a password from STDIN. If the terminal supports it, it will not print the
  260. // input characters to the screen. If not, it'll just read using normal readline semantics (useful for testing).
  261. func ReadPassword(in io.Reader) ([]byte, error) {
  262. // If in is a file and a character device (a TTY), use term.ReadPassword
  263. if f, ok := in.(*os.File); ok {
  264. stat, err := f.Stat()
  265. if err != nil {
  266. return nil, err
  267. }
  268. if (stat.Mode() & os.ModeCharDevice) == os.ModeCharDevice {
  269. password, err := term.ReadPassword(int(f.Fd())) // This is always going to be 0
  270. if err != nil {
  271. return nil, err
  272. } else if len(password) == 0 {
  273. return nil, errors.New("password cannot be empty")
  274. }
  275. return password, nil
  276. }
  277. }
  278. // Fallback: Manually read util \n if found, see #69 for details why this is so manual
  279. password := make([]byte, 0)
  280. buf := make([]byte, 1)
  281. for {
  282. _, err := in.Read(buf)
  283. if err == io.EOF || buf[0] == '\n' {
  284. break
  285. } else if err != nil {
  286. return nil, err
  287. } else if len(password) > 10240 {
  288. return nil, errors.New("passwords this long are not supported")
  289. }
  290. password = append(password, buf[0])
  291. }
  292. if len(password) == 0 {
  293. return nil, errors.New("password cannot be empty")
  294. }
  295. return password, nil
  296. }
  297. // BasicAuth encodes the Authorization header value for basic auth
  298. func BasicAuth(user, pass string) string {
  299. return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", user, pass))))
  300. }
  301. // BearerAuth encodes the Authorization header value for a bearer/token auth
  302. func BearerAuth(token string) string {
  303. return fmt.Sprintf("Bearer %s", token)
  304. }
  305. // MaybeMarshalJSON returns a JSON string of the given object, or "<cannot serialize>" if serialization failed.
  306. // This is useful for logging purposes where a failure doesn't matter that much.
  307. func MaybeMarshalJSON(v any) string {
  308. jsonBytes, err := json.MarshalIndent(v, "", " ")
  309. if err != nil {
  310. return "<cannot serialize>"
  311. }
  312. if len(jsonBytes) > 5000 {
  313. return string(jsonBytes)[:5000]
  314. }
  315. return string(jsonBytes)
  316. }
  317. // QuoteCommand combines a command array to a string, quoting arguments that need quoting.
  318. // This function is naive, and sometimes wrong. It is only meant for lo pretty-printing a command.
  319. //
  320. // Warning: Never use this function with the intent to run the resulting command.
  321. //
  322. // Example:
  323. //
  324. // []string{"ls", "-al", "Document Folder"} -> ls -al "Document Folder"
  325. func QuoteCommand(command []string) string {
  326. var quoted []string
  327. for _, c := range command {
  328. if noQuotesRegex.MatchString(c) {
  329. quoted = append(quoted, c)
  330. } else {
  331. quoted = append(quoted, fmt.Sprintf(`"%s"`, c))
  332. }
  333. }
  334. return strings.Join(quoted, " ")
  335. }
  336. // UnmarshalJSON reads the given io.ReadCloser into a struct
  337. func UnmarshalJSON[T any](body io.ReadCloser) (*T, error) {
  338. var obj T
  339. if err := json.NewDecoder(body).Decode(&obj); err != nil {
  340. return nil, ErrUnmarshalJSON
  341. }
  342. return &obj, nil
  343. }
  344. // UnmarshalJSONWithLimit reads the given io.ReadCloser into a struct, but only until limit is reached
  345. func UnmarshalJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, error) {
  346. defer r.Close()
  347. p, err := Peek(r, limit)
  348. if err != nil {
  349. return nil, err
  350. } else if p.LimitReached {
  351. return nil, ErrTooLargeJSON
  352. }
  353. var obj T
  354. if len(bytes.TrimSpace(p.PeekedBytes)) == 0 && allowEmpty {
  355. return &obj, nil
  356. } else if err := json.NewDecoder(p).Decode(&obj); err != nil {
  357. return nil, ErrUnmarshalJSON
  358. }
  359. return &obj, nil
  360. }
  361. // Retry executes function f until if succeeds, and then returns t. If f fails, it sleeps
  362. // and tries again. The sleep durations are passed as the after params.
  363. func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error) {
  364. for _, delay := range after {
  365. if t, err = f(); err == nil {
  366. return t, nil
  367. }
  368. time.Sleep(delay)
  369. }
  370. return nil, err
  371. }
  372. // MinMax returns value if it is between min and max, or either
  373. // min or max if it is out of range
  374. func MinMax[T int | int64](value, min, max T) T {
  375. if value < min {
  376. return min
  377. } else if value > max {
  378. return max
  379. }
  380. return value
  381. }
  382. // Max returns the maximum value of the two given values
  383. func Max[T int | int64 | rate.Limit](a, b T) T {
  384. if a > b {
  385. return a
  386. }
  387. return b
  388. }
  389. // String turns a string into a pointer of a string
  390. func String(v string) *string {
  391. return &v
  392. }
  393. // Int turns an int into a pointer of an int
  394. func Int(v int) *int {
  395. return &v
  396. }
  397. // Time turns a time.Time into a pointer
  398. func Time(v time.Time) *time.Time {
  399. return &v
  400. }