diff --git a/cmd/cli/main.go b/cmd/cli/main.go index b3e370c..ba7192b 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -6,9 +6,9 @@ "os" "strings" + "github.com/h44z/wg-portal/internal/core" "github.com/h44z/wg-portal/internal/persistence" - "github.com/h44z/wg-portal/internal/portal" "github.com/pkg/errors" "github.com/urfave/cli/v2" @@ -19,7 +19,7 @@ interfaceFlag = "interface" ) -var backend portal.Backend +var backend core.Backend var globalFlags = []cli.Flag{ &cli.StringFlag{ @@ -167,7 +167,7 @@ return errors.WithMessagef(err, "failed to initialize persistent store") } - backend, err = portal.NewPersistentBackend(database) + backend, err = core.NewPersistentBackend(database) if err != nil { return errors.WithMessagef(err, "backend failed to initialize") } diff --git a/cmd/wg-portal/common/config.go b/cmd/wg-portal/common/config.go index 3929549..e3178d8 100644 --- a/cmd/wg-portal/common/config.go +++ b/cmd/wg-portal/common/config.go @@ -3,115 +3,13 @@ import ( "os" - "github.com/pkg/errors" - - "github.com/go-ldap/ldap/v3" + "github.com/h44z/wg-portal/internal/authentication" + "github.com/h44z/wg-portal/internal/core" "github.com/h44z/wg-portal/internal/persistence" - "github.com/h44z/wg-portal/internal/portal" + "github.com/pkg/errors" "gopkg.in/yaml.v3" ) -type BaseFields struct { - UserIdentifier string `yaml:"user_identifier"` - Email string `yaml:"email"` - Firstname string `yaml:"firstname"` - Lastname string `yaml:"lastname"` - Phone string `yaml:"phone"` - Department string `yaml:"department"` -} - -type OauthFields struct { - BaseFields `yaml:",inline"` - IsAdmin string `yaml:"is_admin"` -} - -type LdapFields struct { - BaseFields `yaml:",inline"` - GroupMembership string `yaml:"memberof"` -} - -type LdapProvider struct { - URL string `yaml:"url"` - StartTLS bool `yaml:"start_tls"` - CertValidation bool `yaml:"cert_validation"` - BaseDN string `yaml:"base_dn"` - BindUser string `yaml:"bind_user"` - BindPass string `yaml:"bind_pass"` - - FieldMap LdapFields `yaml:"field_map"` - - LoginFilter string `yaml:"login_filter"` // {{login_identifier}} gets replaced with the login email address - AdminGroupDN string `yaml:"admin_group"` // Members of this group receive admin rights in WG-Portal - adminGroupDN *ldap.DN `yaml:"-"` - - Synchronize bool `yaml:"synchronize"` - - // If DeleteMissing is false, missing users will be deactivated - DeleteMissing bool `yaml:"delete_missing"` - SyncFilter string `yaml:"sync_filter"` - - // If RegistrationEnabled is set to true, wg-portal will create new users that do not exist in the database. - RegistrationEnabled bool `yaml:"registration_enabled"` -} - -type OpenIDConnectProvider struct { - // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. - ProviderName string `yaml:"provider_name"` - - // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. - DisplayName string `yaml:"display_name"` - - BaseUrl string `yaml:"base_url"` - - // ClientID is the application's ID. - ClientID string `yaml:"client_id"` - - // ClientSecret is the application's secret. - ClientSecret string `yaml:"client_secret"` - - // ExtraScopes specifies optional requested permissions. - ExtraScopes []string `yaml:"extra_scopes"` - - // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields - FieldMap OauthFields `yaml:"field_map"` - - // If RegistrationEnabled is set to true, missing users will be created in the database - RegistrationEnabled bool `yaml:"registration_enabled"` -} - -type OAuthProvider struct { - // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. - ProviderName string `yaml:"provider_name"` - - // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. - DisplayName string `yaml:"display_name"` - - BaseUrl string `yaml:"base_url"` - - // ClientID is the application's ID. - ClientID string `yaml:"client_id"` - - // ClientSecret is the application's secret. - ClientSecret string `yaml:"client_secret"` - - AuthURL string `yaml:"auth_url"` - TokenURL string `yaml:"token_url"` - UserInfoURL string `yaml:"user_info_url"` - - // RedirectURL is the URL to redirect users going through - // the OAuth flow, after the resource owner's URLs. - RedirectURL string `yaml:"redirect_url"` - - // Scope specifies optional requested permissions. - Scopes []string `yaml:"scopes"` - - // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields - FieldMap OauthFields `yaml:"field_map"` - - // If RegistrationEnabled is set to true, wg-portal will create new users that do not exist in the database. - RegistrationEnabled bool `yaml:"registration_enabled"` -} - type Config struct { Core struct { GinDebug bool `yaml:"gin_debug"` @@ -136,12 +34,12 @@ } `yaml:"core"` Auth struct { - OpenIDConnect []OpenIDConnectProvider `yaml:"oidc"` - OAuth []OAuthProvider `yaml:"oauth"` - Ldap []LdapProvider `yaml:"ldap"` + OpenIDConnect []authentication.OpenIDConnectProvider `yaml:"oidc"` + OAuth []authentication.OAuthProvider `yaml:"oauth"` + Ldap []authentication.LdapProvider `yaml:"ldap"` } `yaml:"auth"` - Mail portal.MailConfig `yaml:"email"` + Mail core.MailConfig `yaml:"email"` Database persistence.DatabaseConfig `yaml:"database"` } diff --git a/cmd/wg-portal/common/ldap.go b/cmd/wg-portal/common/ldap.go deleted file mode 100644 index a3c6494..0000000 --- a/cmd/wg-portal/common/ldap.go +++ /dev/null @@ -1,294 +0,0 @@ -package common - -import ( - "context" - "crypto/tls" - "strings" - - "github.com/pkg/errors" - - "github.com/go-ldap/ldap/v3" - - "github.com/h44z/wg-portal/internal/persistence" - "github.com/h44z/wg-portal/internal/user" -) - -type LdapAuthenticator interface { - user.Authenticator - GetAllUserInfos(ctx context.Context) ([]map[string]interface{}, error) - GetUserInfo(ctx context.Context, username persistence.UserIdentifier) (map[string]interface{}, error) - ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) - RegistrationEnabled() bool - SynchronizationEnabled() bool -} - -type ldapAuthenticator struct { - cfg *LdapProvider -} - -func NewLdapAuthenticator(_ context.Context, cfg *LdapProvider) (*ldapAuthenticator, error) { - var authenticator = &ldapAuthenticator{} - - authenticator.cfg = cfg - - dn, err := ldap.ParseDN(cfg.AdminGroupDN) - if err != nil { - return nil, errors.WithMessage(err, "failed to parse admin group DN") - } - authenticator.cfg.FieldMap = getLdapFieldMapping(cfg.FieldMap) - authenticator.cfg.adminGroupDN = dn - - return authenticator, nil -} - -func (l *ldapAuthenticator) RegistrationEnabled() bool { - return l.cfg.RegistrationEnabled -} - -func (l *ldapAuthenticator) SynchronizationEnabled() bool { - return l.cfg.Synchronize -} - -func (l *ldapAuthenticator) PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error { - conn, err := l.connect() - if err != nil { - return errors.WithMessage(err, "failed to setup connection") - } - defer l.disconnect(conn) - - attrs := []string{"dn"} - - loginFilter := strings.Replace(l.cfg.LoginFilter, "{{login_identifier}}", string(userId), -1) - searchRequest := ldap.NewSearchRequest( - l.cfg.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit - loginFilter, attrs, nil, - ) - - sr, err := conn.Search(searchRequest) - if err != nil { - return errors.Wrapf(err, "failed to search in ldap") - } - - if len(sr.Entries) == 0 { - return errors.New("user not found") - } - - if len(sr.Entries) > 1 { - return errors.New("no unique user found") - } - - // Bind as the user to verify their password - userDN := sr.Entries[0].DN - err = conn.Bind(userDN, plainPassword) - if err != nil { - return errors.Wrapf(err, "invalid credentials") - } - _ = conn.Unbind() - - return nil -} - -func (l *ldapAuthenticator) HashedAuthentication(_ persistence.UserIdentifier, _ string) error { - // TODO: is this possible? - return errors.New("unimplemented") -} - -func (l *ldapAuthenticator) GetUserInfo(_ context.Context, userId persistence.UserIdentifier) (map[string]interface{}, error) { - conn, err := l.connect() - if err != nil { - return nil, errors.WithMessage(err, "failed to setup connection") - } - defer l.disconnect(conn) - - attrs := l.getLdapSearchAttributes() - - loginFilter := strings.Replace(l.cfg.LoginFilter, "{{login_identifier}}", string(userId), -1) - searchRequest := ldap.NewSearchRequest( - l.cfg.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit - loginFilter, attrs, nil, - ) - - sr, err := conn.Search(searchRequest) - if err != nil { - return nil, errors.Wrapf(err, "failed to search in ldap") - } - - if len(sr.Entries) == 0 { - return nil, errors.New("user not found") - } - - if len(sr.Entries) > 1 { - return nil, errors.New("no unique user found") - } - - users := l.convertLdapEntries(sr) - - return users[0], nil -} - -func (l *ldapAuthenticator) GetAllUserInfos(_ context.Context) ([]map[string]interface{}, error) { - conn, err := l.connect() - if err != nil { - return nil, errors.WithMessage(err, "failed to setup connection") - } - defer l.disconnect(conn) - - attrs := l.getLdapSearchAttributes() - - searchRequest := ldap.NewSearchRequest( - l.cfg.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit - l.cfg.SyncFilter, attrs, nil, - ) - - sr, err := conn.Search(searchRequest) - if err != nil { - return nil, errors.Wrapf(err, "failed to search in ldap") - } - - users := l.convertLdapEntries(sr) - - return users, nil -} - -func (l *ldapAuthenticator) convertLdapEntries(sr *ldap.SearchResult) []map[string]interface{} { - users := make([]map[string]interface{}, len(sr.Entries)) - - fieldMap := l.cfg.FieldMap - for i, entry := range sr.Entries { - userData := make(map[string]interface{}) - userData[fieldMap.UserIdentifier] = entry.DN - userData[fieldMap.Email] = entry.GetAttributeValue(fieldMap.Email) - userData[fieldMap.Firstname] = entry.GetAttributeValue(fieldMap.Firstname) - userData[fieldMap.Lastname] = entry.GetAttributeValue(fieldMap.Lastname) - userData[fieldMap.Phone] = entry.GetAttributeValue(fieldMap.Phone) - userData[fieldMap.Department] = entry.GetAttributeValue(fieldMap.Department) - userData[fieldMap.GroupMembership] = entry.GetRawAttributeValues(fieldMap.GroupMembership) - - users[i] = userData - } - return users -} - -func (l *ldapAuthenticator) getLdapSearchAttributes() []string { - fieldMap := l.cfg.FieldMap - attrs := []string{"dn", fieldMap.UserIdentifier} - if fieldMap.Email != "" { - attrs = append(attrs, fieldMap.Email) - } - if fieldMap.Firstname != "" { - attrs = append(attrs, fieldMap.Firstname) - } - if fieldMap.Lastname != "" { - attrs = append(attrs, fieldMap.Lastname) - } - if fieldMap.Phone != "" { - attrs = append(attrs, fieldMap.Phone) - } - if fieldMap.Department != "" { - attrs = append(attrs, fieldMap.Department) - } - if fieldMap.GroupMembership != "" { - attrs = append(attrs, fieldMap.GroupMembership) - } - - return uniqueStringSlice(attrs) -} - -func (l ldapAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { - isAdmin, err := userIsInAdminGroup(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.adminGroupDN) - if err != nil { - return nil, errors.WithMessage(err, "failed to check admin group") - } - userInfo := &AuthenticatorUserInfo{ - Identifier: persistence.UserIdentifier(mapDefaultString(raw, l.cfg.FieldMap.UserIdentifier, "")), - Email: mapDefaultString(raw, l.cfg.FieldMap.Email, ""), - Firstname: mapDefaultString(raw, l.cfg.FieldMap.Firstname, ""), - Lastname: mapDefaultString(raw, l.cfg.FieldMap.Lastname, ""), - Phone: mapDefaultString(raw, l.cfg.FieldMap.Phone, ""), - Department: mapDefaultString(raw, l.cfg.FieldMap.Department, ""), - IsAdmin: isAdmin, - } - - return userInfo, nil -} - -func (l *ldapAuthenticator) connect() (*ldap.Conn, error) { - tlsConfig := &tls.Config{InsecureSkipVerify: !l.cfg.CertValidation} - conn, err := ldap.DialURL(l.cfg.URL, ldap.DialWithTLSConfig(tlsConfig)) - if err != nil { - return nil, errors.Wrap(err, "failed to connect to LDAP") - } - - if l.cfg.StartTLS { // Reconnect with TLS - if err = conn.StartTLS(tlsConfig); err != nil { - return nil, errors.Wrap(err, "failed to start TLS on connection") - } - } - - if err = conn.Bind(l.cfg.BindUser, l.cfg.BindPass); err != nil { - return nil, errors.Wrap(err, "failed to bind to LDAP") - } - - return conn, nil -} - -func (l *ldapAuthenticator) disconnect(conn *ldap.Conn) { - if conn != nil { - conn.Close() - } -} - -// userIsInAdminGroup checks if the groupData array contains the admin group DN -func userIsInAdminGroup(groupData [][]byte, adminGroupDN *ldap.DN) (bool, error) { - for _, group := range groupData { - dn, err := ldap.ParseDN(string(group)) - if err != nil { - return false, errors.WithMessage(err, "failed to parse group DN") - } - if adminGroupDN.Equal(dn) { - return true, nil - } - } - - return false, nil -} - -func getLdapFieldMapping(f LdapFields) LdapFields { - defaultMap := LdapFields{ - BaseFields: BaseFields{ - UserIdentifier: "mail", - Email: "mail", - Firstname: "givenName", - Lastname: "sn", - Phone: "telephoneNumber", - Department: "department", - }, - GroupMembership: "memberOf", - } - if f.UserIdentifier != "" { - defaultMap.UserIdentifier = f.UserIdentifier - } - if f.Email != "" { - defaultMap.Email = f.Email - } - if f.Firstname != "" { - defaultMap.Firstname = f.Firstname - } - if f.Lastname != "" { - defaultMap.Lastname = f.Lastname - } - if f.Phone != "" { - defaultMap.Phone = f.Phone - } - if f.Department != "" { - defaultMap.Department = f.Department - } - if f.GroupMembership != "" { - defaultMap.GroupMembership = f.GroupMembership - } - - return defaultMap -} diff --git a/cmd/wg-portal/common/ldap_test.go b/cmd/wg-portal/common/ldap_test.go deleted file mode 100644 index 213d575..0000000 --- a/cmd/wg-portal/common/ldap_test.go +++ /dev/null @@ -1,95 +0,0 @@ -package common - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/go-ldap/ldap/v3" -) - -func Test_getLdapFieldMapping(t *testing.T) { - defaultFields := LdapFields{ - BaseFields: BaseFields{ - UserIdentifier: "mail", - Email: "mail", - Firstname: "givenName", - Lastname: "sn", - Phone: "telephoneNumber", - Department: "department", - }, - GroupMembership: "memberOf", - } - - got := getLdapFieldMapping(LdapFields{}) - assert.Equal(t, defaultFields, got) - - customFields := LdapFields{ - BaseFields: BaseFields{ - UserIdentifier: "field_uid", - Email: "field_email", - Firstname: "field_fn", - Lastname: "field_ln", - Phone: "field_phone", - Department: "field_dep", - }, - GroupMembership: "field_member", - } - - got = getLdapFieldMapping(customFields) - assert.Equal(t, customFields, got) -} - -func Test_userIsInAdminGroup(t *testing.T) { - adminDN, _ := ldap.ParseDN("CN=admin,OU=groups,DC=TEST,DC=COM") - - tests := []struct { - name string - groupData [][]byte - want bool - wantErr bool - }{ - { - name: "NoGroups", - groupData: nil, - want: false, - wantErr: false, - }, - { - name: "WrongGroups", - groupData: [][]byte{[]byte("cn=wrong,dc=group"), []byte("CN=wrong2,OU=groups,DC=TEST,DC=COM")}, - want: false, - wantErr: false, - }, - { - name: "CorrectGroups", - groupData: [][]byte{[]byte("CN=admin,OU=groups,DC=TEST,DC=COM")}, - want: true, - wantErr: false, - }, - { - name: "CorrectGroupsCase", - groupData: [][]byte{[]byte("cn=admin,OU=groups,dc=TEST,DC=COM")}, - want: true, - wantErr: false, - }, - { - name: "WrongDN", - groupData: [][]byte{[]byte("i_am_invalid")}, - want: false, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := userIsInAdminGroup(tt.groupData, adminDN) - if (err != nil) != tt.wantErr { - t.Errorf("userIsInAdminGroup() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("userIsInAdminGroup() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/cmd/wg-portal/common/oauth.go b/cmd/wg-portal/common/oauth.go deleted file mode 100644 index 57a39bc..0000000 --- a/cmd/wg-portal/common/oauth.go +++ /dev/null @@ -1,259 +0,0 @@ -package common - -import ( - "context" - "encoding/json" - "io/ioutil" - "net/http" - "strconv" - "time" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/h44z/wg-portal/internal/persistence" - "github.com/pkg/errors" - "golang.org/x/oauth2" -) - -type AuthenticatorType string - -const ( - AuthenticatorTypeOAuth AuthenticatorType = "oauth" - AuthenticatorTypeOidc AuthenticatorType = "oidc" -) - -type AuthenticatorUserInfo struct { - Identifier persistence.UserIdentifier - Email string - Firstname string - Lastname string - Phone string - Department string - IsAdmin bool -} - -type Authenticator interface { - GetType() AuthenticatorType - AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string - Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) - GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]interface{}, error) - ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) - RegistrationEnabled() bool -} - -type plainOauthAuthenticator struct { - name string - cfg *oauth2.Config - userInfoEndpoint string - client *http.Client - userInfoMapping OauthFields - registrationEnabled bool -} - -func NewPlainOauthAuthenticator(_ context.Context, callbackUrl string, cfg *OAuthProvider) (*plainOauthAuthenticator, error) { - var authenticator = &plainOauthAuthenticator{} - - authenticator.name = cfg.ProviderName - authenticator.client = &http.Client{ - Timeout: time.Second * 10, - } - authenticator.cfg = &oauth2.Config{ - ClientID: cfg.ClientID, - ClientSecret: cfg.ClientSecret, - Endpoint: oauth2.Endpoint{ - AuthURL: cfg.AuthURL, - TokenURL: cfg.TokenURL, - AuthStyle: oauth2.AuthStyleAutoDetect, - }, - RedirectURL: callbackUrl, - Scopes: cfg.Scopes, - } - authenticator.userInfoEndpoint = cfg.UserInfoURL - authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) - authenticator.registrationEnabled = cfg.RegistrationEnabled - - return authenticator, nil -} - -func (p plainOauthAuthenticator) RegistrationEnabled() bool { - return p.registrationEnabled -} - -func (p plainOauthAuthenticator) GetType() AuthenticatorType { - return AuthenticatorTypeOAuth -} - -func (p plainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - return p.cfg.AuthCodeURL(state, opts...) -} - -func (p plainOauthAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return p.cfg.Exchange(ctx, code, opts...) -} - -func (p plainOauthAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, _ string) (map[string]interface{}, error) { - req, err := http.NewRequest("GET", p.userInfoEndpoint, nil) - if err != nil { - return nil, errors.WithMessage(err, "failed to create user info get request") - } - req.Header.Add("Authorization", "Bearer "+token.AccessToken) - req.WithContext(ctx) - - response, err := p.client.Do(req) - if err != nil { - return nil, errors.WithMessage(err, "failed to get user info") - } - defer response.Body.Close() - contents, err := ioutil.ReadAll(response.Body) - if err != nil { - return nil, errors.WithMessage(err, "failed to read response body") - } - - var userFields map[string]interface{} - err = json.Unmarshal(contents, &userFields) - if err != nil { - return nil, errors.WithMessage(err, "failed to parse user info") - } - - return userFields, nil -} - -func (p plainOauthAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { - isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, p.userInfoMapping.IsAdmin, "")) - userInfo := &AuthenticatorUserInfo{ - Identifier: persistence.UserIdentifier(mapDefaultString(raw, p.userInfoMapping.UserIdentifier, "")), - Email: mapDefaultString(raw, p.userInfoMapping.Email, ""), - Firstname: mapDefaultString(raw, p.userInfoMapping.Firstname, ""), - Lastname: mapDefaultString(raw, p.userInfoMapping.Lastname, ""), - Phone: mapDefaultString(raw, p.userInfoMapping.Phone, ""), - Department: mapDefaultString(raw, p.userInfoMapping.Department, ""), - IsAdmin: isAdmin, - } - - return userInfo, nil -} - -type oidcAuthenticator struct { - name string - provider *oidc.Provider - verifier *oidc.IDTokenVerifier - cfg *oauth2.Config - userInfoMapping OauthFields - registrationEnabled bool -} - -func NewOidcAuthenticator(ctx context.Context, callbackUrl string, cfg *OpenIDConnectProvider) (*oidcAuthenticator, error) { - var err error - var authenticator = &oidcAuthenticator{} - - authenticator.name = cfg.ProviderName - authenticator.provider, err = oidc.NewProvider(ctx, cfg.BaseUrl) - if err != nil { - return nil, errors.WithMessage(err, "failed to create new oidc provider") - } - authenticator.verifier = authenticator.provider.Verifier(&oidc.Config{ - ClientID: cfg.ClientID, - }) - - scopes := []string{oidc.ScopeOpenID} - scopes = append(scopes, cfg.ExtraScopes...) - authenticator.cfg = &oauth2.Config{ - ClientID: cfg.ClientID, - ClientSecret: cfg.ClientSecret, - Endpoint: authenticator.provider.Endpoint(), - RedirectURL: callbackUrl, - Scopes: scopes, - } - authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) - authenticator.registrationEnabled = cfg.RegistrationEnabled - - return authenticator, nil -} - -func (o oidcAuthenticator) RegistrationEnabled() bool { - return o.registrationEnabled -} - -func (o oidcAuthenticator) GetType() AuthenticatorType { - return AuthenticatorTypeOidc -} - -func (o oidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { - return o.cfg.AuthCodeURL(state, opts...) -} - -func (o oidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { - return o.cfg.Exchange(ctx, code, opts...) -} - -func (o oidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]interface{}, error) { - rawIDToken, ok := token.Extra("id_token").(string) - if !ok { - return nil, errors.New("token does not contain id_token") - } - idToken, err := o.verifier.Verify(ctx, rawIDToken) - if err != nil { - return nil, errors.WithMessage(err, "failed to validate id_token") - } - if idToken.Nonce != nonce { - return nil, errors.New("nonce mismatch") - } - - var tokenFields map[string]interface{} - if err = idToken.Claims(&tokenFields); err != nil { - return nil, errors.WithMessage(err, "failed to parse extra claims") - } - - return tokenFields, nil -} - -func (o oidcAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { - isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, o.userInfoMapping.IsAdmin, "")) - userInfo := &AuthenticatorUserInfo{ - Identifier: persistence.UserIdentifier(mapDefaultString(raw, o.userInfoMapping.UserIdentifier, "")), - Email: mapDefaultString(raw, o.userInfoMapping.Email, ""), - Firstname: mapDefaultString(raw, o.userInfoMapping.Firstname, ""), - Lastname: mapDefaultString(raw, o.userInfoMapping.Lastname, ""), - Phone: mapDefaultString(raw, o.userInfoMapping.Phone, ""), - Department: mapDefaultString(raw, o.userInfoMapping.Department, ""), - IsAdmin: isAdmin, - } - - return userInfo, nil -} - -func getOauthFieldMapping(f OauthFields) OauthFields { - defaultMap := OauthFields{ - BaseFields: BaseFields{ - UserIdentifier: "sub", - Email: "email", - Firstname: "given_name", - Lastname: "family_name", - Phone: "phone", - Department: "department", - }, - IsAdmin: "admin_flag", - } - if f.UserIdentifier != "" { - defaultMap.UserIdentifier = f.UserIdentifier - } - if f.Email != "" { - defaultMap.Email = f.Email - } - if f.Firstname != "" { - defaultMap.Firstname = f.Firstname - } - if f.Lastname != "" { - defaultMap.Lastname = f.Lastname - } - if f.Phone != "" { - defaultMap.Phone = f.Phone - } - if f.Department != "" { - defaultMap.Department = f.Department - } - if f.IsAdmin != "" { - defaultMap.IsAdmin = f.IsAdmin - } - - return defaultMap -} diff --git a/cmd/wg-portal/common/oauth_test.go b/cmd/wg-portal/common/oauth_test.go deleted file mode 100644 index f6e6a3a..0000000 --- a/cmd/wg-portal/common/oauth_test.go +++ /dev/null @@ -1,39 +0,0 @@ -package common - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_getOauthFieldMapping(t *testing.T) { - defaultFields := OauthFields{ - BaseFields: BaseFields{ - UserIdentifier: "sub", - Email: "email", - Firstname: "given_name", - Lastname: "family_name", - Phone: "phone", - Department: "department", - }, - IsAdmin: "admin_flag", - } - - got := getOauthFieldMapping(OauthFields{}) - assert.Equal(t, defaultFields, got) - - customFields := OauthFields{ - BaseFields: BaseFields{ - UserIdentifier: "field_uid", - Email: "field_email", - Firstname: "field_fn", - Lastname: "field_ln", - Phone: "field_phone", - Department: "field_dep", - }, - IsAdmin: "field_admin", - } - - got = getOauthFieldMapping(customFields) - assert.Equal(t, customFields, got) -} diff --git a/cmd/wg-portal/common/utils.go b/cmd/wg-portal/common/utils.go deleted file mode 100644 index 9ce6737..0000000 --- a/cmd/wg-portal/common/utils.go +++ /dev/null @@ -1,35 +0,0 @@ -package common - -import "fmt" - -// mapDefaultString returns the string value for the given key or a default value -func mapDefaultString(m map[string]interface{}, key string, dflt string) string { - if m == nil { - return dflt - } - if tmp, ok := m[key]; !ok { - return dflt - } else { - switch v := tmp.(type) { - case string: - return v - case nil: - return dflt - default: - return fmt.Sprintf("%v", v) - } - } -} - -// uniqueStringSlice removes duplicates in the given string slice -func uniqueStringSlice(slice []string) []string { - keys := make(map[string]struct{}) - uniqueSlice := make([]string, 0, len(slice)) - for _, entry := range slice { - if _, exists := keys[entry]; !exists { - keys[entry] = struct{}{} - uniqueSlice = append(uniqueSlice, entry) - } - } - return uniqueSlice -} diff --git a/cmd/wg-portal/common/utils_test.go b/cmd/wg-portal/common/utils_test.go deleted file mode 100644 index e254288..0000000 --- a/cmd/wg-portal/common/utils_test.go +++ /dev/null @@ -1,108 +0,0 @@ -package common - -import ( - "reflect" - "testing" -) - -func Test_mapDefaultString(t *testing.T) { - type args struct { - m map[string]interface{} - key string - defaultValue string - } - tests := []struct { - name string - args args - want string - }{ - { - name: "match", - args: args{ - m: map[string]interface{}{"hello": "world"}, - key: "hello", - defaultValue: "", - }, - want: "world", - }, { - name: "no_match", - args: args{ - m: map[string]interface{}{"hello": "world"}, - key: "hi", - defaultValue: "", - }, - want: "", - }, { - name: "nil_value", - args: args{ - m: map[string]interface{}{"hello": nil}, - key: "hello", - defaultValue: "", - }, - want: "", - }, { - name: "default_nil_value", - args: args{ - m: map[string]interface{}{"hello": nil}, - key: "hello", - defaultValue: "world", - }, - want: "world", - }, { - name: "nil_map", - args: args{ - m: nil, - key: "hi", - defaultValue: "world", - }, - want: "world", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := mapDefaultString(tt.args.m, tt.args.key, tt.args.defaultValue); got != tt.want { - t.Errorf("mapDefaultString() = %v, want %v", got, tt.want) - } - }) - } - -} - -func Test_uniqueStringSlice(t *testing.T) { - type args struct { - slice []string - } - tests := []struct { - name string - args args - want []string - }{ - { - name: "Empty", - args: args{}, - want: []string{}, - }, - { - name: "Single", - args: args{slice: []string{"1"}}, - want: []string{"1"}, - }, - { - name: "Normal", - args: args{slice: []string{"1", "2", "3"}}, - want: []string{"1", "2", "3"}, - }, - { - name: "Duplicate", - args: args{slice: []string{"1", "2", "2"}}, - want: []string{"1", "2"}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := uniqueStringSlice(tt.args.slice); !reflect.DeepEqual(got, tt.want) { - t.Errorf("UniqueStringSlice() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/cmd/wg-portal/restapi/handler.go b/cmd/wg-portal/restapi/handler.go index 10e26df..9ff239d 100644 --- a/cmd/wg-portal/restapi/handler.go +++ b/cmd/wg-portal/restapi/handler.go @@ -3,16 +3,16 @@ import ( "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/cmd/wg-portal/common" - "github.com/h44z/wg-portal/internal/portal" + "github.com/h44z/wg-portal/internal/core" ) type Handler struct { config *common.Config - backend portal.Backend + backend core.Backend } -func NewHandler(config *common.Config, backend portal.Backend) (*Handler, error) { +func NewHandler(config *common.Config, backend core.Backend) (*Handler, error) { h := &Handler{ config: config, backend: backend, diff --git a/cmd/wg-portal/server.go b/cmd/wg-portal/server.go index ea3e674..20ad143 100644 --- a/cmd/wg-portal/server.go +++ b/cmd/wg-portal/server.go @@ -17,8 +17,8 @@ "github.com/h44z/wg-portal/cmd/wg-portal/common" "github.com/h44z/wg-portal/cmd/wg-portal/restapi" "github.com/h44z/wg-portal/cmd/wg-portal/ui" + "github.com/h44z/wg-portal/internal/core" "github.com/h44z/wg-portal/internal/persistence" - "github.com/h44z/wg-portal/internal/portal" "github.com/pkg/errors" "github.com/sirupsen/logrus" ginlogrus "github.com/toorop/gin-logrus" @@ -32,7 +32,7 @@ config *common.Config server *gin.Engine - backend portal.Backend + backend core.Backend } func NewServer(config *common.Config) (*server, error) { @@ -47,7 +47,7 @@ } // Portal Backend - s.backend, err = portal.NewPersistentBackend(database) + s.backend, err = core.NewPersistentBackend(database) if err != nil { return nil, errors.WithMessagef(err, "backend failed to initialize") } diff --git a/cmd/wg-portal/ui/handler.go b/cmd/wg-portal/ui/handler.go index de99677..1c45358 100644 --- a/cmd/wg-portal/ui/handler.go +++ b/cmd/wg-portal/ui/handler.go @@ -8,11 +8,13 @@ "strings" "time" + "github.com/h44z/wg-portal/internal/authentication" + "github.com/h44z/wg-portal/internal/core" + "github.com/h44z/wg-portal/internal/persistence" "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/cmd/wg-portal/common" - "github.com/h44z/wg-portal/internal/portal" "github.com/pkg/errors" csrf "github.com/utrack/gin-csrf" ) @@ -21,18 +23,18 @@ config *common.Config session SessionStore - backend portal.Backend - oauthAuthenticators map[string]common.Authenticator - ldapAuthenticators map[string]common.LdapAuthenticator + backend core.Backend + oauthAuthenticators map[string]authentication.Authenticator + ldapAuthenticators map[string]authentication.LdapAuthenticator } -func NewHandler(config *common.Config, backend portal.Backend) (*handler, error) { +func NewHandler(config *common.Config, backend core.Backend) (*handler, error) { h := &handler{ config: config, backend: backend, session: GinSessionStore{sessionIdentifier: "wgPortalSession"}, - oauthAuthenticators: make(map[string]common.Authenticator), - ldapAuthenticators: make(map[string]common.LdapAuthenticator), + oauthAuthenticators: make(map[string]authentication.Authenticator), + ldapAuthenticators: make(map[string]authentication.LdapAuthenticator), } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -62,7 +64,7 @@ redirectUrl := *extUrl redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback") - authenticator, err := common.NewOidcAuthenticator(ctx, redirectUrl.String(), providerCfg) + authenticator, err := authentication.NewOidcAuthenticator(ctx, redirectUrl.String(), providerCfg) if err != nil { return errors.WithMessagef(err, "failed to setup oidc authentication provider %s", providerCfg.ProviderName) } @@ -79,7 +81,7 @@ redirectUrl := *extUrl redirectUrl.Path = path.Join(redirectUrl.Path, "/auth/login/", providerId, "/callback") - authenticator, err := common.NewPlainOauthAuthenticator(ctx, redirectUrl.String(), providerCfg) + authenticator, err := authentication.NewPlainOauthAuthenticator(ctx, redirectUrl.String(), providerCfg) if err != nil { return errors.WithMessagef(err, "failed to setup oauth authentication provider %s", providerId) } @@ -93,7 +95,7 @@ return errors.Errorf("auth provider with name %s is already registerd", providerId) } - authenticator, err := common.NewLdapAuthenticator(ctx, providerCfg) + authenticator, err := authentication.NewLdapAuthenticator(ctx, providerCfg) if err != nil { return errors.WithMessagef(err, "failed to setup ldap authentication provider %s", providerId) } diff --git a/cmd/wg-portal/ui/pages_core.go b/cmd/wg-portal/ui/pages_core.go index 4733c90..44b2789 100644 --- a/cmd/wg-portal/ui/pages_core.go +++ b/cmd/wg-portal/ui/pages_core.go @@ -10,9 +10,10 @@ "strings" "time" + "github.com/h44z/wg-portal/internal/authentication" + "github.com/coreos/go-oidc/v3/oidc" "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/cmd/wg-portal/common" "github.com/h44z/wg-portal/internal" "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" @@ -169,9 +170,9 @@ var authCodeUrl string switch authenticator.GetType() { - case common.AuthenticatorTypeOAuth: + case authentication.AuthenticatorTypeOAuth: authCodeUrl = authenticator.AuthCodeURL(state) - case common.AuthenticatorTypeOidc: + case authentication.AuthenticatorTypeOidc: nonce, err := randString(16) if err != nil { h.redirectWithFlash(c, "/auth/login", FlashData{Message: err.Error(), Type: "danger"}) @@ -326,7 +327,7 @@ return nil, errors.Errorf("no configuration for authenticator id %s", id) } -func (h *handler) prepareUserSession(userInfo *common.AuthenticatorUserInfo, providerId string) (SessionData, error) { +func (h *handler) prepareUserSession(userInfo *authentication.AuthenticatorUserInfo, providerId string) (SessionData, error) { session := h.session.DefaultSessionData() authenticatorCfg, err := h.getAuthenticatorConfig(providerId) if err != nil { @@ -334,9 +335,9 @@ } registrationEnabled := false switch cfg := authenticatorCfg.(type) { - case common.OAuthProvider: + case authentication.OAuthProvider: registrationEnabled = cfg.RegistrationEnabled - case common.OpenIDConnectProvider: + case authentication.OpenIDConnectProvider: registrationEnabled = cfg.RegistrationEnabled } @@ -363,7 +364,7 @@ return session, nil } -func (h *handler) registerOauthUser(userInfo *common.AuthenticatorUserInfo) (*persistence.User, error) { +func (h *handler) registerOauthUser(userInfo *authentication.AuthenticatorUserInfo) (*persistence.User, error) { user := &persistence.User{ Identifier: userInfo.Identifier, Email: userInfo.Email, diff --git a/internal/authentication/config.go b/internal/authentication/config.go new file mode 100644 index 0000000..a9b9e34 --- /dev/null +++ b/internal/authentication/config.go @@ -0,0 +1,106 @@ +package authentication + +import ( + "github.com/go-ldap/ldap/v3" +) + +type BaseFields struct { + UserIdentifier string `yaml:"user_identifier"` + Email string `yaml:"email"` + Firstname string `yaml:"firstname"` + Lastname string `yaml:"lastname"` + Phone string `yaml:"phone"` + Department string `yaml:"department"` +} + +type OauthFields struct { + BaseFields `yaml:",inline"` + IsAdmin string `yaml:"is_admin"` +} + +type LdapFields struct { + BaseFields `yaml:",inline"` + GroupMembership string `yaml:"memberof"` +} + +type LdapProvider struct { + URL string `yaml:"url"` + StartTLS bool `yaml:"start_tls"` + CertValidation bool `yaml:"cert_validation"` + BaseDN string `yaml:"base_dn"` + BindUser string `yaml:"bind_user"` + BindPass string `yaml:"bind_pass"` + + FieldMap LdapFields `yaml:"field_map"` + + LoginFilter string `yaml:"login_filter"` // {{login_identifier}} gets replaced with the login email address + AdminGroupDN string `yaml:"admin_group"` // Members of this group receive admin rights in WG-Portal + adminGroupDN *ldap.DN `yaml:"-"` + + Synchronize bool `yaml:"synchronize"` + + // If DeleteMissing is false, missing users will be deactivated + DeleteMissing bool `yaml:"delete_missing"` + SyncFilter string `yaml:"sync_filter"` + + // If RegistrationEnabled is set to true, wg-portal will create new users that do not exist in the database. + RegistrationEnabled bool `yaml:"registration_enabled"` +} + +type OpenIDConnectProvider struct { + // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. + ProviderName string `yaml:"provider_name"` + + // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. + DisplayName string `yaml:"display_name"` + + BaseUrl string `yaml:"base_url"` + + // ClientID is the application's ID. + ClientID string `yaml:"client_id"` + + // ClientSecret is the application's secret. + ClientSecret string `yaml:"client_secret"` + + // ExtraScopes specifies optional requested permissions. + ExtraScopes []string `yaml:"extra_scopes"` + + // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields + FieldMap OauthFields `yaml:"field_map"` + + // If RegistrationEnabled is set to true, missing users will be created in the database + RegistrationEnabled bool `yaml:"registration_enabled"` +} + +type OAuthProvider struct { + // ProviderName is an internal name that is used to distinguish oauth endpoints. It must not contain spaces or special characters. + ProviderName string `yaml:"provider_name"` + + // DisplayName is shown to the user on the login page. If it is empty, ProviderName will be displayed. + DisplayName string `yaml:"display_name"` + + BaseUrl string `yaml:"base_url"` + + // ClientID is the application's ID. + ClientID string `yaml:"client_id"` + + // ClientSecret is the application's secret. + ClientSecret string `yaml:"client_secret"` + + AuthURL string `yaml:"auth_url"` + TokenURL string `yaml:"token_url"` + UserInfoURL string `yaml:"user_info_url"` + + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURL string `yaml:"redirect_url"` + + // Scope specifies optional requested permissions. + Scopes []string `yaml:"scopes"` + + // FieldMap is used to map the names of the user-info endpoint fields to wg-portal fields + FieldMap OauthFields `yaml:"field_map"` + + // If RegistrationEnabled is set to true, wg-portal will create new users that do not exist in the database. + RegistrationEnabled bool `yaml:"registration_enabled"` +} diff --git a/internal/authentication/ldap.go b/internal/authentication/ldap.go new file mode 100644 index 0000000..8287eaf --- /dev/null +++ b/internal/authentication/ldap.go @@ -0,0 +1,294 @@ +package authentication + +import ( + "context" + "crypto/tls" + "strings" + + "github.com/pkg/errors" + + "github.com/go-ldap/ldap/v3" + + "github.com/h44z/wg-portal/internal/persistence" + "github.com/h44z/wg-portal/internal/user" +) + +type LdapAuthenticator interface { + user.Authenticator + GetAllUserInfos(ctx context.Context) ([]map[string]interface{}, error) + GetUserInfo(ctx context.Context, username persistence.UserIdentifier) (map[string]interface{}, error) + ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) + RegistrationEnabled() bool + SynchronizationEnabled() bool +} + +type ldapAuthenticator struct { + cfg *LdapProvider +} + +func NewLdapAuthenticator(_ context.Context, cfg *LdapProvider) (*ldapAuthenticator, error) { + var authenticator = &ldapAuthenticator{} + + authenticator.cfg = cfg + + dn, err := ldap.ParseDN(cfg.AdminGroupDN) + if err != nil { + return nil, errors.WithMessage(err, "failed to parse admin group DN") + } + authenticator.cfg.FieldMap = getLdapFieldMapping(cfg.FieldMap) + authenticator.cfg.adminGroupDN = dn + + return authenticator, nil +} + +func (l *ldapAuthenticator) RegistrationEnabled() bool { + return l.cfg.RegistrationEnabled +} + +func (l *ldapAuthenticator) SynchronizationEnabled() bool { + return l.cfg.Synchronize +} + +func (l *ldapAuthenticator) PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error { + conn, err := l.connect() + if err != nil { + return errors.WithMessage(err, "failed to setup connection") + } + defer l.disconnect(conn) + + attrs := []string{"dn"} + + loginFilter := strings.Replace(l.cfg.LoginFilter, "{{login_identifier}}", string(userId), -1) + searchRequest := ldap.NewSearchRequest( + l.cfg.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit + loginFilter, attrs, nil, + ) + + sr, err := conn.Search(searchRequest) + if err != nil { + return errors.Wrapf(err, "failed to search in ldap") + } + + if len(sr.Entries) == 0 { + return errors.New("user not found") + } + + if len(sr.Entries) > 1 { + return errors.New("no unique user found") + } + + // Bind as the user to verify their password + userDN := sr.Entries[0].DN + err = conn.Bind(userDN, plainPassword) + if err != nil { + return errors.Wrapf(err, "invalid credentials") + } + _ = conn.Unbind() + + return nil +} + +func (l *ldapAuthenticator) HashedAuthentication(_ persistence.UserIdentifier, _ string) error { + // TODO: is this possible? + return errors.New("unimplemented") +} + +func (l *ldapAuthenticator) GetUserInfo(_ context.Context, userId persistence.UserIdentifier) (map[string]interface{}, error) { + conn, err := l.connect() + if err != nil { + return nil, errors.WithMessage(err, "failed to setup connection") + } + defer l.disconnect(conn) + + attrs := l.getLdapSearchAttributes() + + loginFilter := strings.Replace(l.cfg.LoginFilter, "{{login_identifier}}", string(userId), -1) + searchRequest := ldap.NewSearchRequest( + l.cfg.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit + loginFilter, attrs, nil, + ) + + sr, err := conn.Search(searchRequest) + if err != nil { + return nil, errors.Wrapf(err, "failed to search in ldap") + } + + if len(sr.Entries) == 0 { + return nil, errors.New("user not found") + } + + if len(sr.Entries) > 1 { + return nil, errors.New("no unique user found") + } + + users := l.convertLdapEntries(sr) + + return users[0], nil +} + +func (l *ldapAuthenticator) GetAllUserInfos(_ context.Context) ([]map[string]interface{}, error) { + conn, err := l.connect() + if err != nil { + return nil, errors.WithMessage(err, "failed to setup connection") + } + defer l.disconnect(conn) + + attrs := l.getLdapSearchAttributes() + + searchRequest := ldap.NewSearchRequest( + l.cfg.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 20, false, // 20 second time limit + l.cfg.SyncFilter, attrs, nil, + ) + + sr, err := conn.Search(searchRequest) + if err != nil { + return nil, errors.Wrapf(err, "failed to search in ldap") + } + + users := l.convertLdapEntries(sr) + + return users, nil +} + +func (l *ldapAuthenticator) convertLdapEntries(sr *ldap.SearchResult) []map[string]interface{} { + users := make([]map[string]interface{}, len(sr.Entries)) + + fieldMap := l.cfg.FieldMap + for i, entry := range sr.Entries { + userData := make(map[string]interface{}) + userData[fieldMap.UserIdentifier] = entry.DN + userData[fieldMap.Email] = entry.GetAttributeValue(fieldMap.Email) + userData[fieldMap.Firstname] = entry.GetAttributeValue(fieldMap.Firstname) + userData[fieldMap.Lastname] = entry.GetAttributeValue(fieldMap.Lastname) + userData[fieldMap.Phone] = entry.GetAttributeValue(fieldMap.Phone) + userData[fieldMap.Department] = entry.GetAttributeValue(fieldMap.Department) + userData[fieldMap.GroupMembership] = entry.GetRawAttributeValues(fieldMap.GroupMembership) + + users[i] = userData + } + return users +} + +func (l *ldapAuthenticator) getLdapSearchAttributes() []string { + fieldMap := l.cfg.FieldMap + attrs := []string{"dn", fieldMap.UserIdentifier} + if fieldMap.Email != "" { + attrs = append(attrs, fieldMap.Email) + } + if fieldMap.Firstname != "" { + attrs = append(attrs, fieldMap.Firstname) + } + if fieldMap.Lastname != "" { + attrs = append(attrs, fieldMap.Lastname) + } + if fieldMap.Phone != "" { + attrs = append(attrs, fieldMap.Phone) + } + if fieldMap.Department != "" { + attrs = append(attrs, fieldMap.Department) + } + if fieldMap.GroupMembership != "" { + attrs = append(attrs, fieldMap.GroupMembership) + } + + return uniqueStringSlice(attrs) +} + +func (l ldapAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { + isAdmin, err := userIsInAdminGroup(raw[l.cfg.FieldMap.GroupMembership].([][]byte), l.cfg.adminGroupDN) + if err != nil { + return nil, errors.WithMessage(err, "failed to check admin group") + } + userInfo := &AuthenticatorUserInfo{ + Identifier: persistence.UserIdentifier(mapDefaultString(raw, l.cfg.FieldMap.UserIdentifier, "")), + Email: mapDefaultString(raw, l.cfg.FieldMap.Email, ""), + Firstname: mapDefaultString(raw, l.cfg.FieldMap.Firstname, ""), + Lastname: mapDefaultString(raw, l.cfg.FieldMap.Lastname, ""), + Phone: mapDefaultString(raw, l.cfg.FieldMap.Phone, ""), + Department: mapDefaultString(raw, l.cfg.FieldMap.Department, ""), + IsAdmin: isAdmin, + } + + return userInfo, nil +} + +func (l *ldapAuthenticator) connect() (*ldap.Conn, error) { + tlsConfig := &tls.Config{InsecureSkipVerify: !l.cfg.CertValidation} + conn, err := ldap.DialURL(l.cfg.URL, ldap.DialWithTLSConfig(tlsConfig)) + if err != nil { + return nil, errors.Wrap(err, "failed to connect to LDAP") + } + + if l.cfg.StartTLS { // Reconnect with TLS + if err = conn.StartTLS(tlsConfig); err != nil { + return nil, errors.Wrap(err, "failed to start TLS on connection") + } + } + + if err = conn.Bind(l.cfg.BindUser, l.cfg.BindPass); err != nil { + return nil, errors.Wrap(err, "failed to bind to LDAP") + } + + return conn, nil +} + +func (l *ldapAuthenticator) disconnect(conn *ldap.Conn) { + if conn != nil { + conn.Close() + } +} + +// userIsInAdminGroup checks if the groupData array contains the admin group DN +func userIsInAdminGroup(groupData [][]byte, adminGroupDN *ldap.DN) (bool, error) { + for _, group := range groupData { + dn, err := ldap.ParseDN(string(group)) + if err != nil { + return false, errors.WithMessage(err, "failed to parse group DN") + } + if adminGroupDN.Equal(dn) { + return true, nil + } + } + + return false, nil +} + +func getLdapFieldMapping(f LdapFields) LdapFields { + defaultMap := LdapFields{ + BaseFields: BaseFields{ + UserIdentifier: "mail", + Email: "mail", + Firstname: "givenName", + Lastname: "sn", + Phone: "telephoneNumber", + Department: "department", + }, + GroupMembership: "memberOf", + } + if f.UserIdentifier != "" { + defaultMap.UserIdentifier = f.UserIdentifier + } + if f.Email != "" { + defaultMap.Email = f.Email + } + if f.Firstname != "" { + defaultMap.Firstname = f.Firstname + } + if f.Lastname != "" { + defaultMap.Lastname = f.Lastname + } + if f.Phone != "" { + defaultMap.Phone = f.Phone + } + if f.Department != "" { + defaultMap.Department = f.Department + } + if f.GroupMembership != "" { + defaultMap.GroupMembership = f.GroupMembership + } + + return defaultMap +} diff --git a/internal/authentication/ldap_test.go b/internal/authentication/ldap_test.go new file mode 100644 index 0000000..4aa8fd2 --- /dev/null +++ b/internal/authentication/ldap_test.go @@ -0,0 +1,95 @@ +package authentication + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/go-ldap/ldap/v3" +) + +func Test_getLdapFieldMapping(t *testing.T) { + defaultFields := LdapFields{ + BaseFields: BaseFields{ + UserIdentifier: "mail", + Email: "mail", + Firstname: "givenName", + Lastname: "sn", + Phone: "telephoneNumber", + Department: "department", + }, + GroupMembership: "memberOf", + } + + got := getLdapFieldMapping(LdapFields{}) + assert.Equal(t, defaultFields, got) + + customFields := LdapFields{ + BaseFields: BaseFields{ + UserIdentifier: "field_uid", + Email: "field_email", + Firstname: "field_fn", + Lastname: "field_ln", + Phone: "field_phone", + Department: "field_dep", + }, + GroupMembership: "field_member", + } + + got = getLdapFieldMapping(customFields) + assert.Equal(t, customFields, got) +} + +func Test_userIsInAdminGroup(t *testing.T) { + adminDN, _ := ldap.ParseDN("CN=admin,OU=groups,DC=TEST,DC=COM") + + tests := []struct { + name string + groupData [][]byte + want bool + wantErr bool + }{ + { + name: "NoGroups", + groupData: nil, + want: false, + wantErr: false, + }, + { + name: "WrongGroups", + groupData: [][]byte{[]byte("cn=wrong,dc=group"), []byte("CN=wrong2,OU=groups,DC=TEST,DC=COM")}, + want: false, + wantErr: false, + }, + { + name: "CorrectGroups", + groupData: [][]byte{[]byte("CN=admin,OU=groups,DC=TEST,DC=COM")}, + want: true, + wantErr: false, + }, + { + name: "CorrectGroupsCase", + groupData: [][]byte{[]byte("cn=admin,OU=groups,dc=TEST,DC=COM")}, + want: true, + wantErr: false, + }, + { + name: "WrongDN", + groupData: [][]byte{[]byte("i_am_invalid")}, + want: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := userIsInAdminGroup(tt.groupData, adminDN) + if (err != nil) != tt.wantErr { + t.Errorf("userIsInAdminGroup() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("userIsInAdminGroup() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/authentication/oauth.go b/internal/authentication/oauth.go new file mode 100644 index 0000000..79d2e75 --- /dev/null +++ b/internal/authentication/oauth.go @@ -0,0 +1,259 @@ +package authentication + +import ( + "context" + "encoding/json" + "io/ioutil" + "net/http" + "strconv" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" + "golang.org/x/oauth2" +) + +type AuthenticatorType string + +const ( + AuthenticatorTypeOAuth AuthenticatorType = "oauth" + AuthenticatorTypeOidc AuthenticatorType = "oidc" +) + +type AuthenticatorUserInfo struct { + Identifier persistence.UserIdentifier + Email string + Firstname string + Lastname string + Phone string + Department string + IsAdmin bool +} + +type Authenticator interface { + GetType() AuthenticatorType + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]interface{}, error) + ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) + RegistrationEnabled() bool +} + +type plainOauthAuthenticator struct { + name string + cfg *oauth2.Config + userInfoEndpoint string + client *http.Client + userInfoMapping OauthFields + registrationEnabled bool +} + +func NewPlainOauthAuthenticator(_ context.Context, callbackUrl string, cfg *OAuthProvider) (*plainOauthAuthenticator, error) { + var authenticator = &plainOauthAuthenticator{} + + authenticator.name = cfg.ProviderName + authenticator.client = &http.Client{ + Timeout: time.Second * 10, + } + authenticator.cfg = &oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + Endpoint: oauth2.Endpoint{ + AuthURL: cfg.AuthURL, + TokenURL: cfg.TokenURL, + AuthStyle: oauth2.AuthStyleAutoDetect, + }, + RedirectURL: callbackUrl, + Scopes: cfg.Scopes, + } + authenticator.userInfoEndpoint = cfg.UserInfoURL + authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) + authenticator.registrationEnabled = cfg.RegistrationEnabled + + return authenticator, nil +} + +func (p plainOauthAuthenticator) RegistrationEnabled() bool { + return p.registrationEnabled +} + +func (p plainOauthAuthenticator) GetType() AuthenticatorType { + return AuthenticatorTypeOAuth +} + +func (p plainOauthAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + return p.cfg.AuthCodeURL(state, opts...) +} + +func (p plainOauthAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return p.cfg.Exchange(ctx, code, opts...) +} + +func (p plainOauthAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, _ string) (map[string]interface{}, error) { + req, err := http.NewRequest("GET", p.userInfoEndpoint, nil) + if err != nil { + return nil, errors.WithMessage(err, "failed to create user info get request") + } + req.Header.Add("Authorization", "Bearer "+token.AccessToken) + req.WithContext(ctx) + + response, err := p.client.Do(req) + if err != nil { + return nil, errors.WithMessage(err, "failed to get user info") + } + defer response.Body.Close() + contents, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, errors.WithMessage(err, "failed to read response body") + } + + var userFields map[string]interface{} + err = json.Unmarshal(contents, &userFields) + if err != nil { + return nil, errors.WithMessage(err, "failed to parse user info") + } + + return userFields, nil +} + +func (p plainOauthAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { + isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, p.userInfoMapping.IsAdmin, "")) + userInfo := &AuthenticatorUserInfo{ + Identifier: persistence.UserIdentifier(mapDefaultString(raw, p.userInfoMapping.UserIdentifier, "")), + Email: mapDefaultString(raw, p.userInfoMapping.Email, ""), + Firstname: mapDefaultString(raw, p.userInfoMapping.Firstname, ""), + Lastname: mapDefaultString(raw, p.userInfoMapping.Lastname, ""), + Phone: mapDefaultString(raw, p.userInfoMapping.Phone, ""), + Department: mapDefaultString(raw, p.userInfoMapping.Department, ""), + IsAdmin: isAdmin, + } + + return userInfo, nil +} + +type oidcAuthenticator struct { + name string + provider *oidc.Provider + verifier *oidc.IDTokenVerifier + cfg *oauth2.Config + userInfoMapping OauthFields + registrationEnabled bool +} + +func NewOidcAuthenticator(ctx context.Context, callbackUrl string, cfg *OpenIDConnectProvider) (*oidcAuthenticator, error) { + var err error + var authenticator = &oidcAuthenticator{} + + authenticator.name = cfg.ProviderName + authenticator.provider, err = oidc.NewProvider(ctx, cfg.BaseUrl) + if err != nil { + return nil, errors.WithMessage(err, "failed to create new oidc provider") + } + authenticator.verifier = authenticator.provider.Verifier(&oidc.Config{ + ClientID: cfg.ClientID, + }) + + scopes := []string{oidc.ScopeOpenID} + scopes = append(scopes, cfg.ExtraScopes...) + authenticator.cfg = &oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + Endpoint: authenticator.provider.Endpoint(), + RedirectURL: callbackUrl, + Scopes: scopes, + } + authenticator.userInfoMapping = getOauthFieldMapping(cfg.FieldMap) + authenticator.registrationEnabled = cfg.RegistrationEnabled + + return authenticator, nil +} + +func (o oidcAuthenticator) RegistrationEnabled() bool { + return o.registrationEnabled +} + +func (o oidcAuthenticator) GetType() AuthenticatorType { + return AuthenticatorTypeOidc +} + +func (o oidcAuthenticator) AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string { + return o.cfg.AuthCodeURL(state, opts...) +} + +func (o oidcAuthenticator) Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return o.cfg.Exchange(ctx, code, opts...) +} + +func (o oidcAuthenticator) GetUserInfo(ctx context.Context, token *oauth2.Token, nonce string) (map[string]interface{}, error) { + rawIDToken, ok := token.Extra("id_token").(string) + if !ok { + return nil, errors.New("token does not contain id_token") + } + idToken, err := o.verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, errors.WithMessage(err, "failed to validate id_token") + } + if idToken.Nonce != nonce { + return nil, errors.New("nonce mismatch") + } + + var tokenFields map[string]interface{} + if err = idToken.Claims(&tokenFields); err != nil { + return nil, errors.WithMessage(err, "failed to parse extra claims") + } + + return tokenFields, nil +} + +func (o oidcAuthenticator) ParseUserInfo(raw map[string]interface{}) (*AuthenticatorUserInfo, error) { + isAdmin, _ := strconv.ParseBool(mapDefaultString(raw, o.userInfoMapping.IsAdmin, "")) + userInfo := &AuthenticatorUserInfo{ + Identifier: persistence.UserIdentifier(mapDefaultString(raw, o.userInfoMapping.UserIdentifier, "")), + Email: mapDefaultString(raw, o.userInfoMapping.Email, ""), + Firstname: mapDefaultString(raw, o.userInfoMapping.Firstname, ""), + Lastname: mapDefaultString(raw, o.userInfoMapping.Lastname, ""), + Phone: mapDefaultString(raw, o.userInfoMapping.Phone, ""), + Department: mapDefaultString(raw, o.userInfoMapping.Department, ""), + IsAdmin: isAdmin, + } + + return userInfo, nil +} + +func getOauthFieldMapping(f OauthFields) OauthFields { + defaultMap := OauthFields{ + BaseFields: BaseFields{ + UserIdentifier: "sub", + Email: "email", + Firstname: "given_name", + Lastname: "family_name", + Phone: "phone", + Department: "department", + }, + IsAdmin: "admin_flag", + } + if f.UserIdentifier != "" { + defaultMap.UserIdentifier = f.UserIdentifier + } + if f.Email != "" { + defaultMap.Email = f.Email + } + if f.Firstname != "" { + defaultMap.Firstname = f.Firstname + } + if f.Lastname != "" { + defaultMap.Lastname = f.Lastname + } + if f.Phone != "" { + defaultMap.Phone = f.Phone + } + if f.Department != "" { + defaultMap.Department = f.Department + } + if f.IsAdmin != "" { + defaultMap.IsAdmin = f.IsAdmin + } + + return defaultMap +} diff --git a/internal/authentication/oauth_test.go b/internal/authentication/oauth_test.go new file mode 100644 index 0000000..d2b7b7d --- /dev/null +++ b/internal/authentication/oauth_test.go @@ -0,0 +1,39 @@ +package authentication + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_getOauthFieldMapping(t *testing.T) { + defaultFields := OauthFields{ + BaseFields: BaseFields{ + UserIdentifier: "sub", + Email: "email", + Firstname: "given_name", + Lastname: "family_name", + Phone: "phone", + Department: "department", + }, + IsAdmin: "admin_flag", + } + + got := getOauthFieldMapping(OauthFields{}) + assert.Equal(t, defaultFields, got) + + customFields := OauthFields{ + BaseFields: BaseFields{ + UserIdentifier: "field_uid", + Email: "field_email", + Firstname: "field_fn", + Lastname: "field_ln", + Phone: "field_phone", + Department: "field_dep", + }, + IsAdmin: "field_admin", + } + + got = getOauthFieldMapping(customFields) + assert.Equal(t, customFields, got) +} diff --git a/internal/authentication/utils.go b/internal/authentication/utils.go new file mode 100644 index 0000000..50874df --- /dev/null +++ b/internal/authentication/utils.go @@ -0,0 +1,35 @@ +package authentication + +import "fmt" + +// mapDefaultString returns the string value for the given key or a default value +func mapDefaultString(m map[string]interface{}, key string, dflt string) string { + if m == nil { + return dflt + } + if tmp, ok := m[key]; !ok { + return dflt + } else { + switch v := tmp.(type) { + case string: + return v + case nil: + return dflt + default: + return fmt.Sprintf("%v", v) + } + } +} + +// uniqueStringSlice removes duplicates in the given string slice +func uniqueStringSlice(slice []string) []string { + keys := make(map[string]struct{}) + uniqueSlice := make([]string, 0, len(slice)) + for _, entry := range slice { + if _, exists := keys[entry]; !exists { + keys[entry] = struct{}{} + uniqueSlice = append(uniqueSlice, entry) + } + } + return uniqueSlice +} diff --git a/internal/authentication/utils_test.go b/internal/authentication/utils_test.go new file mode 100644 index 0000000..bca9779 --- /dev/null +++ b/internal/authentication/utils_test.go @@ -0,0 +1,108 @@ +package authentication + +import ( + "reflect" + "testing" +) + +func Test_mapDefaultString(t *testing.T) { + type args struct { + m map[string]interface{} + key string + defaultValue string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "match", + args: args{ + m: map[string]interface{}{"hello": "world"}, + key: "hello", + defaultValue: "", + }, + want: "world", + }, { + name: "no_match", + args: args{ + m: map[string]interface{}{"hello": "world"}, + key: "hi", + defaultValue: "", + }, + want: "", + }, { + name: "nil_value", + args: args{ + m: map[string]interface{}{"hello": nil}, + key: "hello", + defaultValue: "", + }, + want: "", + }, { + name: "default_nil_value", + args: args{ + m: map[string]interface{}{"hello": nil}, + key: "hello", + defaultValue: "world", + }, + want: "world", + }, { + name: "nil_map", + args: args{ + m: nil, + key: "hi", + defaultValue: "world", + }, + want: "world", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := mapDefaultString(tt.args.m, tt.args.key, tt.args.defaultValue); got != tt.want { + t.Errorf("mapDefaultString() = %v, want %v", got, tt.want) + } + }) + } + +} + +func Test_uniqueStringSlice(t *testing.T) { + type args struct { + slice []string + } + tests := []struct { + name string + args args + want []string + }{ + { + name: "Empty", + args: args{}, + want: []string{}, + }, + { + name: "Single", + args: args{slice: []string{"1"}}, + want: []string{"1"}, + }, + { + name: "Normal", + args: args{slice: []string{"1", "2", "3"}}, + want: []string{"1", "2", "3"}, + }, + { + name: "Duplicate", + args: args{slice: []string{"1", "2", "2"}}, + want: []string{"1", "2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := uniqueStringSlice(tt.args.slice); !reflect.DeepEqual(got, tt.want) { + t.Errorf("uniqueStringSlice() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/core/backend.go b/internal/core/backend.go new file mode 100644 index 0000000..ad8140f --- /dev/null +++ b/internal/core/backend.go @@ -0,0 +1,94 @@ +package core + +import ( + "github.com/h44z/wg-portal/internal/lowlevel" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/h44z/wg-portal/internal/user" + "github.com/h44z/wg-portal/internal/wireguard" + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/wgctrl" +) + +// Backend combines the user manager and WireGuard manager. It also provides some additional functions. +type Backend interface { + user.Manager + wireguard.Manager + + ImportInterfaceById(identifier persistence.InterfaceIdentifier) error + PrepareFreshPeer(identifier persistence.InterfaceIdentifier) (*persistence.PeerConfig, error) + GetPeersForUser(identifier persistence.UserIdentifier) ([]*persistence.PeerConfig, error) +} + +// type alias +type UserManager = user.Manager +type WireGuardManager = wireguard.Manager + +type PersistentBackend struct { + UserManager + WireGuardManager +} + +func NewPersistentBackend(db *persistence.Database) (*PersistentBackend, error) { + wg, err := wgctrl.New() + if err != nil { + return nil, errors.WithMessage(err, "failed to get wgctrl handle") + } + + nl := &lowlevel.NetlinkManager{} + + wgm, err := wireguard.NewPersistentManager(wg, nl, db) + if err != nil { + return nil, errors.WithMessage(err, "failed to setup WireGuard manager") + } + + um, err := user.NewPersistentManager(db) + if err != nil { + return nil, errors.WithMessage(err, "failed to setup user manager") + } + + b := &PersistentBackend{ + UserManager: um, + WireGuardManager: wgm, + } + + return b, nil +} + +// ImportInterfaceById imports an interface. The given interface identifier must be available as importable interface. +func (b *PersistentBackend) ImportInterfaceById(identifier persistence.InterfaceIdentifier) error { + importable, err := b.GetImportableInterfaces() + if err != nil { + return errors.WithMessage(err, "failed to get importable interfaces") + } + + var interfaceConfig *wireguard.ImportableInterface + var peers []*persistence.PeerConfig + for cfg, peerList := range importable { + if cfg.Identifier == identifier { + interfaceConfig = cfg + peers = peerList + break + } + } + + if interfaceConfig == nil { + return errors.New("the given interface is not importable") + } + + err = b.WireGuardManager.ImportInterface(interfaceConfig, peers) + if err != nil { + return errors.WithMessagef(err, "failed to import interface") + } + + return nil +} + +// PrepareFreshPeer creates a new persistence.PeerConfig with prefilled keys and IP addresses. +func (b *PersistentBackend) PrepareFreshPeer(identifier persistence.InterfaceIdentifier) (*persistence.PeerConfig, error) { + return nil, nil // TODO: implement +} + +// GetPeersForUser returns all peers for the given user. +func (b *PersistentBackend) GetPeersForUser(identifier persistence.UserIdentifier) ([]*persistence.PeerConfig, error) { + return nil, nil // TODO: implement +} diff --git a/internal/core/doc.go b/internal/core/doc.go new file mode 100644 index 0000000..ea7a9f3 --- /dev/null +++ b/internal/core/doc.go @@ -0,0 +1,6 @@ +package core + +/* +Package portal manages the business logic of WireGuard Portal. +It combines and handles access to other packages like wireguard, user, lowlevel. +*/ diff --git a/internal/core/mail.go b/internal/core/mail.go new file mode 100644 index 0000000..a5be9e3 --- /dev/null +++ b/internal/core/mail.go @@ -0,0 +1,29 @@ +package core + +type MailEncryption string + +const ( + MailEncryptionNone MailEncryption = "none" + MailEncryptionTLS MailEncryption = "tls" + MailEncryptionStartTLS MailEncryption = "starttls" +) + +type MailAuthType string + +const ( + MailAuthPlain MailAuthType = "plain" + MailAuthLogin MailAuthType = "login" + MailAuthCramMD5 MailAuthType = "crammd5" +) + +type MailConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Encryption MailEncryption `yaml:"encryption"` + CertValidation bool `yaml:"cert_validation"` + Username string `yaml:"user"` + Password string `yaml:"pass"` + AuthType MailAuthType `yaml:"auth"` + MailFrom string `yaml:"mail_from"` + IncludeSensitiveData bool `yaml:"include_sensitive_data"` +} diff --git a/internal/core/web.go b/internal/core/web.go new file mode 100644 index 0000000..9a8bc95 --- /dev/null +++ b/internal/core/web.go @@ -0,0 +1 @@ +package core diff --git a/internal/persistence/models.go b/internal/persistence/models.go index e645295..c132832 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -121,6 +121,7 @@ DisplayName string // a nice display name/ description for the peer Identifier PeerIdentifier `gorm:"primaryKey"` // peer unique identifier UserIdentifier UserIdentifier `gorm:"index"` // the owner + Temporary bool `gorm:"temporary"` // is this a temporary peer (only prepared, but never saved) // Interface settings for the peer, used to generate the [interface] section in the peer config file Interface *PeerInterfaceConfig `gorm:"embedded"` diff --git a/internal/portal/api.go b/internal/portal/api.go deleted file mode 100644 index 598d685..0000000 --- a/internal/portal/api.go +++ /dev/null @@ -1,9 +0,0 @@ -package portal - -import "github.com/h44z/wg-portal/internal/wireguard" - -var man wireguard.Manager - -func init() { - man, _ = wireguard.NewPersistentManager(nil, nil, nil) -} diff --git a/internal/portal/backend.go b/internal/portal/backend.go deleted file mode 100644 index c8d03b7..0000000 --- a/internal/portal/backend.go +++ /dev/null @@ -1,94 +0,0 @@ -package portal - -import ( - "github.com/h44z/wg-portal/internal/lowlevel" - "github.com/h44z/wg-portal/internal/persistence" - "github.com/h44z/wg-portal/internal/user" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/pkg/errors" - "golang.zx2c4.com/wireguard/wgctrl" -) - -// Backend combines the user manager and WireGuard manager. It also provides some additional functions. -type Backend interface { - user.Manager - wireguard.Manager - - ImportInterfaceById(identifier persistence.InterfaceIdentifier) error - PrepareFreshPeer(identifier persistence.InterfaceIdentifier) (*persistence.PeerConfig, error) - GetPeersForUser(identifier persistence.UserIdentifier) ([]*persistence.PeerConfig, error) -} - -// type alias -type UserManager = user.Manager -type WireGuardManager = wireguard.Manager - -type PersistentBackend struct { - UserManager - WireGuardManager -} - -func NewPersistentBackend(db *persistence.Database) (*PersistentBackend, error) { - wg, err := wgctrl.New() - if err != nil { - return nil, errors.WithMessage(err, "failed to get wgctrl handle") - } - - nl := &lowlevel.NetlinkManager{} - - wgm, err := wireguard.NewPersistentManager(wg, nl, db) - if err != nil { - return nil, errors.WithMessage(err, "failed to setup WireGuard manager") - } - - um, err := user.NewPersistentManager(db) - if err != nil { - return nil, errors.WithMessage(err, "failed to setup user manager") - } - - b := &PersistentBackend{ - UserManager: um, - WireGuardManager: wgm, - } - - return b, nil -} - -// ImportInterfaceById imports an interface. The given interface identifier must be available as importable interface. -func (b *PersistentBackend) ImportInterfaceById(identifier persistence.InterfaceIdentifier) error { - importable, err := b.GetImportableInterfaces() - if err != nil { - return errors.WithMessage(err, "failed to get importable interfaces") - } - - var interfaceConfig *wireguard.ImportableInterface - var peers []*persistence.PeerConfig - for cfg, peerList := range importable { - if cfg.Identifier == identifier { - interfaceConfig = cfg - peers = peerList - break - } - } - - if interfaceConfig == nil { - return errors.New("the given interface is not importable") - } - - err = b.WireGuardManager.ImportInterface(interfaceConfig, peers) - if err != nil { - return errors.WithMessagef(err, "failed to import interface") - } - - return nil -} - -// PrepareFreshPeer creates a new persistence.PeerConfig with prefilled keys and IP addresses. -func (b *PersistentBackend) PrepareFreshPeer(identifier persistence.InterfaceIdentifier) (*persistence.PeerConfig, error) { - return nil, nil // TODO: implement -} - -// GetPeersForUser returns all peers for the given user. -func (b *PersistentBackend) GetPeersForUser(identifier persistence.UserIdentifier) ([]*persistence.PeerConfig, error) { - return nil, nil // TODO: implement -} diff --git a/internal/portal/doc.go b/internal/portal/doc.go deleted file mode 100644 index 7ffbecf..0000000 --- a/internal/portal/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -package portal - -/* -Package portal manages the business logic of WireGuard Portal. -It combines and handles access to other packages like wireguard, user, lowlevel. -*/ diff --git a/internal/portal/mail.go b/internal/portal/mail.go deleted file mode 100644 index e9d155b..0000000 --- a/internal/portal/mail.go +++ /dev/null @@ -1,29 +0,0 @@ -package portal - -type MailEncryption string - -const ( - MailEncryptionNone MailEncryption = "none" - MailEncryptionTLS MailEncryption = "tls" - MailEncryptionStartTLS MailEncryption = "starttls" -) - -type MailAuthType string - -const ( - MailAuthPlain MailAuthType = "plain" - MailAuthLogin MailAuthType = "login" - MailAuthCramMD5 MailAuthType = "crammd5" -) - -type MailConfig struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - Encryption MailEncryption `yaml:"encryption"` - CertValidation bool `yaml:"cert_validation"` - Username string `yaml:"user"` - Password string `yaml:"pass"` - AuthType MailAuthType `yaml:"auth"` - MailFrom string `yaml:"mail_from"` - IncludeSensitiveData bool `yaml:"include_sensitive_data"` -} diff --git a/internal/portal/web.go b/internal/portal/web.go deleted file mode 100644 index 8d7996b..0000000 --- a/internal/portal/web.go +++ /dev/null @@ -1 +0,0 @@ -package portal