diff --git a/backend/blog.db b/backend/blog.db index 409e56d..2accc7a 100644 Binary files a/backend/blog.db and b/backend/blog.db differ diff --git a/backend/go.mod b/backend/go.mod index 85ab8af..adc5788 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -6,6 +6,8 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 + go-simpler.org/env v0.12.0 + golang.org/x/crypto v0.37.0 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 ) @@ -14,6 +16,5 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect - go-simpler.org/env v0.12.0 // indirect - golang.org/x/text v0.14.0 // indirect + golang.org/x/text v0.24.0 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 6fbdc29..d93ac9a 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -12,8 +12,10 @@ github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= go-simpler.org/env v0.12.0 h1:kt/lBts0J1kjWJAnB740goNdvwNxt5emhYngL0Fzufs= go-simpler.org/env v0.12.0/go.mod h1:cc/5Md9JCUM7LVLtN0HYjPTDcI3Q8TDaPlNTAlDU+WI= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= diff --git a/backend/internal/auth/controller.go b/backend/internal/auth/controller.go index 2c28ba5..12eb08c 100644 --- a/backend/internal/auth/controller.go +++ b/backend/internal/auth/controller.go @@ -2,38 +2,115 @@ package auth import ( "encoding/json" + "errors" + "fmt" + "log" "net/http" + "slices" + + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" ) -func (s *Service) Login(w http.ResponseWriter, r *http.Request) { - login := model.Login{} - if err := json.NewDecoder(r.Body).Decode(&login); err != nil { +func (s *Service) Signup(w http.ResponseWriter, r *http.Request) { + var err error + var login Login + user := model.NewUser() + + if err = json.NewDecoder(r.Body).Decode(&login); err != nil { w.WriteHeader(http.StatusUnauthorized) return } - if login.Name == s.cfg.AdminName && login.Password == s.cfg.AdminPassword { - token, err := createJWT([]byte(s.cfg.Secret)) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) + if len([]byte(login.Password)) > 72 { + fmt.Fprint(w, "Password to long, max 72 bytes") + w.WriteHeader(http.StatusBadRequest) + return + } + if user.Password, err = bcrypt.GenerateFromPassword([]byte(login.Password), 6); err != nil { + log.Println("Error: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + user.Name = login.Name + user.Role = s.cfg.DefaultRole + + err = s.db.Save(&user).Error + if err != nil { + if errors.Is(err, gorm.ErrCheckConstraintViolated) { + fmt.Fprint(w, "Username is already in use") + w.WriteHeader(http.StatusBadRequest) return } - err = json.NewEncoder(w).Encode(&model.LoginResponse{ - Token: token, - }) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - return - } - w.WriteHeader(http.StatusOK) - return + log.Printf("Error: %v", err) + w.WriteHeader(http.StatusInternalServerError) } } -func (s *Service) Authenticated(next http.HandlerFunc) http.Handler { +func (s *Service) Login(w http.ResponseWriter, r *http.Request) { + var login Login + var user model.User + + if err := json.NewDecoder(r.Body).Decode(&login); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if err := s.db.First(&user).Error; err != nil { + fmt.Fprint(w, "user not found") + w.WriteHeader(http.StatusBadRequest) + } + if err := bcrypt.CompareHashAndPassword(user.Password, []byte(login.Password)); err != nil { + fmt.Fprint(w, "Invalid Password") + w.WriteHeader(http.StatusBadRequest) + } + + token, err := s.createJWT(&user) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + res, err := json.Marshal(&LoginResponse{ + Token: token, + }) + if err != nil { + log.Println("Error: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Write(res) +} + +func (s *Service) Logout(w http.ResponseWriter, r *http.Request) { + token, err := extractToken(r) + if err != nil { + log.Printf("Error while extracting token: %s", err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + claims, err := s.validateJWT(token) + if err != nil { + fmt.Fprint(w, "Invalid token") + w.WriteHeader(http.StatusBadRequest) + return + } + + if err = s.db.Save(&model.InvalidJWT{JWT: token, ValidUntil: claims.ExpiresAt.Time}).Error; err != nil { + log.Printf("Error while saving logout token: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Service) Authenticated(next http.HandlerFunc, roles ...model.Role) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Our middleware logic goes here... token, err := extractToken(r) @@ -42,12 +119,18 @@ func (s *Service) Authenticated(next http.HandlerFunc) http.Handler { return } - err = validateJWT(token, []byte(s.cfg.Secret)) + claims, err := s.validateJWT(token) if err != nil { w.WriteHeader(http.StatusUnauthorized) return } + // if roles specified check if satisfied + if len(roles) > 0 && !slices.Contains(roles, claims.Role) { + w.WriteHeader(http.StatusForbidden) + return + } + r = writeToContext(r, &claims) next(w, r) }) } diff --git a/backend/internal/auth/ctx.go b/backend/internal/auth/ctx.go new file mode 100644 index 0000000..f61b246 --- /dev/null +++ b/backend/internal/auth/ctx.go @@ -0,0 +1,22 @@ +package auth + +import ( + "context" + "net/http" +) + +type authkey int + +const keyClaims = iota + +func writeToContext(r *http.Request, claims *Claims) *http.Request { + ctx := context.WithValue(r.Context(), claims, claims) + return r.WithContext(ctx) +} + +// ExtractClaims extracts user claims from given context. If no claims in context ok = false. +func ExtractClaims(ctx context.Context) (claims *Claims, ok bool) { + val := ctx.Value(keyClaims) + claims, ok = val.(*Claims) + return +} diff --git a/backend/internal/auth/jwt.go b/backend/internal/auth/jwt.go index daf55f1..162ee8e 100644 --- a/backend/internal/auth/jwt.go +++ b/backend/internal/auth/jwt.go @@ -3,35 +3,61 @@ package auth import ( "errors" "fmt" + "log" "net/http" "strings" "time" - "github.com/golang-jwt/jwt/v5" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + jwt "github.com/golang-jwt/jwt/v5" ) -func createJWT(secret []byte) (token string, err error) { - return jwt.NewWithClaims(jwt.SigningMethodHS512, jwt.MapClaims{ - "exp": time.Now().Add(time.Hour * 24).Unix(), - }).SignedString(secret) +var ErrJWTInvalid = errors.New("JWT not valid") + +func (s *Service) createJWT(user *model.User) (token string, err error) { + claims := &Claims{ + Role: user.Role, + UserID: user.ID, + RegisteredClaims: jwt.RegisteredClaims{ + Subject: user.UUID.String(), + ExpiresAt: &jwt.NumericDate{ + Time: time.Now().Add(s.cfg.ValidDuration), + }, + }, + } + return jwt.NewWithClaims(jwt.SigningMethodHS512, claims).SignedString([]byte(s.cfg.Secret)) } -func validateJWT(tokenString string, secret []byte) (err error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { +func (s *Service) validateJWT(tokenString string) (claims Claims, err error) { + _, err = jwt.ParseWithClaims(tokenString, &claims, func(token *jwt.Token) (any, error) { // Don't forget to validate the alg is what you expect: if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return secret, nil + return []byte(s.cfg.Secret), nil }) if err != nil { return } - if date, err := token.Claims.GetExpirationTime(); err == nil && date.After(time.Now()) { - return nil + log.Println(claims) + if claims.ExpiresAt.Before(time.Now()) { + err = ErrJWTInvalid + return } - return errors.New("JWT not valid") + + var invalidated bool + err = s.db.Model(&model.InvalidJWT{}). + Select("count(*) > 0"). + Where("jwt = ?", tokenString). + Find(&invalidated). + Error + + if invalidated || err != nil { + err = ErrJWTInvalid + return + } + return } func extractToken(r *http.Request) (token string, err error) { diff --git a/backend/internal/auth/jwt_test.go b/backend/internal/auth/jwt_test.go new file mode 100644 index 0000000..a022949 --- /dev/null +++ b/backend/internal/auth/jwt_test.go @@ -0,0 +1,76 @@ +package auth + +import ( + "log" + "testing" + "time" + + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + "github.com/google/uuid" + "gorm.io/driver/sqlite" + "gorm.io/gorm" +) + +func testDB() (db *gorm.DB) { + db, err := gorm.Open(sqlite.Open(":memory:")) + if err != nil { + log.Panic(err) + } + + db.AutoMigrate(&model.User{}, &model.InvalidJWT{}) + + return +} + +func TestService_JWT(t *testing.T) { + t.Parallel() + tests := []struct { + user model.User + }{ + { + user: model.User{ + ID: 0, + Name: "Hans de Admin", + Role: model.RoleAdmin, + UUID: uuid.MustParse("9d8973b7-2005-4ca6-a4bf-7bae5aad2916"), + }, + }, + { + user: model.User{ + ID: 1, + Name: "Ueli de User", + Role: model.RoleUser, + UUID: uuid.MustParse("e1b7099f-a3be-4d77-b33f-389e27123187"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.user.Name, func(t *testing.T) { + t.Parallel() + s := New(&Config{ + Secret: "asdf", + ValidDuration: time.Hour, + AdminName: "adsf", + AdminPassword: "adsf", + }, testDB()) + + jwt, err := s.createJWT(&tt.user) + if err != nil { + t.Errorf("Error while creating JWT: %v", err) + } + + claims, err := s.validateJWT(jwt) + if err != nil { + t.Errorf("Error while creating JWT: %v", err) + } + + if claims.Subject != tt.user.UUID.String() { + t.Error("Subject does not match") + } + if claims.Role != tt.user.Role { + t.Error("Roles did not match") + } + }) + } +} diff --git a/backend/internal/auth/resource.go b/backend/internal/auth/resource.go index 0788183..c7698f0 100644 --- a/backend/internal/auth/resource.go +++ b/backend/internal/auth/resource.go @@ -1,9 +1,14 @@ package auth import ( + "log" "time" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + jwt "github.com/golang-jwt/jwt/v5" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" + "gorm.io/gorm/clause" ) type Config struct { @@ -11,6 +16,7 @@ type Config struct { ValidDuration time.Duration `env:"VALID_DURATION"` AdminName string `env:"ADMIN_NAME"` AdminPassword string `env:"ADMIN_PASSWORD"` + DefaultRole model.Role `env:"DEFAULT_ROLE"` } type Service struct { @@ -19,8 +25,33 @@ type Service struct { } func New(cfg *Config, db *gorm.DB) *Service { + user := model.NewUser() + var err error + if user.Password, err = bcrypt.GenerateFromPassword([]byte(cfg.AdminName), 6); err != nil { + log.Fatalf("Error while creating default user: %v", err) + } + user.Name = cfg.AdminName + user.Role = model.RoleAdmin + + // add default user + _ = db.Clauses(clause.OnConflict{DoNothing: true}).Save(&user).Error + return &Service{ cfg, db, } } + +type Claims struct { + Role model.Role `json:"rl"` + UserID uint `json:"uid"` + jwt.RegisteredClaims +} +type Login struct { + Name string `json:"name"` + Password string `json:"Password"` +} + +type LoginResponse struct { + Token string `json:"token"` +} diff --git a/backend/internal/initialize/inject.go b/backend/internal/initialize/inject.go index 7b901c5..e660ffd 100644 --- a/backend/internal/initialize/inject.go +++ b/backend/internal/initialize/inject.go @@ -4,25 +4,32 @@ import ( "net/http" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/auth" - "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/blog" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/config" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/posts" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/pkg/cors" "github.com/gorilla/mux" ) func CreateMux(cfg *config.Config) (r *mux.Router) { db := model.Init() - blg := blog.New(db) + blg := posts.New(db) auth := auth.New(&cfg.Auth, db) r = mux.NewRouter() r.Use(cors.HandlerForOrigin("*")) + + // auth r.HandleFunc("/login", auth.Login).Methods("POST") - r.Handle("/posts", auth.Authenticated(blg.SavePost)).Methods("POST") - r.Handle("/posts", auth.Authenticated(blg.SavePost)).Methods("PUT") - r.Handle("/posts/{postID}", auth.Authenticated(blg.DeletePost)).Methods("DELETE") + r.HandleFunc("/signup", auth.Signup).Methods("POST") + r.Handle("logout", auth.Authenticated(auth.Logout)).Methods("POST") + + // Posts + r.Handle("/posts", auth.Authenticated(blg.SavePost, model.RoleUser, model.RoleAdmin)).Methods("POST") + r.Handle("/posts", auth.Authenticated(blg.SavePost, model.RoleUser, model.RoleAdmin)).Methods("PUT") + r.Handle("/posts/{postID}", auth.Authenticated(blg.DeletePost, model.RoleUser, model.RoleAdmin)).Methods("DELETE") r.Handle("/posts", http.HandlerFunc(blg.GetAllPosts)).Methods("GET") + r.Methods("OPTIONS").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // The CORS middleware should set up the headers for you w.WriteHeader(http.StatusNoContent) diff --git a/backend/internal/model/auth.go b/backend/internal/model/auth.go index e4b6f33..3d2b624 100644 --- a/backend/internal/model/auth.go +++ b/backend/internal/model/auth.go @@ -1,9 +1,34 @@ package model -type Login struct { - Name string `json:"name"` - Password string `json:"Password"` +import ( + "time" + + "github.com/google/uuid" +) + +type Role string + +const ( + RoleAdmin Role = " admin" + RoleUser Role = "user" + RoleGuest Role = "guest" +) + +type InvalidJWT struct { + JWT string `gorm:"primarykey"` + ValidUntil time.Time } -type LoginResponse struct { - Token string `json:"token"` + +type User struct { + ID uint `gorm:"primarykey" json:"-"` + UUID uuid.UUID `gorm:"type:uuid" json:"uuid"` + Name string `json:"name" gorm:"unique"` + Role Role `json:"role"` + Password []byte `json:"-"` +} + +func NewUser() User { + return User{ + UUID: uuid.New(), + } } diff --git a/backend/internal/model/blog.go b/backend/internal/model/blog.go index 7203f58..64a6c5b 100644 --- a/backend/internal/model/blog.go +++ b/backend/internal/model/blog.go @@ -11,9 +11,11 @@ type Post struct { TLDR string `json:"tldr"` Content string `json:"content"` Comments []Comment + UserID uint `gorm:"->;<-:create"` } type Comment struct { ID uint PostID uint Content string `json:"content"` + UserID uint `gorm:"->;<-:create"` } diff --git a/backend/internal/model/init.go b/backend/internal/model/init.go index c25463f..a427a22 100644 --- a/backend/internal/model/init.go +++ b/backend/internal/model/init.go @@ -12,7 +12,7 @@ func Init() *gorm.DB { if err != nil { log.Panic(err) } - db.AutoMigrate(&Post{}, &Comment{}) + db.AutoMigrate(&Post{}, &Comment{}, &User{}, &InvalidJWT{}) db.Save(&Post{ ID: 1, diff --git a/backend/internal/blog/controller.go b/backend/internal/posts/controller.go similarity index 58% rename from backend/internal/blog/controller.go rename to backend/internal/posts/controller.go index 9332f8c..2b7eb9c 100644 --- a/backend/internal/blog/controller.go +++ b/backend/internal/posts/controller.go @@ -1,43 +1,66 @@ -package blog +package posts import ( "encoding/json" "fmt" + "log" "net/http" "strconv" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/auth" "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" "github.com/gorilla/mux" ) func (s Service) SavePost(w http.ResponseWriter, r *http.Request) { - var post model.Post - if err := json.NewDecoder(r.Body).Decode(&post); err != nil { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprint(w, err.Error()) + claims, ok := auth.ExtractClaims(r.Context()) + if !ok { + log.Println("Err could not ExtractClaims") + w.WriteHeader(http.StatusInternalServerError) return } - if err := s.db.Save(&post).Error; err != nil { - w.WriteHeader(http.StatusInternalServerError) + var post model.Post + if err := json.NewDecoder(r.Body).Decode(&post); err != nil { fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return } - json.NewEncoder(w).Encode(&post) - w.WriteHeader(http.StatusOK) + post.UserID = claims.UserID + + if err := s.db.Save(&post).Error; err != nil { + fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + res, err := json.Marshal(&post) + if err != nil { + fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Write(res) } func (s Service) GetAllPosts(w http.ResponseWriter, r *http.Request) { var posts []model.Post if err := s.db.Preload("Comments").Order("created_at DESC").Find(&posts).Error; err != nil { - w.WriteHeader(http.StatusInternalServerError) fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusInternalServerError) return } - json.NewEncoder(w).Encode(&posts) - w.WriteHeader(http.StatusOK) + res, err := json.Marshal(&posts) + if err != nil { + fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Write(res) } func (s Service) DeletePost(w http.ResponseWriter, r *http.Request) { @@ -51,11 +74,17 @@ func (s Service) DeletePost(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotFound) return } - - err = s.db.Delete(&model.Post{}, id).Error - if err != nil { + claims, ok := auth.ExtractClaims(r.Context()) + if !ok { + log.Println("Err could not ExtractClaims") w.WriteHeader(http.StatusInternalServerError) + return + } + + err = s.db.Where("user_id = ?", claims.UserID).Delete(&model.Post{}, id).Error + if err != nil { fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusInternalServerError) } w.WriteHeader(http.StatusNoContent) } diff --git a/backend/internal/blog/resource.go b/backend/internal/posts/resource.go similarity index 89% rename from backend/internal/blog/resource.go rename to backend/internal/posts/resource.go index a23841a..0580309 100644 --- a/backend/internal/blog/resource.go +++ b/backend/internal/posts/resource.go @@ -1,4 +1,4 @@ -package blog +package posts import "gorm.io/gorm" diff --git a/backend/internal/users/password.go b/backend/internal/users/password.go new file mode 100644 index 0000000..94af198 --- /dev/null +++ b/backend/internal/users/password.go @@ -0,0 +1,51 @@ +package users + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/auth" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + "golang.org/x/crypto/bcrypt" +) + +func (s Service) ChangePassword(w http.ResponseWriter, r *http.Request) { + var err error + var req Password + user := model.NewUser() + + if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if claims, ok := auth.ExtractClaims(r.Context()); !ok { + log.Println("Error: was not able to extract Claims") + w.WriteHeader(http.StatusInternalServerError) + } else { + user.ID = claims.UserID + } + + if len([]byte(req.Password)) > 72 { + fmt.Fprint(w, "Password to long, max 72 bytes") + w.WriteHeader(http.StatusBadRequest) + return + } + + if user.Password, err = bcrypt.GenerateFromPassword([]byte(req.Password), 6); err != nil { + log.Println("Error: ", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + err = s.db.Model(&user). + Where("id = ?", user.ID). + Update("password", user.Password). + Error + if err != nil { + log.Printf("Error: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } + w.WriteHeader(http.StatusOK) +} diff --git a/backend/internal/users/resource.go b/backend/internal/users/resource.go new file mode 100644 index 0000000..23a2d43 --- /dev/null +++ b/backend/internal/users/resource.go @@ -0,0 +1,19 @@ +package users + +import ( + "gorm.io/gorm" +) + +type Service struct { + db *gorm.DB +} + +func New(db *gorm.DB) *Service { + return &Service{ + db: db, + } +} + +type Password struct { + Password string `json:"password"` +} diff --git a/backend/internal/users/roles.go b/backend/internal/users/roles.go new file mode 100644 index 0000000..5f74886 --- /dev/null +++ b/backend/internal/users/roles.go @@ -0,0 +1,104 @@ +package users + +import ( + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/auth" + "git.schreifuchs.ch/schreifuchs/ng-blog/backend/internal/model" + "github.com/google/uuid" + "github.com/gorilla/mux" + "gorm.io/gorm" +) + +func (s *Service) GetUsers(w http.ResponseWriter, r *http.Request) { + var users []model.User + + err := s.db.Find(&users).Error + if err != nil { + log.Printf("Error while getting users: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } + + res, err := json.Marshal(&users) + if err != nil { + log.Printf("Error while marshaling users: %v", err) + w.WriteHeader(http.StatusInternalServerError) + } + + w.Write(res) +} + +func (s *Service) SetUserRole(w http.ResponseWriter, r *http.Request) { + var role model.Role + userUUIDstr, ok := mux.Vars(r)["userUUID"] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + userUUID, err := uuid.Parse(userUUIDstr) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + if err := json.NewDecoder(r.Body).Decode(&role); err != nil { + fmt.Fprint(w, err.Error()) + w.WriteHeader(http.StatusBadRequest) + return + } + + err = s.db.Model(&model.User{}). + Where("uuid = ?", userUUID). + Update("role", role). + Error + if err != nil { + log.Printf("Error while update user role: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +func (s *Service) DeleteUser(w http.ResponseWriter, r *http.Request) { + claims, ok := auth.ExtractClaims(r.Context()) + if !ok { + log.Println("Error while extracting claims") + w.WriteHeader(http.StatusInternalServerError) + return + } + + userUUIDstr, ok := mux.Vars(r)["userUUID"] + if !ok { + w.WriteHeader(http.StatusNotFound) + return + } + userUUID, err := uuid.Parse(userUUIDstr) + if err != nil { + w.WriteHeader(http.StatusNotFound) + return + } + + if claims.Role != model.RoleAdmin && userUUIDstr != claims.Subject { + w.WriteHeader(http.StatusForbidden) + return + } + + if err = s.db.Where("uuid = ?", userUUID).Delete(&model.User{}).Error; err != nil { + + if errors.Is(err, gorm.ErrCheckConstraintViolated) { + fmt.Fprint(w, "Username is already in use") + w.WriteHeader(http.StatusBadRequest) + return + } + + log.Printf("Error: %v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusNoContent) +}