diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index d814b2f..057d27a 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -12,7 +12,7 @@ "github.com/sirupsen/logrus" ) -var Version string = "unknown (local build)" +var Version = "unknown (local build)" func main() { _ = setupLogger(logrus.StandardLogger()) @@ -20,7 +20,7 @@ c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - logrus.Infof("Starting WireGuard Portal Server [%s]...", Version) + logrus.Infof("starting WireGuard Portal Server [%s]...", Version) // Context for clean shutdown ctx, cancel := context.WithCancel(context.Background()) @@ -28,7 +28,7 @@ service := server.Server{} if err := service.Setup(ctx); err != nil { - logrus.Fatalf("Setup failed: %v", err) + logrus.Fatalf("setup failed: %v", err) } // Attach signal handlers to context @@ -44,10 +44,10 @@ <-ctx.Done() // Wait until the context gets canceled // Give goroutines some time to stop gracefully - logrus.Info("Stopping WireGuard Portal Server...") + logrus.Info("stopping WireGuard Portal Server...") time.Sleep(2 * time.Second) - logrus.Infof("Stopped WireGuard Portal Server...") + logrus.Infof("stopped WireGuard Portal Server...") logrus.Exit(0) } diff --git a/internal/authentication/provider.go b/internal/authentication/provider.go index b8420c4..a1e57d1 100644 --- a/internal/authentication/provider.go +++ b/internal/authentication/provider.go @@ -4,10 +4,10 @@ "github.com/gin-gonic/gin" ) +// AuthContext contains all information that the AuthProvider needs to perform the authentication. type AuthContext struct { - Provider AuthProvider Username string // email or username - Password string // optional for OIDC + Password string Callback string // callback for OIDC } @@ -18,6 +18,7 @@ AuthProviderTypeOauth AuthProviderType = "oauth" ) +// AuthProvider is a interface that can be implemented by different authentication providers like LDAP, OAUTH, ... type AuthProvider interface { GetName() string GetType() AuthProviderType diff --git a/internal/authentication/providers/ldap/provider.go b/internal/authentication/providers/ldap/provider.go index 9fa6e83..7e6dc55 100644 --- a/internal/authentication/providers/ldap/provider.go +++ b/internal/authentication/providers/ldap/provider.go @@ -13,7 +13,7 @@ "github.com/pkg/errors" ) -// Provider provide login with password method +// Provider implements a password login method for an LDAP backend. type Provider struct { config *ldapconfig.Config } diff --git a/internal/authentication/providers/password/provider.go b/internal/authentication/providers/password/provider.go index 4a6d95d..ca63ee5 100644 --- a/internal/authentication/providers/password/provider.go +++ b/internal/authentication/providers/password/provider.go @@ -14,7 +14,7 @@ "gorm.io/gorm" ) -// Provider provide login with password method +// Provider implements a password login method for a database backend. type Provider struct { db *gorm.DB } diff --git a/internal/authentication/user.go b/internal/authentication/user.go index e4b0af4..a5afcfc 100644 --- a/internal/authentication/user.go +++ b/internal/authentication/user.go @@ -1,5 +1,6 @@ package authentication +// User represents the data that can be retrieved from authentication backends. type User struct { Email string IsAdmin bool diff --git a/internal/common/configuration.go b/internal/common/configuration.go index 8e8a1cf..2c9f373 100644 --- a/internal/common/configuration.go +++ b/internal/common/configuration.go @@ -1,7 +1,6 @@ package common import ( - "errors" "os" "reflect" "runtime" @@ -10,13 +9,14 @@ "github.com/h44z/wg-portal/internal/users" "github.com/h44z/wg-portal/internal/wireguard" "github.com/kelseyhightower/envconfig" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "gopkg.in/yaml.v3" ) var ErrInvalidSpecification = errors.New("specification must be a struct pointer") -// LoadConfigFile parses yaml files. It uses to yaml annotation to store the data in a struct. +// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct. func loadConfigFile(cfg interface{}, filename string) error { s := reflect.ValueOf(cfg) @@ -30,24 +30,24 @@ f, err := os.Open(filename) if err != nil { - return err + return errors.Wrapf(err, "failed to open config file %s", filename) } defer f.Close() decoder := yaml.NewDecoder(f) err = decoder.Decode(cfg) if err != nil { - return err + return errors.Wrapf(err, "failed to decode config file %s", filename) } return nil } -// LoadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct. +// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct. func loadConfigEnv(cfg interface{}) error { err := envconfig.Process("", cfg) if err != nil { - return err + return errors.Wrap(err, "failed to process environment config") } return nil @@ -124,7 +124,7 @@ } if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" { - logrus.Warnf("Managing IP addresses only works on linux! Feature disabled.") + logrus.Warnf("managing IP addresses only works on linux, feature disabled...") cfg.WG.ManageIPAddresses = false } diff --git a/internal/common/email.go b/internal/common/email.go index a70e6d2..e25da1d 100644 --- a/internal/common/email.go +++ b/internal/common/email.go @@ -26,7 +26,7 @@ Embedded bool } -// SendEmailWithAttachments sends a mail with attachments. +// SendEmailWithAttachments sends a mail with optional attachments. func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body string, htmlBody string, receivers []string, attachments []MailAttachment) error { e := email.NewEmail() diff --git a/internal/common/util.go b/internal/common/util.go index dbdf60d..fc974b1 100644 --- a/internal/common/util.go +++ b/internal/common/util.go @@ -40,6 +40,8 @@ return ip.To4() == nil } +// ParseStringList converts a comma separated string into a list of strings. +// It also trims spaces from each element of the list. func ParseStringList(lst string) []string { tokens := strings.Split(lst, ",") validatedTokens := make([]string, 0, len(tokens)) @@ -53,6 +55,7 @@ return validatedTokens } +// ListToString converts a list of strings into a comma separated string. func ListToString(lst []string) string { return strings.Join(lst, ", ") } diff --git a/internal/ldap/config.go b/internal/ldap/config.go index bb31b20..22caa9f 100644 --- a/internal/ldap/config.go +++ b/internal/ldap/config.go @@ -23,5 +23,5 @@ GroupMemberAttribute string `yaml:"attrGroups" envconfig:"LDAP_ATTR_GROUPS"` DisabledAttribute string `yaml:"attrDisabled" envconfig:"LDAP_ATTR_DISABLED"` - AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` + AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` // Members of this group receive admin rights in WG-Portal } diff --git a/internal/ldap/ldap.go b/internal/ldap/ldap.go index f6f038c..04a8d06 100644 --- a/internal/ldap/ldap.go +++ b/internal/ldap/ldap.go @@ -18,20 +18,20 @@ func Open(cfg *Config) (*ldap.Conn, error) { conn, err := ldap.DialURL(cfg.URL) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to connect to LDAP") } if cfg.StartTLS { // Reconnect with TLS err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to star TLS on connection") } } err = conn.Bind(cfg.BindUser, cfg.BindPass) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to bind to LDAP") } return conn, nil diff --git a/internal/server/auth.go b/internal/server/auth.go index d5e3ed8..0e1cc72 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -3,14 +3,13 @@ import ( "sort" - "github.com/h44z/wg-portal/internal/authentication" - "github.com/gin-gonic/gin" + "github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/users" "github.com/sirupsen/logrus" ) -// Auth auth struct +// AuthManager keeps track of available authentication providers. type AuthManager struct { Server *Server Group *gin.RouterGroup // basic group for all providers (/auth) @@ -38,7 +37,7 @@ auth.RegisterProvider(provider) } -// GetProvider get provider with name +// GetProvider get provider by name func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider { for _, provider := range auth.providers { if provider.GetName() == name { @@ -48,15 +47,23 @@ return nil } -// GetProviders return registered providers +// GetProviders return registered providers. +// Returned providers are ordered by provider priority. func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) { for _, provider := range auth.providers { providers = append(providers, provider) } + + // order by priority + sort.SliceStable(providers, func(i, j int) bool { + return providers[i].GetPriority() < providers[j].GetPriority() + }) + return } -// GetProviders return registered providers +// GetProvidersForType return registered providers for the given type. +// Returned providers are ordered by provider priority. func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) { for _, provider := range auth.providers { if provider.GetType() == typ { diff --git a/internal/server/core.go b/internal/server/core.go deleted file mode 100644 index c095efc..0000000 --- a/internal/server/core.go +++ /dev/null @@ -1,312 +0,0 @@ -package server - -import ( - "context" - "encoding/gob" - "html/template" - "io/fs" - "io/ioutil" - "math/rand" - "net/http" - "net/url" - "os" - "path/filepath" - "time" - - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-gonic/gin" - wg_portal "github.com/h44z/wg-portal" - ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap" - passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - ginlogrus "github.com/toorop/gin-logrus" -) - -const SessionIdentifier = "wgPortalSession" - -func init() { - gob.Register(SessionData{}) - gob.Register(FlashData{}) - gob.Register(Peer{}) - gob.Register(Device{}) - gob.Register(LdapCreateForm{}) - gob.Register(users.User{}) -} - -type SessionData struct { - LoggedIn bool - IsAdmin bool - Firstname string - Lastname string - Email string - - SortedBy map[string]string - SortDirection map[string]string - Search map[string]string - - AlertData string - AlertType string - FormData interface{} -} - -type FlashData struct { - HasAlert bool - Message string - Type string -} - -type StaticData struct { - WebsiteTitle string - WebsiteLogo string - CompanyName string - Year int -} - -type Server struct { - ctx context.Context - config *common.Config - server *gin.Engine - mailTpl *template.Template - auth *AuthManager - - users *users.Manager - wg *wireguard.Manager - peers *PeerManager -} - -func (s *Server) Setup(ctx context.Context) error { - var err error - - dir := s.getExecutableDirectory() - rDir, _ := filepath.Abs(filepath.Dir(os.Args[0])) - logrus.Infof("Real working directory: %s", rDir) - logrus.Infof("Current working directory: %s", dir) - - // Init rand - rand.Seed(time.Now().UnixNano()) - - s.config = common.NewConfig() - s.ctx = ctx - - // Setup http server - gin.SetMode(gin.DebugMode) - gin.DefaultWriter = ioutil.Discard - s.server = gin.New() - s.server.Use(ginlogrus.Logger(logrus.StandardLogger()), gin.Recovery()) - s.server.SetFuncMap(template.FuncMap{ - "formatBytes": common.ByteCountSI, - "urlEncode": url.QueryEscape, - }) - - // Setup templates - templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wg_portal.Templates, "assets/tpl/*.html")) - s.server.SetHTMLTemplate(templates) - s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key? - - // Serve static files - s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/css")))) - s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/js")))) - s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/img")))) - s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/fonts")))) - - // Setup all routes - SetupRoutes(s) - - // Setup user database (also needed for database authentication) - s.users, err = users.NewManager(&s.config.Database) - if err != nil { - return errors.WithMessage(err, "user-manager initialization failed") - } - - // Setup auth manager - s.auth = NewAuthManager(s) - pwProvider, err := passwordprovider.New(&s.config.Database) - if err != nil { - return errors.WithMessage(err, "password provider initialization failed") - } - if err = pwProvider.InitializeAdmin(s.config.Core.AdminUser, s.config.Core.AdminPassword); err != nil { - return errors.WithMessage(err, "admin initialization failed") - } - s.auth.RegisterProvider(pwProvider) - - if s.config.Core.LdapEnabled { - ldapProvider, err := ldapprovider.New(&s.config.LDAP) - if err != nil { - s.config.Core.LdapEnabled = false - logrus.Warnf("failed to setup LDAP connection, LDAP features disabled") - } - s.auth.RegisterProviderWithoutError(ldapProvider, err) - } - - // Setup WireGuard stuff - s.wg = &wireguard.Manager{Cfg: &s.config.WG} - if err = s.wg.Init(); err != nil { - return errors.WithMessage(err, "unable to initialize WireGuard manager") - } - - // Setup peer manager - if s.peers, err = NewPeerManager(s.config, s.wg, s.users); err != nil { - return errors.WithMessage(err, "unable to setup peer manager") - } - if err = s.peers.InitFromCurrentInterface(); err != nil { - return errors.WithMessage(err, "unable to initialize peer manager") - } - if err = s.RestoreWireGuardInterface(); err != nil { - return errors.WithMessage(err, "unable to restore WireGuard state") - } - - // Setup mail template - s.mailTpl, err = template.New("email.html").ParseFS(wg_portal.Templates, "assets/tpl/email.html") - if err != nil { - return errors.Wrap(err, "unable to pare mail template") - } - - logrus.Infof("Setup of service completed!") - return nil -} - -func (s *Server) Run() { - // Start ldap sync - if s.config.Core.LdapEnabled { - go s.SyncLdapWithUserDatabase() - } - - // Run web service - srv := &http.Server{ - Addr: s.config.Core.ListeningAddress, - Handler: s.server, - } - - go func() { - if err := srv.ListenAndServe(); err != nil { - logrus.Debugf("web service on %s exited: %v", s.config.Core.ListeningAddress, err) - } - }() - - <-s.ctx.Done() - - logrus.Debug("web service shutting down...") - - shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - _ = srv.Shutdown(shutdownCtx) - -} - -func (s *Server) getExecutableDirectory() string { - dir, err := filepath.Abs(filepath.Dir(os.Args[0])) - if err != nil { - logrus.Errorf("Failed to get executable directory: %v", err) - } - - if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) { - return "." // assets directory not found -> we are developing in goland =) - } - - return dir -} - -func (s *Server) getStaticData() StaticData { - return StaticData{ - WebsiteTitle: s.config.Core.Title, - WebsiteLogo: "/img/header-logo.png", - CompanyName: s.config.Core.CompanyName, - Year: time.Now().Year(), - } -} - -func GetSessionData(c *gin.Context) SessionData { - session := sessions.Default(c) - rawSessionData := session.Get(SessionIdentifier) - - var sessionData SessionData - if rawSessionData != nil { - sessionData = rawSessionData.(SessionData) - } else { - sessionData = SessionData{ - Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, - SortedBy: map[string]string{"peers": "mail", "userpeers": "mail", "users": "email"}, - SortDirection: map[string]string{"peers": "asc", "userpeers": "asc", "users": "asc"}, - Email: "", - Firstname: "", - Lastname: "", - IsAdmin: false, - LoggedIn: false, - } - session.Set(SessionIdentifier, sessionData) - if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session: %v", err) - } - } - - return sessionData -} - -func GetFlashes(c *gin.Context) []FlashData { - session := sessions.Default(c) - flashes := session.Flashes() - if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session after setting flash: %v", err) - } - - flashData := make([]FlashData, len(flashes)) - for i := range flashes { - flashData[i] = flashes[i].(FlashData) - } - - return flashData -} - -func UpdateSessionData(c *gin.Context, data SessionData) error { - session := sessions.Default(c) - session.Set(SessionIdentifier, data) - if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session: %v", err) - return err - } - return nil -} - -func DestroySessionData(c *gin.Context) error { - session := sessions.Default(c) - session.Delete(SessionIdentifier) - if err := session.Save(); err != nil { - logrus.Errorf("Failed to destroy session: %v", err) - return err - } - return nil -} - -func SetFlashMessage(c *gin.Context, message, typ string) { - session := sessions.Default(c) - session.AddFlash(FlashData{ - Message: message, - Type: typ, - }) - if err := session.Save(); err != nil { - logrus.Errorf("Failed to store session after setting flash: %v", err) - } -} - -func (s SessionData) GetSortIcon(table, field string) string { - if s.SortedBy[table] != field { - return "fa-sort" - } - if s.SortDirection[table] == "asc" { - return "fa-sort-alpha-down" - } else { - return "fa-sort-alpha-up" - } -} - -func fsMust(f fs.FS, err error) fs.FS { - if err != nil { - panic(err) - } - return f -} diff --git a/internal/server/handlers_auth.go b/internal/server/handlers_auth.go index e3c7f72..bb70b95 100644 --- a/internal/server/handlers_auth.go +++ b/internal/server/handlers_auth.go @@ -8,7 +8,6 @@ "github.com/h44z/wg-portal/internal/authentication" "github.com/h44z/wg-portal/internal/users" "github.com/sirupsen/logrus" - "gorm.io/gorm" ) func (s *Server) GetLogin(c *gin.Context) { @@ -85,10 +84,6 @@ loginProvider = provider // create new user in the database (or reactivate him) - if user, err = s.users.GetOrCreateUserUnscoped(email); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to create new user") - return - } userData, err := loginProvider.GetUserModel(&authentication.AuthContext{ Username: email, }) @@ -96,23 +91,25 @@ s.GetHandleError(c, http.StatusInternalServerError, "login error", err.Error()) return } - user.Firstname = userData.Firstname - user.Lastname = userData.Lastname - user.Email = userData.Email - user.Phone = userData.Phone - user.IsAdmin = userData.IsAdmin - user.Source = users.UserSource(loginProvider.GetName()) - user.DeletedAt = gorm.DeletedAt{} // reset deleted flag - if err = s.users.UpdateUser(user); err != nil { + if err := s.CreateUser(users.User{ + Email: userData.Email, + Source: users.UserSource(loginProvider.GetName()), + IsAdmin: userData.IsAdmin, + Firstname: userData.Firstname, + Lastname: userData.Lastname, + Phone: userData.Phone, + }); err != nil { s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data") return } + + user = s.users.GetUser(username) break } } // Check if user is authenticated - if email == "" || loginProvider == nil { + if email == "" || loginProvider == nil || user == nil { c.Redirect(http.StatusSeeOther, "/auth/login?err=authfail") return } @@ -126,17 +123,9 @@ sessionData.Lastname = user.Lastname // Check if user already has a peer setup, if not create one - if s.config.Core.CreateDefaultPeer { - peers := s.peers.GetPeersByMail(sessionData.Email) - if len(peers) == 0 { // Create vpn peer - err := s.CreatePeer(Peer{ - Identifier: sessionData.Firstname + " " + sessionData.Lastname + " (Default)", - Email: sessionData.Email, - CreatedBy: sessionData.Email, - UpdatedBy: sessionData.Email, - }) - logrus.Errorf("Failed to automatically create vpn peer for %s: %v", sessionData.Email, err) - } + if err := s.CreateUserDefaultPeer(user.Email); err != nil { + // Not a fatal error, just log it... + logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err) } if err := UpdateSessionData(c, sessionData); err != nil { diff --git a/internal/server/handlers_common.go b/internal/server/handlers_common.go index ba1d0ec..0ff0885 100644 --- a/internal/server/handlers_common.go +++ b/internal/server/handlers_common.go @@ -4,6 +4,8 @@ "net/http" "strconv" + "github.com/pkg/errors" + "github.com/gin-gonic/gin" ) @@ -145,7 +147,7 @@ currentSession.FormData = formData if err := UpdateSessionData(c, currentSession); err != nil { - return err + return errors.WithMessage(err, "failed to update form in session") } return nil @@ -158,13 +160,13 @@ if currentSession.FormData == nil || c.Query("formerr") == "" { user, err := s.PrepareNewPeer() if err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to prepare new peer") } currentSession.FormData = user } if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to update peer form in session") } return currentSession, nil @@ -179,7 +181,7 @@ } if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, err + return currentSession, errors.WithMessage(err, "failed to set form in session") } return currentSession, nil diff --git a/internal/server/handlers_user.go b/internal/server/handlers_user.go index c0a4400..c96a37c 100644 --- a/internal/server/handlers_user.go +++ b/internal/server/handlers_user.go @@ -7,7 +7,6 @@ "github.com/gin-gonic/gin" "github.com/h44z/wg-portal/internal/users" - "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) @@ -141,31 +140,7 @@ } formUser.IsAdmin = c.PostForm("isadmin") == "true" - // Update peers - if disabled != currentUser.DeletedAt.Valid { - if disabled { - // disable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(currentUser.Email) { - now := time.Now() - peer.DeactivatedAt = &now - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update deactivated peer %s: %v", peer.PublicKey, err) - } - } - } else { - // enable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(currentUser.Email) { - now := time.Now() - peer.DeactivatedAt = nil - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update activated peer %s: %v", peer.PublicKey, err) - } - } - } - } - - // Update in database - if err := s.users.UpdateUser(&formUser); err != nil { + if err := s.UpdateUser(formUser); err != nil { _ = s.updateFormInSession(c, formUser) SetFlashMessage(c, "failed to update user: "+err.Error(), "danger") c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=update") @@ -242,28 +217,14 @@ } formUser.IsAdmin = c.PostForm("isadmin") == "true" formUser.Source = users.UserSourceDatabase - if err := s.users.CreateUser(&formUser); err != nil { - formUser.CreatedAt = time.Time{} // reset created time + + if err := s.CreateUser(formUser); err != nil { _ = s.updateFormInSession(c, formUser) SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") return } - // Check if user already has a peer setup, if not create one - if s.config.Core.CreateDefaultPeer { - peers := s.peers.GetPeersByMail(formUser.Email) - if len(peers) == 0 { // Create vpn peer - err := s.CreatePeer(Peer{ - Identifier: formUser.Firstname + " " + formUser.Lastname + " (Default)", - Email: formUser.Email, - CreatedBy: formUser.Email, - UpdatedBy: formUser.Email, - }) - logrus.Errorf("Failed to automatically create vpn peer for %s: %v", formUser.Email, err) - } - } - SetFlashMessage(c, "user created successfully", "success") c.Redirect(http.StatusSeeOther, "/admin/users/") } diff --git a/internal/server/helper.go b/internal/server/helper.go deleted file mode 100644 index 0b59aa4..0000000 --- a/internal/server/helper.go +++ /dev/null @@ -1,201 +0,0 @@ -package server - -import ( - "crypto/md5" - "fmt" - "io/ioutil" - "syscall" - "time" - - "github.com/h44z/wg-portal/internal/common" - "github.com/pkg/errors" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -func (s *Server) PrepareNewPeer() (Peer, error) { - device := s.peers.GetDevice() - - peer := Peer{} - peer.IsNew = true - peer.AllowedIPsStr = device.AllowedIPsStr - peer.IPs = make([]string, len(device.IPs)) - for i := range device.IPs { - freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) - if err != nil { - return Peer{}, err - } - peer.IPs[i] = freeIP - } - peer.IPsStr = common.ListToString(peer.IPs) - psk, err := wgtypes.GenerateKey() - if err != nil { - return Peer{}, err - } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return Peer{}, err - } - peer.PresharedKey = psk.String() - peer.PrivateKey = key.String() - peer.PublicKey = key.PublicKey().String() - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - - return peer, nil -} - -func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error { - user, err := s.users.GetOrCreateUser(email) - if err != nil { - return errors.WithMessagef(err, "failed to load/create related user %s", email) - } - - device := s.peers.GetDevice() - peer := Peer{} - peer.User = user - peer.AllowedIPsStr = device.AllowedIPsStr - peer.IPs = make([]string, len(device.IPs)) - for i := range device.IPs { - freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) - if err != nil { - return err - } - peer.IPs[i] = freeIP - } - peer.IPsStr = common.ListToString(peer.IPs) - psk, err := wgtypes.GenerateKey() - if err != nil { - return err - } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return err - } - peer.PresharedKey = psk.String() - peer.PrivateKey = key.String() - peer.PublicKey = key.PublicKey().String() - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - peer.Email = email - peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix) - now := time.Now() - if disabled { - peer.DeactivatedAt = &now - } - - return s.CreatePeer(peer) -} - -func (s *Server) CreatePeer(peer Peer) error { - device := s.peers.GetDevice() - peer.AllowedIPsStr = device.AllowedIPsStr - if peer.IPs == nil || len(peer.IPs) == 0 { - peer.IPs = make([]string, len(device.IPs)) - for i := range device.IPs { - freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) - if err != nil { - return err - } - peer.IPs[i] = freeIP - } - peer.IPsStr = common.ListToString(peer.IPs) - } - if peer.PrivateKey == "" { // if private key is empty create a new one - psk, err := wgtypes.GenerateKey() - if err != nil { - return err - } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return err - } - peer.PresharedKey = psk.String() - peer.PrivateKey = key.String() - peer.PublicKey = key.PublicKey().String() - } - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - - // Create WireGuard interface - if peer.DeactivatedAt == nil { - if err := s.wg.AddPeer(peer.GetConfig()); err != nil { - return err - } - } - - // Create in database - if err := s.peers.CreatePeer(peer); err != nil { - return err - } - - return s.WriteWireGuardConfigFile() -} - -func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { - currentPeer := s.peers.GetPeerByKey(peer.PublicKey) - - // Update WireGuard device - var err error - switch { - case peer.DeactivatedAt == &updateTime: - err = s.wg.RemovePeer(peer.PublicKey) - case peer.DeactivatedAt == nil && currentPeer.Peer != nil: - err = s.wg.UpdatePeer(peer.GetConfig()) - case peer.DeactivatedAt == nil && currentPeer.Peer == nil: - err = s.wg.AddPeer(peer.GetConfig()) - } - if err != nil { - return err - } - - // Update in database - if err := s.peers.UpdatePeer(peer); err != nil { - return err - } - - return s.WriteWireGuardConfigFile() -} - -func (s *Server) DeletePeer(peer Peer) error { - // Delete WireGuard peer - if err := s.wg.RemovePeer(peer.PublicKey); err != nil { - return err - } - - // Delete in database - if err := s.peers.DeletePeer(peer); err != nil { - return err - } - - return s.WriteWireGuardConfigFile() -} - -func (s *Server) RestoreWireGuardInterface() error { - activePeers := s.peers.GetActivePeers() - - for i := range activePeers { - if activePeers[i].Peer == nil { - if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil { - return err - } - } - } - - return nil -} - -func (s *Server) WriteWireGuardConfigFile() error { - if s.config.WG.WireGuardConfig == "" { - return nil // writing disabled - } - if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil { - return err - } - - device := s.peers.GetDevice() - cfg, err := device.GetConfigFile(s.peers.GetActivePeers()) - if err != nil { - return err - } - if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil { - return err - } - return nil -} diff --git a/internal/server/peermanager.go b/internal/server/peermanager.go index e26785f..ad95fd1 100644 --- a/internal/server/peermanager.go +++ b/internal/server/peermanager.go @@ -124,7 +124,7 @@ func (p Peer) GetConfigFile(device Device) ([]byte, error) { tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse client template") } var tplBuff bytes.Buffer @@ -137,7 +137,7 @@ Server: device, }) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to execute client template") } return tplBuff.Bytes(), nil @@ -149,7 +149,7 @@ logrus.WithFields(logrus.Fields{ "err": err, }).Error("failed to create qrcode") - return nil, err + return nil, errors.Wrap(err, "failed to encode qrcode") } return png, nil } @@ -247,7 +247,7 @@ func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse server template") } var tplBuff bytes.Buffer @@ -260,7 +260,7 @@ Server: d, }) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to execute server template") } return tplBuff.Bytes(), nil @@ -582,7 +582,7 @@ res := u.db.Create(&peer) if res.Error != nil { logrus.Errorf("failed to create peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to create peer") } return nil @@ -596,7 +596,7 @@ res := u.db.Save(&peer) if res.Error != nil { logrus.Errorf("failed to update peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to update peer") } return nil @@ -606,7 +606,7 @@ res := u.db.Delete(&peer) if res.Error != nil { logrus.Errorf("failed to delete peer: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to delete peer") } return nil @@ -621,7 +621,7 @@ res := u.db.Save(&device) if res.Error != nil { logrus.Errorf("failed to update device: %v", res.Error) - return res.Error + return errors.Wrap(res.Error, "failed to update device") } return nil @@ -637,7 +637,7 @@ } ip, _, err := net.ParseCIDR(cidr) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse cidr") } reservedIps = append(reservedIps, ip.String()) } @@ -650,7 +650,7 @@ } ip, _, err := net.ParseCIDR(cidr) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to parse cidr") } reservedIps = append(reservedIps, ip.String()) @@ -691,11 +691,11 @@ func (u *PeerManager) GetAvailableIp(cidr string) (string, error) { reserved, err := u.GetAllReservedIps() if err != nil { - return "", err + return "", errors.WithMessage(err, "failed to get all reserved IP addresses") } ip, ipnet, err := net.ParseCIDR(cidr) if err != nil { - return "", err + return "", errors.Wrap(err, "failed to parse cidr") } // this two addresses are not usable diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..096c7e0 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,312 @@ +package server + +import ( + "context" + "encoding/gob" + "html/template" + "io/fs" + "io/ioutil" + "math/rand" + "net/http" + "net/url" + "os" + "path/filepath" + "time" + + "github.com/gin-contrib/sessions" + "github.com/gin-contrib/sessions/memstore" + "github.com/gin-gonic/gin" + wg_portal "github.com/h44z/wg-portal" + ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap" + passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password" + "github.com/h44z/wg-portal/internal/common" + "github.com/h44z/wg-portal/internal/users" + "github.com/h44z/wg-portal/internal/wireguard" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + ginlogrus "github.com/toorop/gin-logrus" +) + +const SessionIdentifier = "wgPortalSession" + +func init() { + gob.Register(SessionData{}) + gob.Register(FlashData{}) + gob.Register(Peer{}) + gob.Register(Device{}) + gob.Register(LdapCreateForm{}) + gob.Register(users.User{}) +} + +type SessionData struct { + LoggedIn bool + IsAdmin bool + Firstname string + Lastname string + Email string + + SortedBy map[string]string + SortDirection map[string]string + Search map[string]string + + AlertData string + AlertType string + FormData interface{} +} + +type FlashData struct { + HasAlert bool + Message string + Type string +} + +type StaticData struct { + WebsiteTitle string + WebsiteLogo string + CompanyName string + Year int +} + +type Server struct { + ctx context.Context + config *common.Config + server *gin.Engine + mailTpl *template.Template + auth *AuthManager + + users *users.Manager + wg *wireguard.Manager + peers *PeerManager +} + +func (s *Server) Setup(ctx context.Context) error { + var err error + + dir := s.getExecutableDirectory() + rDir, _ := filepath.Abs(filepath.Dir(os.Args[0])) + logrus.Infof("real working directory: %s", rDir) + logrus.Infof("current working directory: %s", dir) + + // Init rand + rand.Seed(time.Now().UnixNano()) + + s.config = common.NewConfig() + s.ctx = ctx + + // Setup http server + gin.SetMode(gin.DebugMode) + gin.DefaultWriter = ioutil.Discard + s.server = gin.New() + s.server.Use(ginlogrus.Logger(logrus.StandardLogger()), gin.Recovery()) + s.server.SetFuncMap(template.FuncMap{ + "formatBytes": common.ByteCountSI, + "urlEncode": url.QueryEscape, + }) + + // Setup templates + templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wg_portal.Templates, "assets/tpl/*.html")) + s.server.SetHTMLTemplate(templates) + s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key? + + // Serve static files + s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/css")))) + s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/js")))) + s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/img")))) + s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/fonts")))) + + // Setup all routes + SetupRoutes(s) + + // Setup user database (also needed for database authentication) + s.users, err = users.NewManager(&s.config.Database) + if err != nil { + return errors.WithMessage(err, "user-manager initialization failed") + } + + // Setup auth manager + s.auth = NewAuthManager(s) + pwProvider, err := passwordprovider.New(&s.config.Database) + if err != nil { + return errors.WithMessage(err, "password provider initialization failed") + } + if err = pwProvider.InitializeAdmin(s.config.Core.AdminUser, s.config.Core.AdminPassword); err != nil { + return errors.WithMessage(err, "admin initialization failed") + } + s.auth.RegisterProvider(pwProvider) + + if s.config.Core.LdapEnabled { + ldapProvider, err := ldapprovider.New(&s.config.LDAP) + if err != nil { + s.config.Core.LdapEnabled = false + logrus.Warnf("failed to setup LDAP connection, LDAP features disabled") + } + s.auth.RegisterProviderWithoutError(ldapProvider, err) + } + + // Setup WireGuard stuff + s.wg = &wireguard.Manager{Cfg: &s.config.WG} + if err = s.wg.Init(); err != nil { + return errors.WithMessage(err, "unable to initialize WireGuard manager") + } + + // Setup peer manager + if s.peers, err = NewPeerManager(s.config, s.wg, s.users); err != nil { + return errors.WithMessage(err, "unable to setup peer manager") + } + if err = s.peers.InitFromCurrentInterface(); err != nil { + return errors.WithMessage(err, "unable to initialize peer manager") + } + if err = s.RestoreWireGuardInterface(); err != nil { + return errors.WithMessage(err, "unable to restore WireGuard state") + } + + // Setup mail template + s.mailTpl, err = template.New("email.html").ParseFS(wg_portal.Templates, "assets/tpl/email.html") + if err != nil { + return errors.Wrap(err, "unable to pare mail template") + } + + logrus.Infof("setup of service completed!") + return nil +} + +func (s *Server) Run() { + // Start ldap sync + if s.config.Core.LdapEnabled { + go s.SyncLdapWithUserDatabase() + } + + // Run web service + srv := &http.Server{ + Addr: s.config.Core.ListeningAddress, + Handler: s.server, + } + + go func() { + if err := srv.ListenAndServe(); err != nil { + logrus.Debugf("web service on %s exited: %v", s.config.Core.ListeningAddress, err) + } + }() + + <-s.ctx.Done() + + logrus.Debug("web service shutting down...") + + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = srv.Shutdown(shutdownCtx) + +} + +func (s *Server) getExecutableDirectory() string { + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + logrus.Errorf("failed to get executable directory: %v", err) + } + + if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) { + return "." // assets directory not found -> we are developing in goland =) + } + + return dir +} + +func (s *Server) getStaticData() StaticData { + return StaticData{ + WebsiteTitle: s.config.Core.Title, + WebsiteLogo: "/img/header-logo.png", + CompanyName: s.config.Core.CompanyName, + Year: time.Now().Year(), + } +} + +func GetSessionData(c *gin.Context) SessionData { + session := sessions.Default(c) + rawSessionData := session.Get(SessionIdentifier) + + var sessionData SessionData + if rawSessionData != nil { + sessionData = rawSessionData.(SessionData) + } else { + sessionData = SessionData{ + Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, + SortedBy: map[string]string{"peers": "mail", "userpeers": "mail", "users": "email"}, + SortDirection: map[string]string{"peers": "asc", "userpeers": "asc", "users": "asc"}, + Email: "", + Firstname: "", + Lastname: "", + IsAdmin: false, + LoggedIn: false, + } + session.Set(SessionIdentifier, sessionData) + if err := session.Save(); err != nil { + logrus.Errorf("failed to store session: %v", err) + } + } + + return sessionData +} + +func GetFlashes(c *gin.Context) []FlashData { + session := sessions.Default(c) + flashes := session.Flashes() + if err := session.Save(); err != nil { + logrus.Errorf("failed to store session after setting flash: %v", err) + } + + flashData := make([]FlashData, len(flashes)) + for i := range flashes { + flashData[i] = flashes[i].(FlashData) + } + + return flashData +} + +func UpdateSessionData(c *gin.Context, data SessionData) error { + session := sessions.Default(c) + session.Set(SessionIdentifier, data) + if err := session.Save(); err != nil { + logrus.Errorf("failed to store session: %v", err) + return errors.Wrap(err, "failed to store session") + } + return nil +} + +func DestroySessionData(c *gin.Context) error { + session := sessions.Default(c) + session.Delete(SessionIdentifier) + if err := session.Save(); err != nil { + logrus.Errorf("failed to destroy session: %v", err) + return errors.Wrap(err, "failed to destroy session") + } + return nil +} + +func SetFlashMessage(c *gin.Context, message, typ string) { + session := sessions.Default(c) + session.AddFlash(FlashData{ + Message: message, + Type: typ, + }) + if err := session.Save(); err != nil { + logrus.Errorf("failed to store session after setting flash: %v", err) + } +} + +func (s SessionData) GetSortIcon(table, field string) string { + if s.SortedBy[table] != field { + return "fa-sort" + } + if s.SortDirection[table] == "asc" { + return "fa-sort-alpha-down" + } else { + return "fa-sort-alpha-up" + } +} + +func fsMust(f fs.FS, err error) fs.FS { + if err != nil { + panic(err) + } + return f +} diff --git a/internal/server/server_helper.go b/internal/server/server_helper.go new file mode 100644 index 0000000..41fd7dc --- /dev/null +++ b/internal/server/server_helper.go @@ -0,0 +1,297 @@ +package server + +import ( + "crypto/md5" + "fmt" + "io/ioutil" + "syscall" + "time" + + "github.com/h44z/wg-portal/internal/common" + "github.com/h44z/wg-portal/internal/users" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "gorm.io/gorm" +) + +func (s *Server) PrepareNewPeer() (Peer, error) { + device := s.peers.GetDevice() + + peer := Peer{} + peer.IsNew = true + peer.AllowedIPsStr = device.AllowedIPsStr + peer.IPs = make([]string, len(device.IPs)) + for i := range device.IPs { + freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) + if err != nil { + return Peer{}, errors.WithMessage(err, "failed to get available IP addresses") + } + peer.IPs[i] = freeIP + } + peer.IPsStr = common.ListToString(peer.IPs) + psk, err := wgtypes.GenerateKey() + if err != nil { + return Peer{}, errors.Wrap(err, "failed to generate key") + } + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return Peer{}, errors.Wrap(err, "failed to generate private key") + } + peer.PresharedKey = psk.String() + peer.PrivateKey = key.String() + peer.PublicKey = key.PublicKey().String() + peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) + + return peer, nil +} + +func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error { + user, err := s.users.GetOrCreateUser(email) + if err != nil { + return errors.WithMessagef(err, "failed to load/create related user %s", email) + } + + device := s.peers.GetDevice() + peer := Peer{} + peer.User = user + peer.AllowedIPsStr = device.AllowedIPsStr + peer.IPs = make([]string, len(device.IPs)) + for i := range device.IPs { + freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) + if err != nil { + return errors.WithMessage(err, "failed to get available IP addresses") + } + peer.IPs[i] = freeIP + } + peer.IPsStr = common.ListToString(peer.IPs) + psk, err := wgtypes.GenerateKey() + if err != nil { + return errors.Wrap(err, "failed to generate key") + } + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return errors.Wrap(err, "failed to generate private key") + } + peer.PresharedKey = psk.String() + peer.PrivateKey = key.String() + peer.PublicKey = key.PublicKey().String() + peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) + peer.Email = email + peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix) + now := time.Now() + if disabled { + peer.DeactivatedAt = &now + } + + return s.CreatePeer(peer) +} + +func (s *Server) CreatePeer(peer Peer) error { + device := s.peers.GetDevice() + peer.AllowedIPsStr = device.AllowedIPsStr + if peer.IPs == nil || len(peer.IPs) == 0 { + peer.IPs = make([]string, len(device.IPs)) + for i := range device.IPs { + freeIP, err := s.peers.GetAvailableIp(device.IPs[i]) + if err != nil { + return errors.WithMessage(err, "failed to get available IP addresses") + } + peer.IPs[i] = freeIP + } + peer.IPsStr = common.ListToString(peer.IPs) + } + if peer.PrivateKey == "" { // if private key is empty create a new one + psk, err := wgtypes.GenerateKey() + if err != nil { + return errors.Wrap(err, "failed to generate key") + } + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + return errors.Wrap(err, "failed to generate private key") + } + peer.PresharedKey = psk.String() + peer.PrivateKey = key.String() + peer.PublicKey = key.PublicKey().String() + } + peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) + + // Create WireGuard interface + if peer.DeactivatedAt == nil { + if err := s.wg.AddPeer(peer.GetConfig()); err != nil { + return errors.WithMessage(err, "failed to add WireGuard peer") + } + } + + // Create in database + if err := s.peers.CreatePeer(peer); err != nil { + return errors.WithMessage(err, "failed to create peer") + } + + return s.WriteWireGuardConfigFile() +} + +func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error { + currentPeer := s.peers.GetPeerByKey(peer.PublicKey) + + // Update WireGuard device + var err error + switch { + case peer.DeactivatedAt == &updateTime: + err = s.wg.RemovePeer(peer.PublicKey) + case peer.DeactivatedAt == nil && currentPeer.Peer != nil: + err = s.wg.UpdatePeer(peer.GetConfig()) + case peer.DeactivatedAt == nil && currentPeer.Peer == nil: + err = s.wg.AddPeer(peer.GetConfig()) + } + if err != nil { + return errors.WithMessage(err, "failed to update WireGuard peer") + } + + // Update in database + if err := s.peers.UpdatePeer(peer); err != nil { + return errors.WithMessage(err, "failed to update peer") + } + + return s.WriteWireGuardConfigFile() +} + +func (s *Server) DeletePeer(peer Peer) error { + // Delete WireGuard peer + if err := s.wg.RemovePeer(peer.PublicKey); err != nil { + return errors.WithMessage(err, "failed to remove WireGuard peer") + } + + // Delete in database + if err := s.peers.DeletePeer(peer); err != nil { + return errors.WithMessage(err, "failed to remove peer") + } + + return s.WriteWireGuardConfigFile() +} + +func (s *Server) RestoreWireGuardInterface() error { + activePeers := s.peers.GetActivePeers() + + for i := range activePeers { + if activePeers[i].Peer == nil { + if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil { + return errors.WithMessage(err, "failed to add WireGuard peer") + } + } + } + + return nil +} + +func (s *Server) WriteWireGuardConfigFile() error { + if s.config.WG.WireGuardConfig == "" { + return nil // writing disabled + } + if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil { + return errors.Wrap(err, "failed to check WireGuard config access rights") + } + + device := s.peers.GetDevice() + cfg, err := device.GetConfigFile(s.peers.GetActivePeers()) + if err != nil { + return errors.WithMessage(err, "failed to get config file") + } + if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil { + return errors.Wrap(err, "failed to write WireGuard config file") + } + return nil +} + +func (s *Server) CreateUser(user users.User) error { + if user.Email == "" { + return errors.New("cannot create user with empty email address") + } + + // Check if user already exists, if so re-enable + if existingUser := s.users.GetUserUnscoped(user.Email); existingUser != nil { + user.DeletedAt = gorm.DeletedAt{} // reset deleted flag to enable that user again + return s.UpdateUser(user) + } + + // Create user in database + if err := s.users.CreateUser(&user); err != nil { + return errors.WithMessage(err, "failed to create user in manager") + } + + // Check if user already has a peer setup, if not, create one + return s.CreateUserDefaultPeer(user.Email) +} + +func (s *Server) UpdateUser(user users.User) error { + if user.DeletedAt.Valid { + return s.DeleteUser(user) + } + + currentUser := s.users.GetUserUnscoped(user.Email) + + // Update in database + if err := s.users.UpdateUser(&user); err != nil { + return errors.WithMessage(err, "failed to update user in manager") + } + + // If user was deleted (disabled), reactivate it's peers + if currentUser.DeletedAt.Valid { + for _, peer := range s.peers.GetPeersByMail(user.Email) { + now := time.Now() + peer.DeactivatedAt = nil + if err := s.UpdatePeer(peer, now); err != nil { + logrus.Errorf("failed to update (re)activated peer %s for %s: %v", peer.PublicKey, user.Email, err) + } + } + } + + return nil +} + +func (s *Server) DeleteUser(user users.User) error { + currentUser := s.users.GetUserUnscoped(user.Email) + + // Update in database + if err := s.users.DeleteUser(&user); err != nil { + return errors.WithMessage(err, "failed to delete user in manager") + } + + // If user was active, disable it's peers + if !currentUser.DeletedAt.Valid { + for _, peer := range s.peers.GetPeersByMail(user.Email) { + now := time.Now() + peer.DeactivatedAt = &now + if err := s.UpdatePeer(peer, now); err != nil { + logrus.Errorf("failed to update deactivated peer %s for %s: %v", peer.PublicKey, user.Email, err) + } + } + } + + return nil +} + +func (s *Server) CreateUserDefaultPeer(email string) error { + // Check if user is active, if not, quit + var existingUser *users.User + if existingUser = s.users.GetUser(email); existingUser == nil { + return nil + } + + // Check if user already has a peer setup, if not, create one + if s.config.Core.CreateDefaultPeer { + peers := s.peers.GetPeersByMail(email) + if len(peers) == 0 { // Create default vpn peer + if err := s.CreatePeer(Peer{ + Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)", + Email: existingUser.Email, + CreatedBy: existingUser.Email, + UpdatedBy: existingUser.Email, + }); err != nil { + return errors.WithMessagef(err, "failed to automatically create vpn peer for %s", email) + } + } + } + + return nil +} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 976b841..f7f2a78 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -1,9 +1,10 @@ package wireguard import ( - "fmt" "sync" + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -18,7 +19,7 @@ var err error m.wg, err = wgctrl.New() if err != nil { - return fmt.Errorf("could not create WireGuard client: %w", err) + return errors.Wrap(err, "could not create WireGuard client") } return nil @@ -27,7 +28,7 @@ func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) { dev, err := m.wg.Device(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not get WireGuard device: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard device") } return dev, nil @@ -39,7 +40,7 @@ dev, err := m.wg.Device(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not get WireGuard device: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard device") } return dev.Peers, nil @@ -51,12 +52,12 @@ publicKey, err := wgtypes.ParseKey(pubKey) if err != nil { - return nil, fmt.Errorf("invalid public key: %w", err) + return nil, errors.Wrap(err, "invalid public key") } peers, err := m.GetPeerList() if err != nil { - return nil, fmt.Errorf("could not get WireGuard peers: %w", err) + return nil, errors.Wrap(err, "could not get WireGuard peers") } for _, peer := range peers { @@ -65,7 +66,7 @@ } } - return nil, fmt.Errorf("could not find WireGuard peer: %s", pubKey) + return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey) } func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error { @@ -74,7 +75,7 @@ err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil @@ -87,7 +88,7 @@ cfg.UpdateOnly = true err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil @@ -99,7 +100,7 @@ publicKey, err := wgtypes.ParseKey(pubKey) if err != nil { - return fmt.Errorf("invalid public key: %w", err) + return errors.Wrap(err, "invalid public key") } peer := wgtypes.PeerConfig{ @@ -109,7 +110,7 @@ err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}}) if err != nil { - return fmt.Errorf("could not configure WireGuard device: %w", err) + return errors.Wrap(err, "could not configure WireGuard device") } return nil diff --git a/internal/wireguard/net.go b/internal/wireguard/net.go index 92526f8..0b9e68b 100644 --- a/internal/wireguard/net.go +++ b/internal/wireguard/net.go @@ -4,6 +4,8 @@ "fmt" "net" + "github.com/pkg/errors" + "github.com/milosgajdos/tenus" ) @@ -12,18 +14,18 @@ func (m *Manager) GetIPAddress() ([]string, error) { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return nil, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // Get golang net.interface iface := wgInterface.NetInterface() if iface == nil { // Not sure if this check is really necessary - return nil, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err) + return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface") } addrs, err := iface.Addrs() if err != nil { - return nil, fmt.Errorf("could not retrieve WireGuard ip addresses: %w", err) + return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses") } ipAddresses := make([]string, 0, len(addrs)) @@ -53,22 +55,22 @@ func (m *Manager) SetIPAddress(cidrs []string) error { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // First remove existing IP addresses existingIPs, err := m.GetIPAddress() if err != nil { - return err + return errors.Wrap(err, "could not retrieve IP addresses") } for _, cidr := range existingIPs { wgIp, wgIpNet, err := net.ParseCIDR(cidr) if err != nil { - return fmt.Errorf("unable to parse cidr %s: %w", cidr, err) + return errors.Wrapf(err, "unable to parse cidr %s", cidr) } if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil { - return fmt.Errorf("failed to unset ip %s: %w", cidr, err) + return errors.Wrapf(err, "failed to unset ip %s", cidr) } } @@ -76,11 +78,11 @@ for _, cidr := range cidrs { wgIp, wgIpNet, err := net.ParseCIDR(cidr) if err != nil { - return fmt.Errorf("unable to parse cidr %s: %w", cidr, err) + return errors.Wrapf(err, "unable to parse cidr %s", cidr) } if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil { - return fmt.Errorf("failed to set ip %s: %w", cidr, err) + return errors.Wrapf(err, "failed to set ip %s", cidr) } } @@ -90,13 +92,13 @@ func (m *Manager) GetMTU() (int, error) { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return 0, fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } // Get golang net.interface iface := wgInterface.NetInterface() if iface == nil { // Not sure if this check is really necessary - return 0, fmt.Errorf("could not retrieve WireGuard net.interface: %w", err) + return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface") } return iface.MTU, nil @@ -105,7 +107,7 @@ func (m *Manager) SetMTU(mtu int) error { wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName) if err != nil { - return fmt.Errorf("could not retrieve WireGuard interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName) } if mtu == 0 { @@ -113,7 +115,7 @@ } if err := wgInterface.SetLinkMTU(mtu); err != nil { - return fmt.Errorf("could not set MTU on interface %s: %w", m.Cfg.DeviceName, err) + return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName) } return nil