diff --git a/.gitignore b/.gitignore index 8ab3cac..3435ea1 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ ssh.key .testCoverage.txt wg_portal.db +sqlite.db go.sum swagger.json swagger.yaml diff --git a/cmd/wg-portal/assets/tpl/login.html b/cmd/wg-portal/assets/tpl/login.html index 138d80b..810002c 100644 --- a/cmd/wg-portal/assets/tpl/login.html +++ b/cmd/wg-portal/assets/tpl/login.html @@ -30,7 +30,6 @@
-
@@ -62,11 +61,9 @@
- {{ if eq .HasError true }} -
@@ -76,7 +73,6 @@
- {{template "prt_flashes.html" .}} {{template "prt_footer.html" .}} diff --git a/cmd/wg-portal/common/config.go b/cmd/wg-portal/common/config.go index 6e89b4b..b9b2dd6 100644 --- a/cmd/wg-portal/common/config.go +++ b/cmd/wg-portal/common/config.go @@ -1,18 +1,21 @@ package common import ( + "os" + "github.com/h44z/wg-portal/internal/persistence" "github.com/h44z/wg-portal/internal/portal" + "gopkg.in/yaml.v3" ) type OauthFields struct { - UserIdentifier string - Email string - Firstname string - Lastname string - Phone string - Department string - IsAdmin string + UserIdentifier string `yaml:"user_identifier"` + Email string `yaml:"email"` + Firstname string `yaml:"firstname"` + Lastname string `yaml:"lastname"` + Phone string `yaml:"phone"` + Department string `yaml:"department"` + IsAdmin string `yaml:"is_admin"` } type LdapAuthProvider struct { @@ -20,75 +23,80 @@ type OpenIDConnectProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. - ProviderName string + ProviderName string `yaml:"provider_name"` // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. - DisplayName string + DisplayName string `yaml:"display_name"` - BaseUrl string + BaseUrl string `yaml:"base_url"` // ClientID is the application's ID. - ClientID string + ClientID string `yaml:"client_id"` // ClientSecret is the application's secret. - ClientSecret string + ClientSecret string `yaml:"client_secret"` - ExtraScopes []string + ExtraScopes []string `yaml:"extra_scopes"` - FieldMap OauthFields + FieldMap OauthFields `yaml:"field_map"` + + RegistrationEnabled bool `yaml:"registration_enabled"` } type OAuthProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. - ProviderName string + ProviderName string `yaml:"provider_name"` // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. - DisplayName string + DisplayName string `yaml:"display_name"` - BaseUrl string + BaseUrl string `yaml:"base_url"` // ClientID is the application's ID. - ClientID string + ClientID string `yaml:"client_id"` // ClientSecret is the application's secret. - ClientSecret string + ClientSecret string `yaml:"client_secret"` - AuthURL string - TokenURL string - UserInfoURL string + AuthURL string `yaml:"auth_url"` + TokenURL string `yaml:"token_url"` + UserInfoURL string `yaml:"user_info_url"` // RedirectURL is the URL to redirect users going through // the OAuth flow, after the resource owner's URLs. - RedirectURL string + RedirectURL string `yaml:"redirect_url"` // Scope specifies optional requested permissions. - Scopes []string + Scopes []string `yaml:"scopes"` // Fielmap contains - FieldMap OauthFields + FieldMap OauthFields `yaml:"field_map"` + + // If RegistrationEnabled is set to true, wg-portal will create new users that do not exist in the database. + RegistrationEnabled bool `yaml:"registration_enabled"` } type Config struct { Core struct { - GinDebug bool `yaml:"ginDebug" envconfig:"GIN_DEBUG"` - LogLevel string `yaml:"logLevel" envconfig:"LOG_LEVEL"` + GinDebug bool `yaml:"ginDebug"` + LogLevel string `yaml:"logLevel"` - ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"` - SessionSecret string `yaml:"sessionSecret" envconfig:"SESSION_SECRET"` + ListeningAddress string `yaml:"listeningAddress"` + SessionSecret string `yaml:"sessionSecret"` - ExternalUrl string `yaml:"externalUrl" envconfig:"EXTERNAL_URL"` - Title string `yaml:"title" envconfig:"WEBSITE_TITLE"` - CompanyName string `yaml:"company" envconfig:"COMPANY_NAME"` + ExternalUrl string `yaml:"externalUrl"` + Title string `yaml:"title"` + CompanyName string `yaml:"company"` // TODO: check... - AdminUser string `yaml:"adminUser" envconfig:"ADMIN_USER"` // must be an email address - AdminPassword string `yaml:"adminPass" envconfig:"ADMIN_PASS"` + AdminUser string `yaml:"adminUser"` // must be an email address + AdminPassword string `yaml:"adminPass"` - EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"` - CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"` - SelfProvisioningAllowed bool `yaml:"selfProvisioning" envconfig:"SELF_PROVISIONING"` - LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"` - LogoUrl string `yaml:"logoUrl" envconfig:"LOGO_URL"` + EditableKeys bool `yaml:"editableKeys"` + CreateDefaultPeer bool `yaml:"createDefaultPeer"` + SelfProvisioningAllowed bool `yaml:"selfProvisioning"` + LdapEnabled bool `yaml:"ldapEnabled"` + LogoUrl string `yaml:"logoUrl"` } `yaml:"core"` Auth struct { @@ -100,3 +108,19 @@ Mail portal.MailConfig `yaml:"email"` Database persistence.DatabaseConfig `yaml:"database"` } + +func LoadConfigFile(cfg interface{}, filename string) error { + f, err := os.Open(filename) + if err != nil { + return err + } + defer f.Close() + + decoder := yaml.NewDecoder(f) + err = decoder.Decode(cfg) + if err != nil { + return err + } + + return nil +} diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index ef86fea..acdf4da 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -55,17 +55,11 @@ cfg.Core.CompanyName = "Test Company" cfg.Core.LogoUrl = "/img/header-logo.png" - cfg.Auth.OpenIDConnect = []common.OpenIDConnectProvider{ - { - ProviderName: "google", - DisplayName: "Login with
Google", - BaseUrl: "https://accounts.google.com", - ClientID: "XXXX.apps.googleusercontent.com", - ClientSecret: "XXXX", - ExtraScopes: []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}, - }, + err := common.LoadConfigFile(&cfg, "config.yml") + if err != nil { + logrus.Errorf("failed to load config file: %v", err) + return } - // TODO: load config srv, err := NewServer(cfg) if err != nil { diff --git a/cmd/wg-portal/ui/handler.go b/cmd/wg-portal/ui/handler.go index 9a0ed86..5f300ef 100644 --- a/cmd/wg-portal/ui/handler.go +++ b/cmd/wg-portal/ui/handler.go @@ -103,7 +103,7 @@ auth.POST("/login", h.handleLoginPost()) auth.GET("/login/:provider", h.handleLoginGetOauth()) auth.GET("/login/:provider/callback", h.handleLoginGetOauthCallback()) - //auth.GET("/logout", s.GetLogout) + auth.GET("/logout", h.handleLogoutGet()) // Admin routes diff --git a/cmd/wg-portal/ui/pages_core.go b/cmd/wg-portal/ui/pages_core.go index c9131b0..38067ea 100644 --- a/cmd/wg-portal/ui/pages_core.go +++ b/cmd/wg-portal/ui/pages_core.go @@ -3,7 +3,6 @@ import ( "crypto/rand" "encoding/base64" - "fmt" "html/template" "io" "net/http" @@ -49,6 +48,20 @@ Url string } +func (h *handler) handleLogoutGet() gin.HandlerFunc { + return func(c *gin.Context) { + currentSession := h.session.GetData(c) + + if !currentSession.LoggedIn { // Not logged in + c.Redirect(http.StatusSeeOther, "/") + return + } + + h.session.DestroyData(c) + c.Redirect(http.StatusSeeOther, "/") + } +} + func (h *handler) handleLoginGet() gin.HandlerFunc { return func(c *gin.Context) { currentSession := h.session.GetData(c) @@ -56,20 +69,6 @@ c.Redirect(http.StatusSeeOther, "/") // already logged in } - deepLink := c.DefaultQuery("dl", "") - authError := c.DefaultQuery("err", "") - errMsg := "Unknown error occurred, try again!" - switch authError { - case "missingdata": - errMsg = "Invalid login data retrieved, please fill out all fields and try again!" - case "authfail": - errMsg = "Authentication failed!" - case "loginreq": - errMsg = "Login required!" - case "tokenexchange": - errMsg = "Invalid OAuth token!" - } - authProviders := make([]LoginProviderInfo, 0, len(h.config.Auth.OAuth)+len(h.config.Auth.OpenIDConnect)) for _, provider := range h.config.Auth.OpenIDConnect { providerId := strings.ToLower(provider.ProviderName) @@ -95,9 +94,7 @@ } c.HTML(http.StatusOK, "login.html", gin.H{ - "HasError": authError != "", - "Message": errMsg, - "DeepLink": deepLink, + "Alerts": h.session.GetFlashes(c), "Static": h.getStaticData(), "Csrf": csrf.GetToken(c), "LoginProviders": authProviders, @@ -114,7 +111,6 @@ username := strings.ToLower(c.PostForm("username")) password := c.PostForm("password") - deepLink := c.PostForm("_dl") // Validate form input if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { @@ -131,7 +127,12 @@ "Csrf": csrf.GetToken(c), })*/ - c.Redirect(http.StatusSeeOther, deepLink) + nextUrl := "/" + if currentSession.DeeplLink != "" { + nextUrl = currentSession.DeeplLink + } + + c.Redirect(http.StatusSeeOther, nextUrl) } } @@ -140,7 +141,7 @@ providerId := c.Param("provider") if _, ok := h.oauthAuthenticators[providerId]; !ok { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: "Invalid login provider", Type: "danger"}) return } @@ -152,7 +153,7 @@ // Prepare authentication flow, set state cookies state, err := randString(16) if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) return } currentSession.OauthState = state @@ -166,7 +167,7 @@ case common.AuthenticatorTypeOidc: nonce, err := randString(16) if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) return } currentSession.OidcNonce = nonce @@ -184,7 +185,7 @@ return func(c *gin.Context) { providerId := c.Param("provider") if _, ok := h.oauthAuthenticators[providerId]; !ok { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: "Invalid login provider", Type: "danger"}) return } @@ -192,7 +193,7 @@ ctx := c.Request.Context() if state := c.Query("state"); state != currentSession.OauthState { - c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidstate") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: "Invalid OAuth state", Type: "danger"}) return } @@ -200,19 +201,36 @@ oauthCode := c.Query("code") oauth2Token, err := authenticator.Exchange(ctx, oauthCode) if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=tokenexchange") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) return } rawUserInfo, err := authenticator.GetUserInfo(c.Request.Context(), oauth2Token, currentSession.OidcNonce) if err != nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=userinfofetch") + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) return } userInfo, err := authenticator.ParseUserInfo(rawUserInfo) + if err != nil { + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) + return + } - fmt.Println(userInfo) // TODO: implement login/registration process + sessionData, err := h.prepareUserSession(userInfo, providerId) + if err != nil { + h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) + return + } + + h.session.SetData(c, sessionData) + + nextUrl := "/" + if currentSession.DeeplLink != "" { + nextUrl = currentSession.DeeplLink + } + + c.Redirect(http.StatusSeeOther, nextUrl) } } @@ -226,6 +244,80 @@ return nil, nil } +func (h *handler) getAuthenticatorConfig(id string) (interface{}, error) { + for i := range h.config.Auth.OpenIDConnect { + if h.config.Auth.OpenIDConnect[i].ProviderName == id { + return h.config.Auth.OpenIDConnect[i], nil + } + } + + for i := range h.config.Auth.OAuth { + if h.config.Auth.OAuth[i].ProviderName == id { + return h.config.Auth.OAuth[i], nil + } + } + + return nil, errors.Errorf("no configuration for authenticator id %s", id) +} + +func (h *handler) prepareUserSession(userInfo *common.AuthenticatorUserInfo, providerId string) (SessionData, error) { + session := h.session.DefaultSessionData() + authenticatorCfg, err := h.getAuthenticatorConfig(providerId) + if err != nil { + return session, errors.WithMessagef(err, "failed to find auth provider config for %s", providerId) + } + registrationEnabled := false + switch cfg := authenticatorCfg.(type) { + case common.OAuthProvider: + registrationEnabled = cfg.RegistrationEnabled + case common.OpenIDConnectProvider: + registrationEnabled = cfg.RegistrationEnabled + } + + // Search user in backend + user, err := h.backend.GetUser(userInfo.Identifier) + switch { + case err != nil && registrationEnabled: + user, err = h.registerOauthUser(userInfo) + if err != nil { + return session, errors.WithMessage(err, "failed to register user") + } + case err != nil: + return session, errors.WithMessage(err, "registration disabled, cannot create missing user") + } + + // Set session data for user + session.LoggedIn = true + session.UserIdentifier = user.Identifier + session.IsAdmin = user.IsAdmin + session.Firstname = user.Firstname + session.Lastname = user.Lastname + session.Email = user.Email + + return session, nil +} + +func (h *handler) registerOauthUser(userInfo *common.AuthenticatorUserInfo) (*persistence.User, error) { + user := &persistence.User{ + Identifier: userInfo.Identifier, + Email: userInfo.Email, + Source: persistence.UserSourceOauth, + IsAdmin: userInfo.IsAdmin, + Firstname: userInfo.Firstname, + Lastname: userInfo.Lastname, + Phone: userInfo.Phone, + Department: userInfo.Department, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + err := h.backend.CreateUser(user) + if err != nil { + return nil, errors.WithMessage(err, "failed to create new user") + } + + return user, nil +} + func randString(nByte int) (string, error) { b := make([]byte, nByte) if _, err := io.ReadFull(rand.Reader, b); err != nil { @@ -233,3 +325,8 @@ } return base64.RawURLEncoding.EncodeToString(b), nil } + +func (h *handler) redirectWithFlash(c *gin.Context, url string, flash FlashData) { + h.session.SetFlashes(c, flash) + c.Redirect(http.StatusSeeOther, url) +} diff --git a/cmd/wg-portal/ui/session.go b/cmd/wg-portal/ui/session.go index bdd4bf1..ac1966c 100644 --- a/cmd/wg-portal/ui/session.go +++ b/cmd/wg-portal/ui/session.go @@ -15,15 +15,18 @@ } type SessionData struct { - AuthBackend string - OauthState string // oauth state - OidcNonce string // oidc id token nonce - LoggedIn bool - IsAdmin bool - UserIdentifier persistence.UserIdentifier - Firstname string - Lastname string - Email string + DeeplLink string // deep link, used to redirect after a successful login + + OauthState string // oauth state + OidcNonce string // oidc id token nonce + + LoggedIn bool + IsAdmin bool + UserIdentifier persistence.UserIdentifier + Firstname string + Lastname string + Email string + InterfaceIdentifier persistence.InterfaceIdentifier SortedBy map[string]string @@ -36,19 +39,20 @@ } type FlashData struct { - HasAlert bool - Message string - Type string + Message string + Type string // flash type, for example: danger, success, warning, info, primary } type SessionStore interface { + DefaultSessionData() SessionData + GetData(c *gin.Context) SessionData SetData(c *gin.Context, data SessionData) GetFlashes(c *gin.Context) []FlashData SetFlashes(c *gin.Context, flashes ...FlashData) - RemoveData(c *gin.Context) + DestroyData(c *gin.Context) RemoveFlashes(c *gin.Context) } @@ -65,17 +69,7 @@ sessionData = rawSessionData.(SessionData) } else { // init a new default session - sessionData = SessionData{ - Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, - SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"}, - SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"}, - Email: "", - Firstname: "", - Lastname: "", - InterfaceIdentifier: "", - IsAdmin: false, - LoggedIn: false, - } + sessionData = g.DefaultSessionData() session.Set(g.sessionIdentifier, sessionData) if err := session.Save(); err != nil { panic(fmt.Sprintf("failed to store session: %v", err)) @@ -85,6 +79,20 @@ return sessionData } +func (g GinSessionStore) DefaultSessionData() SessionData { + return SessionData{ + Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, + SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"}, + SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"}, + Email: "", + Firstname: "", + Lastname: "", + InterfaceIdentifier: "", + IsAdmin: false, + LoggedIn: false, + } +} + func (g GinSessionStore) SetData(c *gin.Context, data SessionData) { session := sessions.Default(c) session.Set(g.sessionIdentifier, data) @@ -118,7 +126,7 @@ } } -func (g GinSessionStore) RemoveData(c *gin.Context) { +func (g GinSessionStore) DestroyData(c *gin.Context) { session := sessions.Default(c) session.Delete(g.sessionIdentifier) if err := session.Save(); err != nil { diff --git a/internal/persistence/models.go b/internal/persistence/models.go index 2de00af..c892b3f 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -125,9 +125,9 @@ type UserSource string const ( - UserSourceLdap UserSource = "ldap" // LDAP / ActiveDirectory - UserSourceDatabase UserSource = "db" // sqlite / mysql database - UserSourceOIDC UserSource = "oidc" // open id connect, TODO: implement + UserSourceLdap UserSource = "ldap" // LDAP / ActiveDirectory + UserSourceDatabase UserSource = "db" // sqlite / mysql database + UserSourceOauth UserSource = "oauth" // oauth / open id connect ) type PrivateString string diff --git a/internal/portal/backend.go b/internal/portal/backend.go index bc4f1a9..c8d03b7 100644 --- a/internal/portal/backend.go +++ b/internal/portal/backend.go @@ -54,7 +54,7 @@ return b, nil } -// ImportInterface imports an interface. The given interface identifier must be available as importable interface. +// ImportInterfaceById imports an interface. The given interface identifier must be available as importable interface. func (b *PersistentBackend) ImportInterfaceById(identifier persistence.InterfaceIdentifier) error { importable, err := b.GetImportableInterfaces() if err != nil { diff --git a/internal/user/manager.go b/internal/user/manager.go index 68646c1..41793f7 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -9,6 +9,7 @@ ) type Loader interface { + GetActiveUser(id persistence.UserIdentifier) (*persistence.User, error) GetUser(id persistence.UserIdentifier) (*persistence.User, error) GetActiveUsers() ([]*persistence.User, error) GetAllUsers() ([]*persistence.User, error) @@ -56,6 +57,10 @@ users: make(map[persistence.UserIdentifier]*persistence.User), } + if err := mgr.initializeFromStore(); err != nil { + return nil, errors.WithMessage(err, "failed to initialize manager from store") + } + return mgr, nil } @@ -67,6 +72,17 @@ return nil, errors.New("no such user exists") } + return p.users[id], nil +} + +func (p *PersistentManager) GetActiveUser(id persistence.UserIdentifier) (*persistence.User, error) { + p.mux.RLock() + defer p.mux.RUnlock() + + if !p.userExists(id) { + return nil, errors.New("no such user exists") + } + if !p.userIsEnabled(id) { return nil, errors.New("user is disabled") }