diff --git a/cmd/wg-portal/common/config.go b/cmd/wg-portal/common/config.go index 82fe1cb..6e89b4b 100644 --- a/cmd/wg-portal/common/config.go +++ b/cmd/wg-portal/common/config.go @@ -10,6 +10,8 @@ Email string Firstname string Lastname string + Phone string + Department string IsAdmin string } diff --git a/cmd/wg-portal/common/oauth.go b/cmd/wg-portal/common/oauth.go index ede5389..2f4a681 100644 --- a/cmd/wg-portal/common/oauth.go +++ b/cmd/wg-portal/common/oauth.go @@ -3,13 +3,15 @@ import ( "context" "encoding/json" + "fmt" "io/ioutil" "net/http" + "strconv" "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" - "golang.org/x/oauth2" ) @@ -21,6 +23,13 @@ ) type AuthenticatorUserInfo struct { + Identifier persistence.UserIdentifier + Email string + Firstname string + Lastname string + Phone string + Department string + IsAdmin bool } type Authenticator interface { @@ -36,7 +45,7 @@ cfg *oauth2.Config userInfoEndpoint string client *http.Client - userInfoMapping map[string]string + userInfoMapping OauthFields } func NewPlainOauthAuthenticator(_ context.Context, callbackUrl string, cfg *OAuthProvider) (*plainOauthAuthenticator, error) { @@ -58,6 +67,7 @@ Scopes: cfg.Scopes, } authenticator.userInfoEndpoint = cfg.UserInfoURL + authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) return authenticator, nil } @@ -102,7 +112,18 @@ } func (p plainOauthAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { - return nil, nil // TODO: implement + isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, p.userInfoMapping.IsAdmin, "")) + userInfo := &AuthenticatorUserInfo{ + Identifier: persistence.UserIdentifier(mapDefaultString(raw, p.userInfoMapping.UserIdentifier, "")), + Email: mapDefaultString(raw, p.userInfoMapping.Email, ""), + Firstname: mapDefaultString(raw, p.userInfoMapping.Firstname, ""), + Lastname: mapDefaultString(raw, p.userInfoMapping.Lastname, ""), + Phone: mapDefaultString(raw, p.userInfoMapping.Phone, ""), + Department: mapDefaultString(raw, p.userInfoMapping.Department, ""), + IsAdmin: isAdmin, + } + + return userInfo, nil } type oidcAuthenticator struct { @@ -110,7 +131,7 @@ provider *oidc.Provider verifier *oidc.IDTokenVerifier cfg *oauth2.Config - userInfoMapping map[string]string + userInfoMapping OauthFields } func NewOidcAuthenticator(ctx context.Context, callbackUrl string, cfg *OpenIDConnectProvider) (*oidcAuthenticator, error) { @@ -135,6 +156,7 @@ RedirectURL: callbackUrl, Scopes: scopes, } + authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) return authenticator, nil } @@ -173,5 +195,59 @@ } func (o oidcAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { - return nil, nil // TODO: implement + isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, o.userInfoMapping.IsAdmin, "")) + userInfo := &AuthenticatorUserInfo{ + Identifier: persistence.UserIdentifier(mapDefaultString(raw, o.userInfoMapping.UserIdentifier, "")), + Email: mapDefaultString(raw, o.userInfoMapping.Email, ""), + Firstname: mapDefaultString(raw, o.userInfoMapping.Firstname, ""), + Lastname: mapDefaultString(raw, o.userInfoMapping.Lastname, ""), + Phone: mapDefaultString(raw, o.userInfoMapping.Phone, ""), + Department: mapDefaultString(raw, o.userInfoMapping.Department, ""), + IsAdmin: isAdmin, + } + + return userInfo, nil +} + +func getOauthFieldMapping(f OauthFields) OauthFields { + defaultMap := OauthFields{ + UserIdentifier: "sub", + Email: "email", + Firstname: "given_name", + Lastname: "family_name", + Phone: "phone", + Department: "department", + IsAdmin: "admin_flag", + } + switch { + case f.UserIdentifier != "": + defaultMap.UserIdentifier = f.UserIdentifier + case f.Email != "": + defaultMap.Email = f.Email + case f.Firstname != "": + defaultMap.Firstname = f.Firstname + case f.Lastname != "": + defaultMap.Lastname = f.Lastname + case f.Phone != "": + defaultMap.Phone = f.Phone + case f.Department != "": + defaultMap.Department = f.Department + case f.IsAdmin != "": + defaultMap.IsAdmin = f.IsAdmin + } + + return defaultMap +} + +func mapDefaultString(m map[string]interface{}, key string, dflt string) string { + if tmp, ok := m[key]; !ok { + return dflt + } else { + switch v := tmp.(type) { + case string: + return v + default: + return fmt.Sprintf("%v", v) + } + } } diff --git a/cmd/wg-portal/common/session.go b/cmd/wg-portal/common/session.go deleted file mode 100644 index 22e0162..0000000 --- a/cmd/wg-portal/common/session.go +++ /dev/null @@ -1,39 +0,0 @@ -package common - -import ( - "encoding/gob" - - "github.com/h44z/wg-portal/internal/persistence" -) - -func init() { - gob.Register(SessionData{}) - gob.Register(FlashData{}) -} - -type SessionData struct { - AuthBackend string - OauthState string // oauth state - OidcNonce string // oidc id token nonce - LoggedIn bool - IsAdmin bool - UserIdentifier persistence.UserIdentifier - Firstname string - Lastname string - Email string - InterfaceIdentifier persistence.InterfaceIdentifier - - SortedBy map[string]string - SortDirection map[string]string - Search map[string]string - - AlertData string - AlertType string - FormData interface{} -} - -type FlashData struct { - HasAlert bool - Message string - Type string -} diff --git a/cmd/wg-portal/server.go b/cmd/wg-portal/server.go index 76dcf02..c3d176f 100644 --- a/cmd/wg-portal/server.go +++ b/cmd/wg-portal/server.go @@ -52,7 +52,7 @@ return nil, errors.WithMessagef(err, "backend failed to initialize") } - // Web Handler + // Web handler err = s.setupGin() if err != nil { return nil, errors.WithMessagef(err, "backend failed to initialize") @@ -75,7 +75,7 @@ } func (s *server) setupGin() error { - // Web Handler + // Web handler gin.SetMode(gin.ReleaseMode) gin.DefaultWriter = ioutil.Discard s.server = gin.New() diff --git a/cmd/wg-portal/ui/handler.go b/cmd/wg-portal/ui/handler.go index d497a37..9a0ed86 100644 --- a/cmd/wg-portal/ui/handler.go +++ b/cmd/wg-portal/ui/handler.go @@ -7,26 +7,26 @@ "strings" "time" - "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/cmd/wg-portal/common" "github.com/h44z/wg-portal/internal/portal" "github.com/pkg/errors" - "github.com/sirupsen/logrus" csrf "github.com/utrack/gin-csrf" ) -type Handler struct { +type handler struct { config *common.Config + session SessionStore backend portal.Backend oauthAuthenticators map[string]common.Authenticator } -func NewHandler(config *common.Config, backend portal.Backend) (*Handler, error) { - h := &Handler{ +func NewHandler(config *common.Config, backend portal.Backend) (*handler, error) { + h := &handler{ config: config, backend: backend, + session: GinSessionStore{sessionIdentifier: "wgPortalSession"}, oauthAuthenticators: make(map[string]common.Authenticator), } @@ -40,7 +40,7 @@ return h, nil } -func (h *Handler) setupAuthProviders(ctx context.Context) error { +func (h *handler) setupAuthProviders(ctx context.Context) error { extUrl, err := url.Parse(h.config.Core.ExternalUrl) if err != nil { return errors.WithMessage(err, "failed to parse external url") @@ -84,7 +84,7 @@ return nil } -func (h *Handler) RegisterRoutes(g *gin.Engine) { +func (h *handler) RegisterRoutes(g *gin.Engine) { csrfMiddleware := csrf.Middleware(csrf.Options{ Secret: h.config.Core.SessionSecret, ErrorFunc: func(c *gin.Context) { @@ -94,15 +94,15 @@ }) // Entrypoint - g.GET("/", h.GetIndex) + g.GET("/", h.GetIndex()) // Auth routes auth := g.Group("/auth") auth.Use(csrfMiddleware) - auth.GET("/login", h.GetLogin) - auth.POST("/login", h.PostLogin) - auth.GET("/login/:provider", h.GetLoginOauth) - auth.GET("/login/:provider/callback", h.GetLoginOauthCallback) + auth.GET("/login", h.handleLoginGet()) + auth.POST("/login", h.handleLoginPost()) + auth.GET("/login/:provider", h.handleLoginGetOauth()) + auth.GET("/login/:provider/callback", h.handleLoginGetOauthCallback()) //auth.GET("/logout", s.GetLogout) // Admin routes @@ -114,8 +114,6 @@ // -- // -const SessionIdentifier = "wgPortalSession" - type StaticData struct { WebsiteTitle string WebsiteLogo string @@ -123,67 +121,3 @@ Year int Version string } - -func GetSessionData(c *gin.Context) common.SessionData { - session := sessions.Default(c) - rawSessionData := session.Get(SessionIdentifier) - - var sessionData common.SessionData - if rawSessionData != nil { - sessionData = rawSessionData.(common.SessionData) - } else { - // init a new default session - sessionData = common.SessionData{ - Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, - SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"}, - SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"}, - Email: "", - Firstname: "", - Lastname: "", - InterfaceIdentifier: "", - IsAdmin: false, - LoggedIn: false, - } - session.Set(SessionIdentifier, sessionData) - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session: %v", err) - } - } - - return sessionData -} - -func GetFlashes(c *gin.Context) []common.FlashData { - session := sessions.Default(c) - flashes := session.Flashes() - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session after setting flash: %v", err) - } - - flashData := make([]common.FlashData, len(flashes)) - for i := range flashes { - flashData[i] = flashes[i].(common.FlashData) - } - - return flashData -} - -func UpdateSessionData(c *gin.Context, data common.SessionData) error { - session := sessions.Default(c) - session.Set(SessionIdentifier, data) - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session: %v", err) - return errors.Wrap(err, "failed to store session") - } - return nil -} - -func DestroySessionData(c *gin.Context) error { - session := sessions.Default(c) - session.Delete(SessionIdentifier) - if err := session.Save(); err != nil { - logrus.Errorf("failed to destroy session: %v", err) - return errors.Wrap(err, "failed to destroy session") - } - return nil -} diff --git a/cmd/wg-portal/ui/pages_core.go b/cmd/wg-portal/ui/pages_core.go index 4c2da1c..c9131b0 100644 --- a/cmd/wg-portal/ui/pages_core.go +++ b/cmd/wg-portal/ui/pages_core.go @@ -19,7 +19,7 @@ csrf "github.com/utrack/gin-csrf" ) -func (h *Handler) getStaticData() StaticData { +func (h *handler) getStaticData() StaticData { return StaticData{ WebsiteTitle: h.config.Core.Title, WebsiteLogo: h.config.Core.LogoUrl, @@ -29,17 +29,19 @@ } } -func (h *Handler) GetIndex(c *gin.Context) { - currentSession := GetSessionData(c) +func (h *handler) GetIndex() gin.HandlerFunc { + return func(c *gin.Context) { + currentSession := h.session.GetData(c) - c.HTML(http.StatusOK, "index.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": h.getStaticData(), - "Interface": nil, // TODO: load interface specified in the session - "InterfaceNames": map[string]string{"wgX": "wgX descr"}, - }) + c.HTML(http.StatusOK, "index.html", gin.H{ + "Route": c.Request.URL.Path, + "Alerts": h.session.GetFlashes(c), + "Session": currentSession, + "Static": h.getStaticData(), + "Interface": nil, // TODO: load interface specified in the session + "InterfaceNames": map[string]string{"wgX": "wgX descr"}, + }) + } } type LoginProviderInfo struct { @@ -47,168 +49,174 @@ Url string } -func (h *Handler) GetLogin(c *gin.Context) { - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - c.Redirect(http.StatusSeeOther, "/") // already logged in - } - - deepLink := c.DefaultQuery("dl", "") - authError := c.DefaultQuery("err", "") - errMsg := "Unknown error occurred, try again!" - switch authError { - case "missingdata": - errMsg = "Invalid login data retrieved, please fill out all fields and try again!" - case "authfail": - errMsg = "Authentication failed!" - case "loginreq": - errMsg = "Login required!" - } - - authProviders := make([]LoginProviderInfo, 0, len(h.config.Auth.OAuth)+len(h.config.Auth.OpenIDConnect)) - for _, provider := range h.config.Auth.OpenIDConnect { - providerId := strings.ToLower(provider.ProviderName) - providerName := provider.DisplayName - if providerName == "" { - providerName = provider.ProviderName +func (h *handler) handleLoginGet() gin.HandlerFunc { + return func(c *gin.Context) { + currentSession := h.session.GetData(c) + if currentSession.LoggedIn { + c.Redirect(http.StatusSeeOther, "/") // already logged in } - authProviders = append(authProviders, LoginProviderInfo{ - Name: template.HTML(providerName), - Url: "/auth/login/" + providerId, + + deepLink := c.DefaultQuery("dl", "") + authError := c.DefaultQuery("err", "") + errMsg := "Unknown error occurred, try again!" + switch authError { + case "missingdata": + errMsg = "Invalid login data retrieved, please fill out all fields and try again!" + case "authfail": + errMsg = "Authentication failed!" + case "loginreq": + errMsg = "Login required!" + case "tokenexchange": + errMsg = "Invalid OAuth token!" + } + + authProviders := make([]LoginProviderInfo, 0, len(h.config.Auth.OAuth)+len(h.config.Auth.OpenIDConnect)) + for _, provider := range h.config.Auth.OpenIDConnect { + providerId := strings.ToLower(provider.ProviderName) + providerName := provider.DisplayName + if providerName == "" { + providerName = provider.ProviderName + } + authProviders = append(authProviders, LoginProviderInfo{ + Name: template.HTML(providerName), + Url: "/auth/login/" + providerId, + }) + } + for _, provider := range h.config.Auth.OAuth { + providerId := strings.ToLower(provider.ProviderName) + providerName := provider.DisplayName + if providerName == "" { + providerName = provider.ProviderName + } + authProviders = append(authProviders, LoginProviderInfo{ + Name: template.HTML(providerName), + Url: "/auth/login/" + providerId, + }) + } + + c.HTML(http.StatusOK, "login.html", gin.H{ + "HasError": authError != "", + "Message": errMsg, + "DeepLink": deepLink, + "Static": h.getStaticData(), + "Csrf": csrf.GetToken(c), + "LoginProviders": authProviders, }) } - for _, provider := range h.config.Auth.OAuth { - providerId := strings.ToLower(provider.ProviderName) - providerName := provider.DisplayName - if providerName == "" { - providerName = provider.ProviderName +} + +func (h *handler) handleLoginPost() gin.HandlerFunc { + return func(c *gin.Context) { + currentSession := h.session.GetData(c) + if currentSession.LoggedIn { + c.Redirect(http.StatusSeeOther, "/") // already logged in } - authProviders = append(authProviders, LoginProviderInfo{ - Name: template.HTML(providerName), - Url: "/auth/login/" + providerId, - }) - } - c.HTML(http.StatusOK, "login.html", gin.H{ - "HasError": authError != "", - "Message": errMsg, - "DeepLink": deepLink, - "Static": h.getStaticData(), - "Csrf": csrf.GetToken(c), - "LoginProviders": authProviders, - }) + username := strings.ToLower(c.PostForm("username")) + password := c.PostForm("password") + deepLink := c.PostForm("_dl") + + // Validate form input + if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { + c.Redirect(http.StatusSeeOther, "/auth/login?err=missingdata") + return + } + + // TODO: implement db authentication + /*c.HTML(http.StatusOK, "login.html", gin.H{ + "HasError": authError != "", + "Message": errMsg, + "DeepLink": deepLink, + "Static": h.getStaticData(), + "Csrf": csrf.GetToken(c), + })*/ + + c.Redirect(http.StatusSeeOther, deepLink) + } } -func (h *Handler) PostLogin(c *gin.Context) { - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - c.Redirect(http.StatusSeeOther, "/") // already logged in - } +func (h *handler) handleLoginGetOauth() gin.HandlerFunc { + return func(c *gin.Context) { - username := strings.ToLower(c.PostForm("username")) - password := c.PostForm("password") - deepLink := c.PostForm("_dl") + providerId := c.Param("provider") + if _, ok := h.oauthAuthenticators[providerId]; !ok { + c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") + return + } - // Validate form input - if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { - c.Redirect(http.StatusSeeOther, "/auth/login?err=missingdata") - return - } + currentSession := h.session.GetData(c) + if currentSession.LoggedIn { + c.Redirect(http.StatusSeeOther, "/") // already logged in + } - // TODO: implement db authentication - /*c.HTML(http.StatusOK, "login.html", gin.H{ - "HasError": authError != "", - "Message": errMsg, - "DeepLink": deepLink, - "Static": h.getStaticData(), - "Csrf": csrf.GetToken(c), - })*/ - - c.Redirect(http.StatusSeeOther, deepLink) -} - -func (h *Handler) GetLoginOauth(c *gin.Context) { - providerId := c.Param("provider") - if _, ok := h.oauthAuthenticators[providerId]; !ok { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") - return - } - - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - c.Redirect(http.StatusSeeOther, "/") // already logged in - } - - // Prepare authentication flow, set state cookies - state, err := randString(16) - if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") - return - } - currentSession.OauthState = state - - authenticator := h.oauthAuthenticators[providerId] - - var authCodeUrl string - switch authenticator.GetType() { - case common.AuthenticatorTypeOAuth: - authCodeUrl = authenticator.AuthCodeURL(state) - case common.AuthenticatorTypeOidc: - nonce, err := randString(16) + // Prepare authentication flow, set state cookies + state, err := randString(16) if err != nil { c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") return } - currentSession.OidcNonce = nonce + currentSession.OauthState = state - authCodeUrl = authenticator.AuthCodeURL(state, oidc.Nonce(nonce)) + authenticator := h.oauthAuthenticators[providerId] + + var authCodeUrl string + switch authenticator.GetType() { + case common.AuthenticatorTypeOAuth: + authCodeUrl = authenticator.AuthCodeURL(state) + case common.AuthenticatorTypeOidc: + nonce, err := randString(16) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") + return + } + currentSession.OidcNonce = nonce + + authCodeUrl = authenticator.AuthCodeURL(state, oidc.Nonce(nonce)) + } + + h.session.SetData(c, currentSession) + + c.Redirect(http.StatusFound, authCodeUrl) } - - err = UpdateSessionData(c, currentSession) - if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=sessionerror") - return - } - - c.Redirect(http.StatusFound, authCodeUrl) - } -func (h *Handler) GetLoginOauthCallback(c *gin.Context) { - providerId := c.Param("provider") - if _, ok := h.oauthAuthenticators[providerId]; !ok { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") - return +func (h *handler) handleLoginGetOauthCallback() gin.HandlerFunc { + return func(c *gin.Context) { + providerId := c.Param("provider") + if _, ok := h.oauthAuthenticators[providerId]; !ok { + c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") + return + } + + currentSession := h.session.GetData(c) + ctx := c.Request.Context() + + if state := c.Query("state"); state != currentSession.OauthState { + c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidstate") + return + } + + authenticator := h.oauthAuthenticators[providerId] + oauthCode := c.Query("code") + oauth2Token, err := authenticator.Exchange(ctx, oauthCode) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=tokenexchange") + return + } + + rawUserInfo, err := authenticator.GetUserInfo(c.Request.Context(), oauth2Token, currentSession.OidcNonce) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=userinfofetch") + return + } + + userInfo, err := authenticator.ParseUserInfo(rawUserInfo) + + fmt.Println(userInfo) // TODO: implement login/registration process } - - currentSession := GetSessionData(c) - ctx := c.Request.Context() - - if state := c.Query("state"); state != currentSession.OauthState { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidstate") - return - } - - authenticator := h.oauthAuthenticators[providerId] - oauthCode := c.Query("code") - oauth2Token, err := authenticator.Exchange(ctx, oauthCode) - if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=tokenexchange") - return - } - - rawUserInfo, err := authenticator.GetUserInfo(c.Request.Context(), oauth2Token, currentSession.OidcNonce) - if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=userinfofetch") - return - } - - userInfo, err := authenticator.ParseUserInfo(rawUserInfo) - - fmt.Println(userInfo) // TODO: implement login/registration process } -func (h *Handler) passwordAuthentication(username, password string) (*persistence.User, error) { +func (h *handler) passwordAuthentication(username, password string) (*persistence.User, error) { err := h.backend.PlaintextAuthentication(persistence.UserIdentifier(username), password) if err != nil { return nil, errors.WithMessage(err, "failed to authenticate") diff --git a/cmd/wg-portal/ui/session.go b/cmd/wg-portal/ui/session.go new file mode 100644 index 0000000..bdd4bf1 --- /dev/null +++ b/cmd/wg-portal/ui/session.go @@ -0,0 +1,135 @@ +package ui + +import ( + "encoding/gob" + "fmt" + + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/h44z/wg-portal/internal/persistence" +) + +func init() { + gob.Register(SessionData{}) + gob.Register(FlashData{}) +} + +type SessionData struct { + AuthBackend string + OauthState string // oauth state + OidcNonce string // oidc id token nonce + LoggedIn bool + IsAdmin bool + UserIdentifier persistence.UserIdentifier + Firstname string + Lastname string + Email string + InterfaceIdentifier persistence.InterfaceIdentifier + + SortedBy map[string]string + SortDirection map[string]string + Search map[string]string + + AlertData string + AlertType string + FormData interface{} +} + +type FlashData struct { + HasAlert bool + Message string + Type string +} + +type SessionStore interface { + GetData(c *gin.Context) SessionData + SetData(c *gin.Context, data SessionData) + + GetFlashes(c *gin.Context) []FlashData + SetFlashes(c *gin.Context, flashes ...FlashData) + + RemoveData(c *gin.Context) + RemoveFlashes(c *gin.Context) +} + +type GinSessionStore struct { + sessionIdentifier string +} + +func (g GinSessionStore) GetData(c *gin.Context) SessionData { + session := sessions.Default(c) + rawSessionData := session.Get(g.sessionIdentifier) + + var sessionData SessionData + if rawSessionData != nil { + sessionData = rawSessionData.(SessionData) + } else { + // init a new default session + sessionData = SessionData{ + Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, + SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"}, + SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"}, + Email: "", + Firstname: "", + Lastname: "", + InterfaceIdentifier: "", + IsAdmin: false, + LoggedIn: false, + } + session.Set(g.sessionIdentifier, sessionData) + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } + } + + return sessionData +} + +func (g GinSessionStore) SetData(c *gin.Context, data SessionData) { + session := sessions.Default(c) + session.Set(g.sessionIdentifier, data) + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } +} + +func (g GinSessionStore) GetFlashes(c *gin.Context) []FlashData { + session := sessions.Default(c) + flashes := session.Flashes() + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } + + flashData := make([]FlashData, len(flashes)) + for i := range flashes { + flashData[i] = flashes[i].(FlashData) + } + + return flashData +} + +func (g GinSessionStore) SetFlashes(c *gin.Context, flashes ...FlashData) { + session := sessions.Default(c) + for i := range flashes { + session.AddFlash(flashes[i]) + } + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } +} + +func (g GinSessionStore) RemoveData(c *gin.Context) { + session := sessions.Default(c) + session.Delete(g.sessionIdentifier) + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } +} + +func (g GinSessionStore) RemoveFlashes(c *gin.Context) { + session := sessions.Default(c) + _ = session.Flashes() // Clear flashes + if err := session.Save(); err != nil { + panic(fmt.Sprintf("failed to store session: %v", err)) + } +}