From 13218116e32887a84f78d491c1ee85acf9336920 Mon Sep 17 00:00:00 2001 From: juls0730 <62722391+juls0730@users.noreply.github.com> Date: Wed, 4 Sep 2024 22:55:41 -0500 Subject: [PATCH] fix auth middleware --- main.go | 15 +++++++++++---- middleware/auth.go | 16 +++++++++++++--- models/user.go | 2 +- routes/auth.go | 14 ++++++++------ 4 files changed, 33 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index 17b22da..53537b8 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ package main import ( "context" "database/sql" + "filething/middleware" "filething/models" "filething/routes" "filething/ui" @@ -14,7 +15,7 @@ import ( "strings" "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + echoMiddleware "github.com/labstack/echo/v4/middleware" "github.com/uptrace/bun" "github.com/uptrace/bun/dialect/pgdialect" "github.com/uptrace/bun/driver/pgdriver" @@ -49,9 +50,9 @@ func main() { } }) - e.Use(middleware.Gzip()) - e.Use(middleware.CORS()) - e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ + e.Use(echoMiddleware.Gzip()) + e.Use(echoMiddleware.CORS()) + e.Use(echoMiddleware.CSRFWithConfig(echoMiddleware.CSRFConfig{ TokenLookup: "cookie:_csrf", CookiePath: "/", CookieSecure: true, @@ -63,6 +64,12 @@ func main() { { api.POST("/login", routes.LoginHandler) api.POST("/signup", routes.SignupHandler) + api.Use(middleware.SessionMiddleware(db)) + api.GET("/user", func(c echo.Context) error { + user := c.Get("user").(*models.User) + message := fmt.Sprintf("You are %s", user.ID) + return c.JSON(http.StatusOK, map[string]string{"message": message}) + }) api.GET("/hello", func(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{"message": "Hello, World!!!"}) }) diff --git a/middleware/auth.go b/middleware/auth.go index 78049a3..c18e342 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "filething/models" + "fmt" "net/http" + "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/uptrace/bun" ) @@ -36,17 +38,25 @@ func SessionMiddleware(db *bun.DB) echo.MiddlewareFunc { sessionToken := cookie.Value // Query the session and user data from PostgreSQL - session := new(models.Session) - err = db.NewSelect().Model(session).Relation("User").WherePK(sessionToken).Scan(context.Background()) + sessionId, err := uuid.Parse(sessionToken) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Bad request") + } + + session := &models.Session{ + ID: sessionId, + } + err = db.NewSelect().Model(session).Relation("User").WherePK().Scan(context.Background()) if err != nil { + fmt.Println(err) if err == sql.ErrNoRows { return echo.NewHTTPError(http.StatusUnauthorized, "Invalid session token") } return echo.NewHTTPError(http.StatusInternalServerError, "Database error") } - user := session.User + user := &session.User // Store the user in the context c.Set(UserContextKey, user) diff --git a/models/user.go b/models/user.go index 2139e4f..58d5b2f 100644 --- a/models/user.go +++ b/models/user.go @@ -27,6 +27,6 @@ type User struct { type Session struct { bun.BaseModel `bun:"table:sessions,alias:u"` ID uuid.UUID `bun:",pk,type:uuid,default:uuid_generate_v4()"` - UserID uuid.UUID `bun:"user_id,notnull"` + UserID uuid.UUID `bun:"user_id,notnull,type:uuid"` User User `bun:"rel:belongs-to,join:user_id=id"` } diff --git a/routes/auth.go b/routes/auth.go index 28b7b45..e3c7d56 100644 --- a/routes/auth.go +++ b/routes/auth.go @@ -41,9 +41,10 @@ func LoginHandler(c echo.Context) error { } c.SetCookie(&http.Cookie{ - Name: "sessionToken", - Value: session.ID.String(), - Path: "/", + Name: "sessionToken", + Value: session.ID.String(), + SameSite: http.SameSiteStrictMode, + Path: "/", }) return c.JSON(http.StatusOK, map[string]string{"message": "Login successful!"}) @@ -109,9 +110,10 @@ func SignupHandler(c echo.Context) error { } c.SetCookie(&http.Cookie{ - Name: "sessionToken", - Value: session.ID.String(), - Path: "/", + Name: "sessionToken", + Value: session.ID.String(), + SameSite: http.SameSiteStrictMode, + Path: "/", }) return c.JSON(http.StatusOK, map[string]string{"message": "Signup successful!"})