diff --git a/cmd/wg-portal/common/config.go b/cmd/wg-portal/common/config.go index bd466aa..ed2edaa 100644 --- a/cmd/wg-portal/common/config.go +++ b/cmd/wg-portal/common/config.go @@ -5,6 +5,14 @@ "github.com/h44z/wg-portal/internal/portal" ) +type OauthFields struct { + UserIdentifier string + Email string + Firstname string + Lastname string + IsAdmin string +} + type OpenIDConnectProvider struct { // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. ProviderName string @@ -21,6 +29,8 @@ ClientSecret string Scopes []string + + FieldMap OauthFields } type OAuthProvider struct { @@ -48,6 +58,9 @@ // Scope specifies optional requested permissions. Scopes []string + + // Fielmap contains + FieldMap OauthFields } type Config struct { diff --git a/cmd/wg-portal/common/session.go b/cmd/wg-portal/common/session.go index f9ce742..999325e 100644 --- a/cmd/wg-portal/common/session.go +++ b/cmd/wg-portal/common/session.go @@ -12,6 +12,8 @@ } type SessionData struct { + OauthState string // oauth state + OidcNonce string // oidc id token nonce LoggedIn bool IsAdmin bool UserIdentifier persistence.UserIdentifier diff --git a/cmd/wg-portal/ui/handler.go b/cmd/wg-portal/ui/handler.go index 15decdb..1486f2c 100644 --- a/cmd/wg-portal/ui/handler.go +++ b/cmd/wg-portal/ui/handler.go @@ -7,7 +7,7 @@ "golang.org/x/oauth2" - "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/cmd/wg-portal/common" @@ -30,7 +30,8 @@ backend portal.Backend authProviderNames map[string]AuthProviderType oidcProviders map[string]*oidc.Provider - oauthConfigs map[string]oauth2.Config + oidcVerifiers map[string]*oidc.IDTokenVerifier + oauthConfigs map[string]*oauth2.Config } func NewHandler(config *common.Config, backend portal.Backend) (*Handler, error) { @@ -39,48 +40,61 @@ backend: backend, authProviderNames: make(map[string]AuthProviderType), oidcProviders: make(map[string]*oidc.Provider), - oauthConfigs: make(map[string]oauth2.Config), + oidcVerifiers: make(map[string]*oidc.IDTokenVerifier), + oauthConfigs: make(map[string]*oauth2.Config), } - extUrl, err := url.Parse(config.Core.ExternalUrl) + err := h.setupAuthProviders() if err != nil { - return nil, errors.WithMessage(err, "failed to parse external url") + return nil, errors.WithMessage(err, "failed to setup authentication providers") + } + + return h, nil +} + +func (h *Handler) setupAuthProviders() error { + extUrl, err := url.Parse(h.config.Core.ExternalUrl) + if err != nil { + return errors.WithMessage(err, "failed to parse external url") } for _, provider := range h.config.Auth.OpenIDConnect { if _, exists := h.authProviderNames[provider.ProviderName]; exists { - return nil, errors.Errorf("auth provider with name %s is already registerd", provider.ProviderName) + return errors.Errorf("auth provider with name %s is already registerd", provider.ProviderName) } h.authProviderNames[provider.ProviderName] = AuthProviderTypeOpenIDConnect var err error h.oidcProviders[provider.ProviderName], err = oidc.NewProvider(context.Background(), provider.BaseUrl) if err != nil { - return nil, errors.WithMessagef(err, "failed to setup oidc provider %s", provider.ProviderName) + return errors.WithMessagef(err, "failed to setup oidc provider %s", provider.ProviderName) } + h.oidcVerifiers[provider.ProviderName] = h.oidcProviders[provider.ProviderName].Verifier(&oidc.Config{ + ClientID: provider.ClientID, + }) - redirecUrl := *extUrl - redirecUrl.Path = path.Join(redirecUrl.Path, "/auth/login/", provider.ProviderName, "/callback") + redirectUrl := *extUrl + redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", provider.ProviderName, "/callback") scopes := []string{oidc.ScopeOpenID} scopes = append(scopes, provider.Scopes...) - h.oauthConfigs[provider.ProviderName] = oauth2.Config{ + h.oauthConfigs[provider.ProviderName] = &oauth2.Config{ ClientID: provider.ClientID, ClientSecret: provider.ClientSecret, Endpoint: h.oidcProviders[provider.ProviderName].Endpoint(), - RedirectURL: redirecUrl.String(), + RedirectURL: redirectUrl.String(), Scopes: scopes, } } for _, provider := range h.config.Auth.OAuth { if _, exists := h.authProviderNames[provider.ProviderName]; exists { - return nil, errors.Errorf("auth provider with name %s is already registerd", provider.ProviderName) + return errors.Errorf("auth provider with name %s is already registerd", provider.ProviderName) } h.authProviderNames[provider.ProviderName] = AuthProviderTypeOAuth // TODO } - return h, nil + return nil } func (h *Handler) RegisterRoutes(g *gin.Engine) { diff --git a/cmd/wg-portal/ui/pages_core.go b/cmd/wg-portal/ui/pages_core.go index 5b4a8d5..6155294 100644 --- a/cmd/wg-portal/ui/pages_core.go +++ b/cmd/wg-portal/ui/pages_core.go @@ -1,11 +1,18 @@ package ui import ( + "crypto/rand" + "encoding/base64" "html/template" + "io" "net/http" "strings" "time" + "github.com/h44z/wg-portal/internal/persistence" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/internal" csrf "github.com/utrack/gin-csrf" @@ -119,23 +126,102 @@ } func (h *Handler) GetLoginOauth(c *gin.Context) { - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - c.Redirect(http.StatusSeeOther, "/") // already logged in - } - provider := c.Param("provider") if _, ok := h.authProviderNames[provider]; !ok { c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") return } + currentSession := GetSessionData(c) + if currentSession.LoggedIn { + c.Redirect(http.StatusSeeOther, "/") // already logged in + } + + // Prepare authentication flow, set state cookies + state, err := randString(16) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") + return + } + currentSession.OauthState = state + switch h.authProviderNames[provider] { case AuthProviderTypeOAuth: + c.Redirect(http.StatusFound, h.oauthConfigs[provider].AuthCodeURL(state)) + return case AuthProviderTypeOpenIDConnect: + nonce, err := randString(16) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=randsrcunavailable") + return + } + currentSession.OidcNonce = nonce + + c.Redirect(http.StatusFound, h.oauthConfigs[provider].AuthCodeURL(state, oidc.Nonce(nonce))) + return } } func (h *Handler) GetLoginOauthCallback(c *gin.Context) { - //code := c.PostForm("code") + provider := c.Param("provider") + if _, ok := h.authProviderNames[provider]; !ok { + c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidprovider") + return + } + + currentSession := GetSessionData(c) + ctx := c.Request.Context() + + if state := c.Query("state"); state != currentSession.OauthState { + c.Redirect(http.StatusSeeOther, "/auth/login?err=invalidstate") + return + } + + oauth2Token, err := h.oauthConfigs[provider].Exchange(ctx, c.Query("code")) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=tokenexchange") + return + } + + switch h.authProviderNames[provider] { + case AuthProviderTypeOAuth: + // TODO + case AuthProviderTypeOpenIDConnect: + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + c.Redirect(http.StatusSeeOther, "/auth/login?err=missingidtoken") + return + } + idToken, err := h.oidcVerifiers[provider].Verify(ctx, rawIDToken) + if err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=idtokeninvalid") + return + } + if idToken.Nonce != currentSession.OidcNonce { + c.Redirect(http.StatusSeeOther, "/auth/login?err=idtokennonce") + return + } + + // TODO: check if user exists in db, if not, maybe create? (if registration is allowed) + + currentSession.LoggedIn = true + currentSession.UserIdentifier = persistence.UserIdentifier(idToken.Subject) + + var extraFields map[string]interface{} + if err = idToken.Claims(&extraFields); err != nil { + c.Redirect(http.StatusSeeOther, "/auth/login?err=claimsparsing") + return + } + + // TODO: use FieldMap to get extra fields + //currentSession.Email = extraFields[mappedName] + } +} + +func randString(nByte int) (string, error) { + b := make([]byte, nByte) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil } diff --git a/go.mod b/go.mod index 82d7416..7c4e63a 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,12 @@ go 1.16 require ( - github.com/coreos/go-oidc v2.2.1+incompatible + github.com/coreos/go-oidc/v3 v3.1.0 github.com/gin-contrib/sessions v0.0.3 github.com/gin-gonic/gin v1.7.4 github.com/kr/text v0.2.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pkg/errors v0.9.1 - github.com/pquerna/cachecontrol v0.1.0 // indirect github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.7.0 github.com/toorop/gin-logrus v0.0.0-20210225092905-2c785434f26f @@ -17,7 +16,7 @@ github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca github.com/vishvananda/netlink v1.1.0 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/oauth2 v0.0.0-20211005180243-6b3c2da341f1 // indirect + golang.org/x/oauth2 v0.0.0-20211005180243-6b3c2da341f1 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210506160403-92e472f520a5 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect