Jelajahi Sumber

Move to package

Philipp Heckel 4 tahun lalu
induk
melakukan
f388fd9c90
4 mengubah file dengan 59 tambahan dan 50 penghapusan
  1. 36 0
      auth/auth.go
  2. 14 14
      auth/auth_sqlite.go
  3. 0 28
      server/auth.go
  4. 9 8
      server/server.go

+ 36 - 0
auth/auth.go

@@ -0,0 +1,36 @@
+package auth
+
+import "errors"
+
+// auth is a generic interface to implement password-based authentication and authorization
+type Auth interface {
+	Authenticate(user, pass string) (*User, error)
+	Authorize(user *User, topic string, perm Permission) error
+}
+
+type User struct {
+	Name string
+	Role Role
+}
+
+type Permission int
+
+const (
+	PermissionRead  = Permission(1)
+	PermissionWrite = Permission(2)
+)
+
+type Role string
+
+const (
+	RoleAdmin = Role("admin")
+	RoleUser  = Role("user")
+	RoleNone  = Role("none")
+)
+
+var Everyone = &User{
+	Name: "",
+	Role: RoleNone,
+}
+
+var ErrUnauthorized = errors.New("unauthorized")

+ 14 - 14
server/auth_sqlite.go → auth/auth_sqlite.go

@@ -1,4 +1,4 @@
-package server
+package auth
 
 
 import (
 import (
 	"database/sql"
 	"database/sql"
@@ -69,15 +69,15 @@ const (
 	`
 	`
 )
 )
 
 
-type sqliteAuth struct {
+type SQLiteAuth struct {
 	db           *sql.DB
 	db           *sql.DB
 	defaultRead  bool
 	defaultRead  bool
 	defaultWrite bool
 	defaultWrite bool
 }
 }
 
 
-var _ auth = (*sqliteAuth)(nil)
+var _ Auth = (*SQLiteAuth)(nil)
 
 
-func newSqliteAuth(filename string, defaultRead, defaultWrite bool) (*sqliteAuth, error) {
+func NewSQLiteAuth(filename string, defaultRead, defaultWrite bool) (*SQLiteAuth, error) {
 	db, err := sql.Open("sqlite3", filename)
 	db, err := sql.Open("sqlite3", filename)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -85,7 +85,7 @@ func newSqliteAuth(filename string, defaultRead, defaultWrite bool) (*sqliteAuth
 	if err := setupNewAuthDB(db); err != nil {
 	if err := setupNewAuthDB(db); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &sqliteAuth{
+	return &SQLiteAuth{
 		db:           db,
 		db:           db,
 		defaultRead:  defaultRead,
 		defaultRead:  defaultRead,
 		defaultWrite: defaultWrite,
 		defaultWrite: defaultWrite,
@@ -100,7 +100,7 @@ func setupNewAuthDB(db *sql.DB) error {
 	return nil
 	return nil
 }
 }
 
 
-func (a *sqliteAuth) Authenticate(username, password string) (*user, error) {
+func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
 	rows, err := a.db.Query(selectUserQuery, username)
 	rows, err := a.db.Query(selectUserQuery, username)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
@@ -117,14 +117,14 @@ func (a *sqliteAuth) Authenticate(username, password string) (*user, error) {
 	if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
 	if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
-	return &user{
+	return &User{
 		Name: username,
 		Name: username,
-		Role: role,
+		Role: Role(role),
 	}, nil
 	}, nil
 }
 }
 
 
-func (a *sqliteAuth) Authorize(user *user, topic string, perm int) error {
-	if user.Role == roleAdmin {
+func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error {
+	if user.Role == RoleAdmin {
 		return nil // Admin can do everything
 		return nil // Admin can do everything
 	}
 	}
 	// Select the read/write permissions for this user/topic combo. The query may return two
 	// Select the read/write permissions for this user/topic combo. The query may return two
@@ -147,11 +147,11 @@ func (a *sqliteAuth) Authorize(user *user, topic string, perm int) error {
 	return a.resolvePerms(read, write, perm)
 	return a.resolvePerms(read, write, perm)
 }
 }
 
 
-func (a *sqliteAuth) resolvePerms(read, write bool, perm int) error {
-	if perm == permRead && read {
+func (a *SQLiteAuth) resolvePerms(read, write bool, perm Permission) error {
+	if perm == PermissionRead && read {
 		return nil
 		return nil
-	} else if perm == permWrite && write {
+	} else if perm == PermissionWrite && write {
 		return nil
 		return nil
 	}
 	}
-	return errHTTPUnauthorized
+	return ErrUnauthorized
 }
 }

+ 0 - 28
server/auth.go

@@ -1,28 +0,0 @@
-package server
-
-// auth is a generic interface to implement password-based authentication and authorization
-type auth interface {
-	Authenticate(user, pass string) (*user, error)
-	Authorize(user *user, topic string, perm int) error
-}
-
-type user struct {
-	Name string
-	Role string
-}
-
-const (
-	permRead  = 1
-	permWrite = 2
-)
-
-const (
-	roleAdmin = "admin"
-	roleUser  = "user"
-	roleNone  = "none"
-)
-
-var everyone = &user{
-	Name: "",
-	Role: roleNone,
-}

+ 9 - 8
server/server.go

@@ -14,6 +14,7 @@ import (
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
 	"google.golang.org/api/option"
 	"google.golang.org/api/option"
+	"heckel.io/ntfy/auth"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"html/template"
 	"html/template"
 	"io"
 	"io"
@@ -46,7 +47,7 @@ type Server struct {
 	firebase     subscriber
 	firebase     subscriber
 	mailer       mailer
 	mailer       mailer
 	messages     int64
 	messages     int64
-	auth         auth
+	auth         auth.Auth
 	cache        cache
 	cache        cache
 	fileCache    *fileCache
 	fileCache    *fileCache
 	closeChan    chan bool
 	closeChan    chan bool
@@ -141,9 +142,9 @@ func New(conf *Config) (*Server, error) {
 			return nil, err
 			return nil, err
 		}
 		}
 	}
 	}
-	var auth auth
+	var auther auth.Auth
 	if conf.AuthFile != "" {
 	if conf.AuthFile != "" {
-		auth, err = newSqliteAuth(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
+		auther, err = auth.NewSQLiteAuth(conf.AuthFile, conf.AuthDefaultRead, conf.AuthDefaultWrite)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -155,7 +156,7 @@ func New(conf *Config) (*Server, error) {
 		firebase:  firebaseSubscriber,
 		firebase:  firebaseSubscriber,
 		mailer:    mailer,
 		mailer:    mailer,
 		topics:    topics,
 		topics:    topics,
-		auth:      auth,
+		auth:      auther,
 		visitors:  make(map[string]*visitor),
 		visitors:  make(map[string]*visitor),
 	}, nil
 	}, nil
 }
 }
@@ -1117,14 +1118,14 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
 }
 }
 
 
 func (s *Server) authWrite(next handleFunc) handleFunc {
 func (s *Server) authWrite(next handleFunc) handleFunc {
-	return s.withAuth(next, permWrite)
+	return s.withAuth(next, auth.PermissionWrite)
 }
 }
 
 
 func (s *Server) authRead(next handleFunc) handleFunc {
 func (s *Server) authRead(next handleFunc) handleFunc {
-	return s.withAuth(next, permRead)
+	return s.withAuth(next, auth.PermissionRead)
 }
 }
 
 
-func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
+func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 		if s.auth == nil {
 		if s.auth == nil {
 			return next(w, r, v)
 			return next(w, r, v)
@@ -1133,7 +1134,7 @@ func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		user := everyone
+		user := auth.Everyone
 		username, password, ok := r.BasicAuth()
 		username, password, ok := r.BasicAuth()
 		if ok {
 		if ok {
 			if user, err = s.auth.Authenticate(username, password); err != nil {
 			if user, err = s.auth.Authenticate(username, password); err != nil {