diff --git a/Makefile b/Makefile index c8f8483..356907b 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,9 @@ test: dep $(GOCMD) test $(MODULENAME)/... -v -count=1 +test-integration: dep + $(GOCMD) test -tags=integration $(MODULENAME)/... -v -count=1 + clean: $(GOCMD) clean $(GOFILES) rm -rf .testCoverage.txt diff --git a/cmd/wg-portal/main.go b/cmd/wg-portal/main.go index b776558..7905807 100644 --- a/cmd/wg-portal/main.go +++ b/cmd/wg-portal/main.go @@ -1,102 +1,5 @@ package main -import ( - "context" - "io/ioutil" - "os" - "os/signal" - "syscall" - "time" - - "git.prolicht.digital/pub/healthcheck" - "github.com/h44z/wg-portal/internal/server" - "github.com/sirupsen/logrus" -) - func main() { - _ = setupLogger(logrus.StandardLogger()) - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - - logrus.Infof("starting WireGuard Portal Server [%s]...", server.Version) - - // Context for clean shutdown - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // start health check service on port 11223 - healthcheck.New(healthcheck.WithContext(ctx)).Start() - - service := server.Server{} - if err := service.Setup(ctx); err != nil { - logrus.Fatalf("setup failed: %v", err) - } - - // Attach signal handlers to context - go func() { - osCall := <-c - logrus.Tracef("received system call: %v", osCall) - cancel() // cancel the context - }() - - // Start main process in background - go service.Run() - - <-ctx.Done() // Wait until the context gets canceled - - // Give goroutines some time to stop gracefully - logrus.Info("stopping WireGuard Portal Server...") - time.Sleep(2 * time.Second) - - logrus.Infof("stopped WireGuard Portal Server...") - logrus.Exit(0) -} - -func setupLogger(logger *logrus.Logger) error { - // Check environment variables for logrus settings - level, ok := os.LookupEnv("LOG_LEVEL") - if !ok { - level = "debug" // Default logrus level - } - - useJSON, ok := os.LookupEnv("LOG_JSON") - if !ok { - useJSON = "false" // Default use human readable logging - } - - useColor, ok := os.LookupEnv("LOG_COLOR") - if !ok { - useColor = "true" - } - - switch level { - case "off": - logger.SetOutput(ioutil.Discard) - case "info": - logger.SetLevel(logrus.InfoLevel) - case "debug": - logger.SetLevel(logrus.DebugLevel) - case "trace": - logger.SetLevel(logrus.TraceLevel) - } - - var formatter logrus.Formatter - if useJSON == "false" { - f := new(logrus.TextFormatter) - f.TimestampFormat = "2006-01-02 15:04:05" - f.FullTimestamp = true - if useColor == "true" { - f.ForceColors = true - } - formatter = f - } else { - f := new(logrus.JSONFormatter) - f.TimestampFormat = "2006-01-02 15:04:05" - formatter = f - } - - logger.SetFormatter(formatter) - - return nil } diff --git a/go.mod b/go.mod index f89343d..bfb1ac0 100644 --- a/go.mod +++ b/go.mod @@ -3,34 +3,14 @@ go 1.16 require ( - git.prolicht.digital/pub/healthcheck v1.0.1 - github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 - github.com/evanphx/json-patch v0.5.2 - github.com/gin-contrib/sessions v0.0.3 - github.com/gin-gonic/gin v1.7.2 - github.com/go-ldap/ldap/v3 v3.3.0 - github.com/go-openapi/spec v0.20.3 // indirect - github.com/go-openapi/swag v0.19.15 // indirect - github.com/go-playground/validator/v10 v10.8.0 - github.com/gorilla/sessions v1.2.1 // indirect - github.com/kelseyhightower/envconfig v1.4.0 - github.com/mailru/easyjson v0.7.7 // indirect - github.com/milosgajdos/tenus v0.0.3 + github.com/kr/text v0.2.0 // indirect + github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pkg/errors v0.9.1 - github.com/sirupsen/logrus v1.8.1 - github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/swaggo/gin-swagger v1.3.0 - github.com/swaggo/swag v1.7.0 - github.com/tatsushid/go-fastping v0.0.0-20160109021039-d7bb493dee3e - github.com/toorop/gin-logrus v0.0.0-20210225092905-2c785434f26f - github.com/utrack/gin-csrf v0.0.0-20190424104817-40fb8d2c8fca - github.com/xhit/go-simple-mail/v2 v2.10.0 - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 - golang.org/x/tools v0.1.0 // indirect - golang.zx2c4.com/wireguard v0.0.20200121 // indirect + github.com/stretchr/testify v1.6.1 + github.com/vishvananda/netlink v1.1.0 + golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect + golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210506160403-92e472f520a5 - gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b - gorm.io/driver/mysql v1.1.1 - gorm.io/driver/sqlite v1.1.4 - gorm.io/gorm v1.21.12 + gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/internal/authentication/provider.go b/internal/authentication/provider.go deleted file mode 100644 index a1e57d1..0000000 --- a/internal/authentication/provider.go +++ /dev/null @@ -1,32 +0,0 @@ -package authentication - -import ( - "github.com/gin-gonic/gin" -) - -// AuthContext contains all information that the AuthProvider needs to perform the authentication. -type AuthContext struct { - Username string // email or username - Password string - Callback string // callback for OIDC -} - -type AuthProviderType string - -const ( - AuthProviderTypePassword AuthProviderType = "password" - AuthProviderTypeOauth AuthProviderType = "oauth" -) - -// AuthProvider is a interface that can be implemented by different authentication providers like LDAP, OAUTH, ... -type AuthProvider interface { - GetName() string - GetType() AuthProviderType - GetPriority() int // lower number = higher priority - - Login(*AuthContext) (string, error) - Logout(*AuthContext) error - GetUserModel(*AuthContext) (*User, error) - - SetupRoutes(routes *gin.RouterGroup) -} diff --git a/internal/authentication/providers/ldap/provider.go b/internal/authentication/providers/ldap/provider.go deleted file mode 100644 index 97b9a05..0000000 --- a/internal/authentication/providers/ldap/provider.go +++ /dev/null @@ -1,183 +0,0 @@ -package ldap - -import ( - "crypto/tls" - "strings" - - "github.com/gin-gonic/gin" - "github.com/go-ldap/ldap/v3" - "github.com/h44z/wg-portal/internal/authentication" - ldapconfig "github.com/h44z/wg-portal/internal/ldap" - "github.com/h44z/wg-portal/internal/users" - "github.com/pkg/errors" -) - -// Provider implements a password login method for an LDAP backend. -type Provider struct { - config *ldapconfig.Config -} - -func New(cfg *ldapconfig.Config) (*Provider, error) { - p := &Provider{ - config: cfg, - } - - // test ldap connectivity - client, err := p.open() - if err != nil { - return nil, errors.Wrap(err, "unable to open ldap connection") - } - defer p.close(client) - - return p, nil -} - -// GetName return provider name -func (Provider) GetName() string { - return string(users.UserSourceLdap) -} - -// GetType return provider type -func (Provider) GetType() authentication.AuthProviderType { - return authentication.AuthProviderTypePassword -} - -// GetPriority return provider priority -func (Provider) GetPriority() int { - return 1 // LDAP password provider -} - -func (provider Provider) SetupRoutes(routes *gin.RouterGroup) { - // nothing todo here -} - -func (provider Provider) Login(ctx *authentication.AuthContext) (string, error) { - username := strings.ToLower(ctx.Username) - password := ctx.Password - - // Validate input - if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { - return "", errors.New("empty username or password") - } - - client, err := provider.open() - if err != nil { - return "", errors.Wrap(err, "unable to open ldap connection") - } - defer provider.close(client) - - // Search for the given username - attrs := []string{"dn", provider.config.EmailAttribute} - loginFilter := strings.Replace(provider.config.LoginFilter, "{{login_identifier}}", username, -1) - searchRequest := ldap.NewSearchRequest( - provider.config.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - loginFilter, - attrs, - nil, - ) - - sr, err := client.Search(searchRequest) - if err != nil { - return "", errors.Wrap(err, "unable to find user in ldap") - } - - if len(sr.Entries) != 1 { - return "", errors.Errorf("invalid amount of ldap entries (%d)", len(sr.Entries)) - } - - // Bind as the user to verify their password - userDN := sr.Entries[0].DN - err = client.Bind(userDN, password) - if err != nil { - return "", errors.Wrapf(err, "invalid credentials") - } - - return sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute), nil -} - -func (provider Provider) Logout(context *authentication.AuthContext) error { - return nil // nothing todo here -} - -func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authentication.User, error) { - username := strings.ToLower(ctx.Username) - - // Validate input - if strings.Trim(username, " ") == "" { - return nil, errors.New("empty username") - } - - client, err := provider.open() - if err != nil { - return nil, errors.Wrap(err, "unable to open ldap connection") - } - defer provider.close(client) - - // Search for the given username - attrs := []string{"dn", provider.config.EmailAttribute, provider.config.FirstNameAttribute, provider.config.LastNameAttribute, - provider.config.PhoneAttribute, provider.config.GroupMemberAttribute} - loginFilter := strings.Replace(provider.config.LoginFilter, "{{login_identifier}}", username, -1) - searchRequest := ldap.NewSearchRequest( - provider.config.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - loginFilter, - attrs, - nil, - ) - - sr, err := client.Search(searchRequest) - if err != nil { - return nil, errors.Wrap(err, "unable to find user in ldap") - } - - if len(sr.Entries) != 1 { - return nil, errors.Wrapf(err, "invalid amount of ldap entries (%d)", len(sr.Entries)) - } - - user := &authentication.User{ - Firstname: sr.Entries[0].GetAttributeValue(provider.config.FirstNameAttribute), - Lastname: sr.Entries[0].GetAttributeValue(provider.config.LastNameAttribute), - Email: sr.Entries[0].GetAttributeValue(provider.config.EmailAttribute), - Phone: sr.Entries[0].GetAttributeValue(provider.config.PhoneAttribute), - IsAdmin: false, - } - - for _, group := range sr.Entries[0].GetAttributeValues(provider.config.GroupMemberAttribute) { - if group == provider.config.AdminLdapGroup { - user.IsAdmin = true - break - } - } - - return user, nil -} - -func (provider Provider) open() (*ldap.Conn, error) { - tlsConfig := &tls.Config{InsecureSkipVerify: !provider.config.CertValidation} - conn, err := ldap.DialURL(provider.config.URL, ldap.DialWithTLSConfig(tlsConfig)) - if err != nil { - return nil, err - } - - if provider.config.StartTLS { - // Reconnect with TLS - err = conn.StartTLS(tlsConfig) - if err != nil { - return nil, err - } - } - - err = conn.Bind(provider.config.BindUser, provider.config.BindPass) - if err != nil { - return nil, err - } - - return conn, nil -} - -func (provider Provider) close(conn *ldap.Conn) { - if conn != nil { - conn.Close() - } -} diff --git a/internal/authentication/providers/password/provider.go b/internal/authentication/providers/password/provider.go deleted file mode 100644 index e185ac1..0000000 --- a/internal/authentication/providers/password/provider.go +++ /dev/null @@ -1,195 +0,0 @@ -package password - -import ( - "fmt" - "math/rand" - "regexp" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/authentication" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/pkg/errors" - "golang.org/x/crypto/bcrypt" - "gorm.io/gorm" -) - -var emailRegex = regexp.MustCompile("^[a-zA-Z0-9.!#$%&'*+\\/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$") - -// Provider implements a password login method for a database backend. -type Provider struct { - db *gorm.DB -} - -func New(cfg *common.DatabaseConfig) (*Provider, error) { - p := &Provider{} - - var err error - p.db, err = common.GetDatabaseForConfig(cfg) - if err != nil { - return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database) - } - - return p, nil -} - -// GetName return provider name -func (Provider) GetName() string { - return string(users.UserSourceDatabase) -} - -// GetType return provider type -func (Provider) GetType() authentication.AuthProviderType { - return authentication.AuthProviderTypePassword -} - -// GetPriority return provider priority -func (Provider) GetPriority() int { - return 0 // DB password provider = highest prio -} - -func (provider Provider) SetupRoutes(routes *gin.RouterGroup) { - // nothing todo here -} - -func (provider Provider) Login(ctx *authentication.AuthContext) (string, error) { - username := strings.ToLower(ctx.Username) - password := ctx.Password - - // Validate input - if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { - return "", errors.New("empty username or password") - } - - // Authenticate against the users database - user := users.User{} - provider.db.Where("email = ?", username).First(&user) - - if user.Email == "" { - return "", errors.New("invalid username") - } - - // Compare the stored hashed password, with the hashed version of the password that was received - if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { - return "", errors.New("invalid password") - } - - return user.Email, nil -} - -func (provider Provider) Logout(context *authentication.AuthContext) error { - return nil // nothing todo here -} - -func (provider Provider) GetUserModel(ctx *authentication.AuthContext) (*authentication.User, error) { - username := strings.ToLower(ctx.Username) - - // Validate input - if strings.Trim(username, " ") == "" { - return nil, errors.New("empty username") - } - - // Fetch usermodel from users database - user := users.User{} - provider.db.Where("email = ?", username).First(&user) - if user.Email != username { - return nil, errors.New("invalid or disabled username") - } - - return &authentication.User{ - Email: user.Email, - IsAdmin: user.IsAdmin, - Firstname: user.Firstname, - Lastname: user.Lastname, - Phone: user.Phone, - }, nil -} - -func (provider Provider) InitializeAdmin(email, password string) error { - email = strings.ToLower(email) - if !emailRegex.MatchString(email) { - return errors.New("admin username must be an email address") - } - - admin := users.User{} - provider.db.Unscoped().Where("email = ?", email).FirstOrInit(&admin) - - // newly created admin - if admin.Email != email { - // For security reasons a random admin password will be generated if the default one is still in use! - if password == "wgportal" { - password = generateRandomPassword() - - fmt.Println("#############################################") - fmt.Println("Administrator credentials:") - fmt.Println(" Email: ", email) - fmt.Println(" Password: ", password) - fmt.Println() - fmt.Println("This information will only be displayed once!") - fmt.Println("#############################################") - } - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "failed to hash admin password") - } - - admin.Email = email - admin.Password = users.PrivateString(hashedPassword) - admin.Firstname = "WireGuard" - admin.Lastname = "Administrator" - admin.CreatedAt = time.Now() - admin.UpdatedAt = time.Now() - admin.IsAdmin = true - admin.Source = users.UserSourceDatabase - - res := provider.db.Create(admin) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to create admin %s", admin.Email) - } - } - - // update/reactivate - if !admin.IsAdmin || admin.DeletedAt.Valid { - // For security reasons a random admin password will be generated if the default one is still in use! - if password == "wgportal" { - password = generateRandomPassword() - - fmt.Println("#############################################") - fmt.Println("Administrator credentials:") - fmt.Println(" Email: ", email) - fmt.Println(" Password: ", password) - fmt.Println() - fmt.Println("This information will only be displayed once!") - fmt.Println("#############################################") - } - - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "failed to hash admin password") - } - - admin.Password = users.PrivateString(hashedPassword) - admin.IsAdmin = true - admin.UpdatedAt = time.Now() - - res := provider.db.Save(admin) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to update admin %s", admin.Email) - } - } - - return nil -} - -func generateRandomPassword() string { - rand.Seed(time.Now().Unix()) - var randPassword strings.Builder - charSet := "abcdedfghijklmnopqrstABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!#$" - for i := 0; i < 12; i++ { - random := rand.Intn(len(charSet)) - randPassword.WriteString(string(charSet[random])) - } - return randPassword.String() -} diff --git a/internal/authentication/user.go b/internal/authentication/user.go deleted file mode 100644 index a5afcfc..0000000 --- a/internal/authentication/user.go +++ /dev/null @@ -1,12 +0,0 @@ -package authentication - -// User represents the data that can be retrieved from authentication backends. -type User struct { - Email string - IsAdmin bool - - // optional fields - Firstname string - Lastname string - Phone string -} diff --git a/internal/common/db.go b/internal/common/db.go deleted file mode 100644 index d8122be..0000000 --- a/internal/common/db.go +++ /dev/null @@ -1,160 +0,0 @@ -package common - -import ( - "fmt" - "os" - "path/filepath" - "sort" - "time" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "gorm.io/driver/mysql" - "gorm.io/driver/sqlite" - "gorm.io/gorm" - "gorm.io/gorm/logger" -) - -func init() { - migrations = append(migrations, Migration{ - version: "1.0.7", - migrateFn: func(db *gorm.DB) error { - if err := db.Exec("UPDATE users SET email = LOWER(email)").Error; err != nil { - return errors.Wrap(err, "failed to convert user emails to lower case") - } - if err := db.Exec("UPDATE peers SET email = LOWER(email)").Error; err != nil { - return errors.Wrap(err, "failed to convert peer emails to lower case") - } - logrus.Infof("upgraded database format to version 1.0.7") - return nil - }, - }) - migrations = append(migrations, Migration{ - version: "1.0.8", - migrateFn: func(db *gorm.DB) error { - logrus.Infof("upgraded database format to version 1.0.8") - return nil - }, - }) -} - -type SupportedDatabase string - -const ( - SupportedDatabaseMySQL SupportedDatabase = "mysql" - SupportedDatabaseSQLite SupportedDatabase = "sqlite" -) - -type DatabaseConfig struct { - Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite - Host string `yaml:"host" envconfig:"DATABASE_HOST"` - Port int `yaml:"port" envconfig:"DATABASE_PORT"` - Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name - User string `yaml:"user" envconfig:"DATABASE_USERNAME"` - Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"` -} - -func GetDatabaseForConfig(cfg *DatabaseConfig) (db *gorm.DB, err error) { - switch cfg.Typ { - case SupportedDatabaseSQLite: - if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) { - if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil { - return - } - } - db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}) - if err != nil { - return - } - case SupportedDatabaseMySQL: - connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database) - db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{}) - if err != nil { - return - } - - sqlDB, _ := db.DB() - sqlDB.SetConnMaxLifetime(time.Minute * 5) - sqlDB.SetMaxIdleConns(2) - sqlDB.SetMaxOpenConns(10) - err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible - if err != nil { - return nil, errors.Wrap(err, "failed to ping mysql authentication database") - } - } - - // Enable Logger (logrus) - logCfg := logger.Config{ - SlowThreshold: time.Second, // all slower than one second - Colorful: false, - LogLevel: logger.Silent, // default: log nothing - } - - if logrus.StandardLogger().GetLevel() == logrus.TraceLevel { - logCfg.LogLevel = logger.Info - logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second - } - - db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg) - return -} - -type DatabaseMigrationInfo struct { - Version string `gorm:"primaryKey"` - Applied time.Time -} - -type Migration struct { - version string - migrateFn func(db *gorm.DB) error -} - -var migrations []Migration - -func MigrateDatabase(db *gorm.DB, version string) error { - if err := db.AutoMigrate(&DatabaseMigrationInfo{}); err != nil { - return errors.Wrap(err, "failed to migrate version database") - } - - existingMigration := DatabaseMigrationInfo{} - db.Where("version = ?", version).FirstOrInit(&existingMigration) - - if existingMigration.Version == "" { - lastVersion := DatabaseMigrationInfo{} - db.Order("applied desc, version desc").FirstOrInit(&lastVersion) - - if lastVersion.Version == "" { - // fresh database, no migrations to apply - res := db.Create(&DatabaseMigrationInfo{ - Version: version, - Applied: time.Now(), - }) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to write version %s to database", version) - } - return nil - } - - sort.Slice(migrations, func(i, j int) bool { - return migrations[i].version < migrations[j].version - }) - - for _, migration := range migrations { - if migration.version > lastVersion.Version { - if err := migration.migrateFn(db); err != nil { - return errors.Wrapf(err, "failed to migrate to version %s", migration.version) - } - - res := db.Create(&DatabaseMigrationInfo{ - Version: migration.version, - Applied: time.Now(), - }) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to write version %s to database", migration.version) - } - } - } - } - - return nil -} diff --git a/internal/common/email.go b/internal/common/email.go deleted file mode 100644 index c72f667..0000000 --- a/internal/common/email.go +++ /dev/null @@ -1,117 +0,0 @@ -package common - -import ( - "crypto/tls" - "io" - "io/ioutil" - "time" - - "github.com/pkg/errors" - mail "github.com/xhit/go-simple-mail/v2" -) - -type MailEncryption string - -const ( - MailEncryptionNone MailEncryption = "none" - MailEncryptionTLS MailEncryption = "tls" - MailEncryptionStartTLS MailEncryption = "starttls" -) - -type MailAuthType string - -const ( - MailAuthPlain MailAuthType = "plain" - MailAuthLogin MailAuthType = "login" - MailAuthCramMD5 MailAuthType = "crammd5" -) - -type MailConfig struct { - Host string `yaml:"host" envconfig:"EMAIL_HOST"` - Port int `yaml:"port" envconfig:"EMAIL_PORT"` - TLS bool `yaml:"tls" envconfig:"EMAIL_TLS"` // Deprecated, use MailConfig.Encryption instead. - Encryption MailEncryption `yaml:"encryption" envconfig:"EMAIL_ENCRYPTION"` - CertValidation bool `yaml:"certcheck" envconfig:"EMAIL_CERT_VALIDATION"` - Username string `yaml:"user" envconfig:"EMAIL_USERNAME"` - Password string `yaml:"pass" envconfig:"EMAIL_PASSWORD"` - AuthType MailAuthType `yaml:"auth" envconfig:"EMAIL_AUTHTYPE"` -} - -type MailAttachment struct { - Name string - ContentType string - Data io.Reader - Embedded bool -} - -// SendEmailWithAttachments sends a mail with optional attachments. -func SendEmailWithAttachments(cfg MailConfig, sender, replyTo, subject, body, htmlBody string, receivers []string, attachments []MailAttachment) error { - srv := mail.NewSMTPClient() - - srv.ConnectTimeout = 30 * time.Second - srv.SendTimeout = 30 * time.Second - srv.Host = cfg.Host - srv.Port = cfg.Port - srv.Username = cfg.Username - srv.Password = cfg.Password - - // TODO: remove this once the deprecated MailConfig.TLS config option has been removed - if cfg.TLS { - cfg.Encryption = MailEncryptionStartTLS - } - switch cfg.Encryption { - case MailEncryptionTLS: - srv.Encryption = mail.EncryptionSSLTLS - case MailEncryptionStartTLS: - srv.Encryption = mail.EncryptionSTARTTLS - default: // MailEncryptionNone - srv.Encryption = mail.EncryptionNone - } - srv.TLSConfig = &tls.Config{ServerName: srv.Host, InsecureSkipVerify: !cfg.CertValidation} - switch cfg.AuthType { - case MailAuthPlain: - srv.Authentication = mail.AuthPlain - case MailAuthLogin: - srv.Authentication = mail.AuthLogin - case MailAuthCramMD5: - srv.Authentication = mail.AuthCRAMMD5 - } - - client, err := srv.Connect() - if err != nil { - return errors.Wrap(err, "failed to connect via SMTP") - } - - if replyTo == "" { - replyTo = sender - } - - email := mail.NewMSG() - email.SetFrom(sender). - AddTo(receivers...). - SetReplyTo(replyTo). - SetSubject(subject) - - email.SetBody(mail.TextHTML, htmlBody) - email.AddAlternative(mail.TextPlain, body) - - for _, attachment := range attachments { - attachmentData, err := ioutil.ReadAll(attachment.Data) - if err != nil { - return errors.Wrapf(err, "failed to read attachment data for %s", attachment.Name) - } - - if attachment.Embedded { - email.AddInlineData(attachmentData, attachment.Name, attachment.ContentType) - } else { - email.AddAttachmentData(attachmentData, attachment.Name, attachment.ContentType) - } - } - - // Call Send and pass the client - err = email.Send(client) - if err != nil { - return errors.Wrapf(err, "failed to send email") - } - return nil -} diff --git a/internal/common/util.go b/internal/common/util.go deleted file mode 100644 index bbde700..0000000 --- a/internal/common/util.go +++ /dev/null @@ -1,86 +0,0 @@ -package common - -import ( - "fmt" - "net" - "strings" -) - -// BroadcastAddr returns the last address in the given network, or the broadcast address. -func BroadcastAddr(n *net.IPNet) net.IP { - // The golang net package doesn't make it easy to calculate the broadcast address. :( - var broadcast net.IP - if len(n.IP) == 4 { - broadcast = net.ParseIP("0.0.0.0").To4() - } else { - broadcast = net.ParseIP("::") - } - for i := 0; i < len(n.IP); i++ { - broadcast[i] = n.IP[i] | ^n.Mask[i] - } - return broadcast -} - -// http://play.golang.org/p/m8TNTtygK0 -func IncreaseIP(ip net.IP) { - for j := len(ip) - 1; j >= 0; j-- { - ip[j]++ - if ip[j] > 0 { - break - } - } -} - -// IsIPv6 check if given ip is IPv6 -func IsIPv6(address string) bool { - ip := net.ParseIP(address) - if ip == nil { - return false - } - return ip.To4() == nil -} - -// ParseStringList converts a comma separated string into a list of strings. -// It also trims spaces from each element of the list. -func ParseStringList(lst string) []string { - tokens := strings.Split(lst, ",") - validatedTokens := make([]string, 0, len(tokens)) - for i := range tokens { - tokens[i] = strings.TrimSpace(tokens[i]) - if tokens[i] != "" { - validatedTokens = append(validatedTokens, tokens[i]) - } - } - - return validatedTokens -} - -// ListToString converts a list of strings into a comma separated string. -func ListToString(lst []string) string { - return strings.Join(lst, ", ") -} - -// ListContains checks if a needle exists in the given list. -func ListContains(lst []string, needle string) bool { - for _, entry := range lst { - if entry == needle { - return true - } - } - return false -} - -// https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/ -func ByteCountSI(b int64) string { - const unit = 1000 - if b < unit { - return fmt.Sprintf("%d B", b) - } - div, exp := int64(unit), 0 - for n := b / unit; n >= unit; n /= unit { - div *= unit - exp++ - } - return fmt.Sprintf("%.1f %cB", - float64(b)/float64(div), "kMGTPE"[exp]) -} diff --git a/internal/ldap/config.go b/internal/ldap/config.go deleted file mode 100644 index 4d581be..0000000 --- a/internal/ldap/config.go +++ /dev/null @@ -1,27 +0,0 @@ -package ldap - -type Type string - -const ( - TypeActiveDirectory Type = "AD" - TypeOpenLDAP Type = "OpenLDAP" -) - -type Config struct { - URL string `yaml:"url" envconfig:"LDAP_URL"` - StartTLS bool `yaml:"startTLS" envconfig:"LDAP_STARTTLS"` - CertValidation bool `yaml:"certcheck" envconfig:"LDAP_CERT_VALIDATION"` - BaseDN string `yaml:"dn" envconfig:"LDAP_BASEDN"` - BindUser string `yaml:"user" envconfig:"LDAP_USER"` - BindPass string `yaml:"pass" envconfig:"LDAP_PASSWORD"` - - EmailAttribute string `yaml:"attrEmail" envconfig:"LDAP_ATTR_EMAIL"` - FirstNameAttribute string `yaml:"attrFirstname" envconfig:"LDAP_ATTR_FIRSTNAME"` - LastNameAttribute string `yaml:"attrLastname" envconfig:"LDAP_ATTR_LASTNAME"` - PhoneAttribute string `yaml:"attrPhone" envconfig:"LDAP_ATTR_PHONE"` - GroupMemberAttribute string `yaml:"attrGroups" envconfig:"LDAP_ATTR_GROUPS"` - - LoginFilter string `yaml:"loginFilter" envconfig:"LDAP_LOGIN_FILTER"` // {{login_identifier}} gets replaced with the login email address - SyncFilter string `yaml:"syncFilter" envconfig:"LDAP_SYNC_FILTER"` - AdminLdapGroup string `yaml:"adminGroup" envconfig:"LDAP_ADMIN_GROUP"` // Members of this group receive admin rights in WG-Portal -} diff --git a/internal/ldap/ldap.go b/internal/ldap/ldap.go deleted file mode 100644 index 38af07b..0000000 --- a/internal/ldap/ldap.go +++ /dev/null @@ -1,84 +0,0 @@ -package ldap - -import ( - "crypto/tls" - - "github.com/go-ldap/ldap/v3" - "github.com/pkg/errors" -) - -type RawLdapData struct { - DN string - Attributes map[string]string - RawAttributes map[string][][]byte -} - -func Open(cfg *Config) (*ldap.Conn, error) { - tlsConfig := &tls.Config{InsecureSkipVerify: !cfg.CertValidation} - conn, err := ldap.DialURL(cfg.URL, ldap.DialWithTLSConfig(tlsConfig)) - if err != nil { - return nil, errors.Wrap(err, "failed to connect to LDAP") - } - - if cfg.StartTLS { - // Reconnect with TLS - err = conn.StartTLS(tlsConfig) - if err != nil { - return nil, errors.Wrap(err, "failed to star TLS on connection") - } - } - - err = conn.Bind(cfg.BindUser, cfg.BindPass) - if err != nil { - return nil, errors.Wrap(err, "failed to bind to LDAP") - } - - return conn, nil -} - -func Close(conn *ldap.Conn) { - if conn != nil { - conn.Close() - } -} - -func FindAllUsers(cfg *Config) ([]RawLdapData, error) { - client, err := Open(cfg) - if err != nil { - return nil, errors.WithMessage(err, "failed to open ldap connection") - } - defer Close(client) - - // Search all users - attrs := []string{"dn", cfg.EmailAttribute, cfg.EmailAttribute, cfg.FirstNameAttribute, cfg.LastNameAttribute, - cfg.PhoneAttribute, cfg.GroupMemberAttribute} - searchRequest := ldap.NewSearchRequest( - cfg.BaseDN, - ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, - cfg.SyncFilter, attrs, nil, - ) - - sr, err := client.Search(searchRequest) - if err != nil { - return nil, errors.Wrapf(err, "failed to search in ldap") - } - - tmpData := make([]RawLdapData, 0, len(sr.Entries)) - - for _, entry := range sr.Entries { - tmp := RawLdapData{ - DN: entry.DN, - Attributes: make(map[string]string, len(attrs)), - RawAttributes: make(map[string][][]byte, len(attrs)), - } - - for _, field := range attrs { - tmp.Attributes[field] = entry.GetAttributeValue(field) - tmp.RawAttributes[field] = entry.GetRawAttributeValues(field) - } - - tmpData = append(tmpData, tmp) - } - - return tmpData, nil -} diff --git a/internal/server/api.go b/internal/server/api.go deleted file mode 100644 index aa38105..0000000 --- a/internal/server/api.go +++ /dev/null @@ -1,926 +0,0 @@ -package server - -// go get -u github.com/swaggo/swag/cmd/swag -// run: swag init --parseDependency --parseInternal --generalInfo api.go -// in the internal/server folder -import ( - "encoding/json" - "net/http" - "strings" - "time" - - jsonpatch "github.com/evanphx/json-patch" - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/h44z/wg-portal/internal/wireguard" -) - -// @title WireGuard Portal API -// @version 1.0 -// @description WireGuard Portal API for managing users and peers. - -// @license.name MIT -// @license.url https://github.com/h44z/wg-portal/blob/master/LICENSE.txt - -// @contact.name WireGuard Portal Project -// @contact.url https://github.com/h44z/wg-portal - -// @securityDefinitions.basic ApiBasicAuth -// @in header -// @name Authorization -// @scope.admin Admin access required - -// @securityDefinitions.basic GeneralBasicAuth -// @in header -// @name Authorization -// @scope.user User access required - -// @BasePath /api/v1 - -// ApiServer is a simple wrapper struct so that we can have fresh member function names. -type ApiServer struct { - s *Server -} - -type ApiError struct { - Message string -} - -// GetUsers godoc -// @Tags Users -// @Summary Retrieves all users -// @Produce json -// @Success 200 {object} []users.User -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/users [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetUsers(c *gin.Context) { - allUsers := s.s.users.GetUsersUnscoped() - - c.JSON(http.StatusOK, allUsers) -} - -// GetUser godoc -// @Tags Users -// @Summary Retrieves user based on given Email -// @Produce json -// @Param email query string true "User Email" -// @Success 200 {object} users.User -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/user [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetUser(c *gin.Context) { - email := strings.ToLower(strings.TrimSpace(c.Query("email"))) - if email == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must be specified"}) - return - } - - user := s.s.users.GetUserUnscoped(email) - if user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user not found"}) - return - } - c.JSON(http.StatusOK, user) -} - -// PostUser godoc -// @Tags Users -// @Summary Creates a new user based on the given user model -// @Accept json -// @Produce json -// @Param user body users.User true "User Model" -// @Success 200 {object} users.User -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/users [post] -// @Security ApiBasicAuth -func (s *ApiServer) PostUser(c *gin.Context) { - newUser := users.User{} - if err := c.BindJSON(&newUser); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - if user := s.s.users.GetUserUnscoped(newUser.Email); user != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: "user already exists"}) - return - } - - if err := s.s.CreateUser(newUser, s.s.wg.Cfg.GetDefaultDeviceName()); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - user := s.s.users.GetUserUnscoped(newUser.Email) - if user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user not found"}) - return - } - c.JSON(http.StatusOK, user) -} - -// PutUser godoc -// @Tags Users -// @Summary Updates a user based on the given user model -// @Accept json -// @Produce json -// @Param email query string true "User Email" -// @Param user body users.User true "User Model" -// @Success 200 {object} users.User -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/user [put] -// @Security ApiBasicAuth -func (s *ApiServer) PutUser(c *gin.Context) { - email := strings.ToLower(strings.TrimSpace(c.Query("email"))) - if email == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must be specified"}) - return - } - - updateUser := users.User{} - if err := c.BindJSON(&updateUser); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - // Changing email address is not allowed - if email != updateUser.Email { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must match the model email address"}) - return - } - - if user := s.s.users.GetUserUnscoped(email); user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user does not exist"}) - return - } - - if err := s.s.UpdateUser(updateUser); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - user := s.s.users.GetUserUnscoped(email) - if user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user not found"}) - return - } - c.JSON(http.StatusOK, user) -} - -// PatchUser godoc -// @Tags Users -// @Summary Updates a user based on the given partial user model -// @Accept json -// @Produce json -// @Param email query string true "User Email" -// @Param user body users.User true "User Model" -// @Success 200 {object} users.User -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/user [patch] -// @Security ApiBasicAuth -func (s *ApiServer) PatchUser(c *gin.Context) { - email := strings.ToLower(strings.TrimSpace(c.Query("email"))) - if email == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must be specified"}) - return - } - - patch, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - user := s.s.users.GetUserUnscoped(email) - if user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user does not exist"}) - return - } - userData, err := json.Marshal(user) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - mergedUserData, err := jsonpatch.MergePatch(userData, patch) - var mergedUser users.User - err = json.Unmarshal(mergedUserData, &mergedUser) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - // CHanging email address is not allowed - if email != mergedUser.Email { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must match the model email address"}) - return - } - - if err := s.s.UpdateUser(mergedUser); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - user = s.s.users.GetUserUnscoped(email) - if user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user not found"}) - return - } - c.JSON(http.StatusOK, user) -} - -// DeleteUser godoc -// @Tags Users -// @Summary Deletes the specified user -// @Produce json -// @Param email query string true "User Email" -// @Success 204 "No content" -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/user [delete] -// @Security ApiBasicAuth -func (s *ApiServer) DeleteUser(c *gin.Context) { - email := strings.ToLower(strings.TrimSpace(c.Query("email"))) - if email == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must be specified"}) - return - } - - var user *users.User - if user = s.s.users.GetUserUnscoped(email); user == nil { - c.JSON(http.StatusNotFound, ApiError{Message: "user does not exist"}) - return - } - - if err := s.s.DeleteUser(*user); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - c.Status(http.StatusNoContent) -} - -// GetPeers godoc -// @Tags Peers -// @Summary Retrieves all peers for the given interface -// @Produce json -// @Param device query string true "Device Name" -// @Success 200 {object} []wireguard.Peer -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/peers [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetPeers(c *gin.Context) { - deviceName := strings.ToLower(strings.TrimSpace(c.Query("device"))) - if deviceName == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must be specified"}) - return - } - - // validate device name - if !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - c.JSON(http.StatusNotFound, ApiError{Message: "unknown device"}) - return - } - - peers := s.s.peers.GetAllPeers(deviceName) - c.JSON(http.StatusOK, peers) -} - -// GetPeer godoc -// @Tags Peers -// @Summary Retrieves the peer for the given public key -// @Produce json -// @Param pkey query string true "Public Key (Base 64)" -// @Success 200 {object} wireguard.Peer -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/peer [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetPeer(c *gin.Context) { - pkey := c.Query("pkey") - if pkey == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must be specified"}) - return - } - - peer := s.s.peers.GetPeerByKey(pkey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer does not exist"}) - return - } - c.JSON(http.StatusOK, peer) -} - -// PostPeer godoc -// @Tags Peers -// @Summary Creates a new peer based on the given peer model -// @Accept json -// @Produce json -// @Param device query string true "Device Name" -// @Param peer body wireguard.Peer true "Peer Model" -// @Success 200 {object} wireguard.Peer -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/peers [post] -// @Security ApiBasicAuth -func (s *ApiServer) PostPeer(c *gin.Context) { - deviceName := strings.ToLower(strings.TrimSpace(c.Query("device"))) - if deviceName == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must be specified"}) - return - } - - // validate device name - if !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - c.JSON(http.StatusNotFound, ApiError{Message: "unknown device"}) - return - } - - newPeer := wireguard.Peer{} - if err := c.BindJSON(&newPeer); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - if peer := s.s.peers.GetPeerByKey(newPeer.PublicKey); peer.IsValid() { - c.JSON(http.StatusBadRequest, ApiError{Message: "peer already exists"}) - return - } - - if err := s.s.CreatePeer(deviceName, newPeer); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - peer := s.s.peers.GetPeerByKey(newPeer.PublicKey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer not found"}) - return - } - c.JSON(http.StatusOK, peer) -} - -// PutPeer godoc -// @Tags Peers -// @Summary Updates the given peer based on the given peer model -// @Accept json -// @Produce json -// @Param pkey query string true "Public Key" -// @Param peer body wireguard.Peer true "Peer Model" -// @Success 200 {object} wireguard.Peer -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/peer [put] -// @Security ApiBasicAuth -func (s *ApiServer) PutPeer(c *gin.Context) { - updatePeer := wireguard.Peer{} - if err := c.BindJSON(&updatePeer); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - pkey := c.Query("pkey") - if pkey == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must be specified"}) - return - } - - if peer := s.s.peers.GetPeerByKey(pkey); !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer does not exist"}) - return - } - - // Changing public key is not allowed - if pkey != updatePeer.PublicKey { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must match the model public key"}) - return - } - - now := time.Now() - if updatePeer.DeactivatedAt != nil { - updatePeer.DeactivatedAt = &now - } - if err := s.s.UpdatePeer(updatePeer, now); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - peer := s.s.peers.GetPeerByKey(updatePeer.PublicKey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer not found"}) - return - } - c.JSON(http.StatusOK, peer) -} - -// PatchPeer godoc -// @Tags Peers -// @Summary Updates the given peer based on the given partial peer model -// @Accept json -// @Produce json -// @Param pkey query string true "Public Key" -// @Param peer body wireguard.Peer true "Peer Model" -// @Success 200 {object} wireguard.Peer -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/peer [patch] -// @Security ApiBasicAuth -func (s *ApiServer) PatchPeer(c *gin.Context) { - patch, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - pkey := c.Query("pkey") - if pkey == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must be specified"}) - return - } - - peer := s.s.peers.GetPeerByKey(pkey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer does not exist"}) - return - } - - peerData, err := json.Marshal(peer) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - mergedPeerData, err := jsonpatch.MergePatch(peerData, patch) - var mergedPeer wireguard.Peer - err = json.Unmarshal(mergedPeerData, &mergedPeer) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - if !mergedPeer.IsValid() { - c.JSON(http.StatusBadRequest, ApiError{Message: "invalid peer model"}) - return - } - - // Changing public key is not allowed - if pkey != mergedPeer.PublicKey { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must match the model public key"}) - return - } - - now := time.Now() - if mergedPeer.DeactivatedAt != nil { - mergedPeer.DeactivatedAt = &now - } - if err := s.s.UpdatePeer(mergedPeer, now); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - peer = s.s.peers.GetPeerByKey(mergedPeer.PublicKey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer not found"}) - return - } - c.JSON(http.StatusOK, peer) -} - -// DeletePeer godoc -// @Tags Peers -// @Summary Updates the given peer based on the given partial peer model -// @Produce json -// @Param pkey query string true "Public Key" -// @Success 202 "No Content" -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/peer [delete] -// @Security ApiBasicAuth -func (s *ApiServer) DeletePeer(c *gin.Context) { - pkey := c.Query("pkey") - if pkey == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must be specified"}) - return - } - - peer := s.s.peers.GetPeerByKey(pkey) - if peer.PublicKey == "" { - c.JSON(http.StatusNotFound, ApiError{Message: "peer does not exist"}) - return - } - - if err := s.s.DeletePeer(peer); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - c.Status(http.StatusNoContent) -} - -// GetDevices godoc -// @Tags Interface -// @Summary Get all devices -// @Produce json -// @Success 200 {object} []wireguard.Device -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/devices [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetDevices(c *gin.Context) { - var devices []wireguard.Device - for _, deviceName := range s.s.config.WG.DeviceNames { - device := s.s.peers.GetDevice(deviceName) - if !device.IsValid() { - continue - } - devices = append(devices, device) - } - - c.JSON(http.StatusOK, devices) -} - -// GetDevice godoc -// @Tags Interface -// @Summary Get the given device -// @Produce json -// @Param device query string true "Device Name" -// @Success 200 {object} wireguard.Device -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /backend/device [get] -// @Security ApiBasicAuth -func (s *ApiServer) GetDevice(c *gin.Context) { - deviceName := strings.ToLower(strings.TrimSpace(c.Query("device"))) - if deviceName == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must be specified"}) - return - } - - // validate device name - if !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - c.JSON(http.StatusNotFound, ApiError{Message: "unknown device"}) - return - } - - device := s.s.peers.GetDevice(deviceName) - if !device.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "device not found"}) - return - } - - c.JSON(http.StatusOK, device) -} - -// PutDevice godoc -// @Tags Interface -// @Summary Updates the given device based on the given device model (UNIMPLEMENTED) -// @Accept json -// @Produce json -// @Param device query string true "Device Name" -// @Param body body wireguard.Device true "Device Model" -// @Success 200 {object} wireguard.Device -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/device [put] -// @Security ApiBasicAuth -func (s *ApiServer) PutDevice(c *gin.Context) { - updateDevice := wireguard.Device{} - if err := c.BindJSON(&updateDevice); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - deviceName := strings.ToLower(strings.TrimSpace(c.Query("device"))) - if deviceName == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must be specified"}) - return - } - - // validate device name - if !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - c.JSON(http.StatusNotFound, ApiError{Message: "unknown device"}) - return - } - - device := s.s.peers.GetDevice(deviceName) - if !device.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer not found"}) - return - } - - // Changing device name is not allowed - if deviceName != updateDevice.DeviceName { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must match the model device name"}) - return - } - - // TODO: implement - - c.JSON(http.StatusNotImplemented, device) -} - -// PatchDevice godoc -// @Tags Interface -// @Summary Updates the given device based on the given partial device model (UNIMPLEMENTED) -// @Accept json -// @Produce json -// @Param device query string true "Device Name" -// @Param body body wireguard.Device true "Device Model" -// @Success 200 {object} wireguard.Device -// @Failure 400 {object} ApiError -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Failure 500 {object} ApiError -// @Router /backend/device [patch] -// @Security ApiBasicAuth -func (s *ApiServer) PatchDevice(c *gin.Context) { - patch, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - deviceName := strings.ToLower(strings.TrimSpace(c.Query("device"))) - if deviceName == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must be specified"}) - return - } - - // validate device name - if !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - c.JSON(http.StatusNotFound, ApiError{Message: "unknown device"}) - return - } - - device := s.s.peers.GetDevice(deviceName) - if !device.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer not found"}) - return - } - - deviceData, err := json.Marshal(device) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - mergedDeviceData, err := jsonpatch.MergePatch(deviceData, patch) - var mergedDevice wireguard.Device - err = json.Unmarshal(mergedDeviceData, &mergedDevice) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - if !mergedDevice.IsValid() { - c.JSON(http.StatusBadRequest, ApiError{Message: "invalid device model"}) - return - } - - // Changing device name is not allowed - if deviceName != mergedDevice.DeviceName { - c.JSON(http.StatusBadRequest, ApiError{Message: "device parameter must match the model device name"}) - return - } - - // TODO: implement - - c.JSON(http.StatusNotImplemented, device) -} - -type PeerDeploymentInformation struct { - PublicKey string - Identifier string - Device string - DeviceIdentifier string -} - -// GetPeerDeploymentInformation godoc -// @Tags Provisioning -// @Summary Retrieves all active peers for the given email address -// @Produce json -// @Param email query string true "Email Address" -// @Success 200 {object} []PeerDeploymentInformation "All active WireGuard peers" -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /provisioning/peers [get] -// @Security GeneralBasicAuth -func (s *ApiServer) GetPeerDeploymentInformation(c *gin.Context) { - email := c.Query("email") - if email == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "email parameter must be specified"}) - return - } - - // Get authenticated user to check permissions - username, _, _ := c.Request.BasicAuth() - user := s.s.users.GetUser(username) - - if !user.IsAdmin && user.Email != email { - c.JSON(http.StatusForbidden, ApiError{Message: "not enough permissions to access this resource"}) - return - } - - peers := s.s.peers.GetPeersByMail(email) - result := make([]PeerDeploymentInformation, 0, len(peers)) - for i := range peers { - if peers[i].DeactivatedAt != nil { - continue // skip deactivated peers - } - - device := s.s.peers.GetDevice(peers[i].DeviceName) - if device.Type != wireguard.DeviceTypeServer { - continue // Skip peers on non-server devices - } - - result = append(result, PeerDeploymentInformation{ - PublicKey: peers[i].PublicKey, - Identifier: peers[i].Identifier, - Device: device.DeviceName, - DeviceIdentifier: device.DisplayName, - }) - } - - c.JSON(http.StatusOK, result) -} - -// GetPeerDeploymentConfig godoc -// @Tags Provisioning -// @Summary Retrieves the peer config for the given public key -// @Produce plain -// @Param pkey query string true "Public Key (Base 64)" -// @Success 200 {object} string "The WireGuard configuration file" -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /provisioning/peer [get] -// @Security GeneralBasicAuth -func (s *ApiServer) GetPeerDeploymentConfig(c *gin.Context) { - pkey := c.Query("pkey") - if pkey == "" { - c.JSON(http.StatusBadRequest, ApiError{Message: "pkey parameter must be specified"}) - return - } - - peer := s.s.peers.GetPeerByKey(pkey) - if !peer.IsValid() { - c.JSON(http.StatusNotFound, ApiError{Message: "peer does not exist"}) - return - } - - // Get authenticated user to check permissions - username, _, _ := c.Request.BasicAuth() - user := s.s.users.GetUser(username) - - if !user.IsAdmin && user.Email != peer.Email { - c.JSON(http.StatusForbidden, ApiError{Message: "not enough permissions to access this resource"}) - return - } - - device := s.s.peers.GetDevice(peer.DeviceName) - config, err := peer.GetConfigFile(device) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - c.Data(http.StatusOK, "text/plain", config) -} - -type ProvisioningRequest struct { - // DeviceName is optional, if not specified, the configured default device will be used. - DeviceName string `json:",omitempty"` - Identifier string `binding:"required"` - Email string `binding:"required"` - - // Client specific and optional settings - - AllowedIPsStr string `binding:"cidrlist" json:",omitempty"` - PersistentKeepalive int `binding:"gte=0" json:",omitempty"` - DNSStr string `binding:"iplist" json:",omitempty"` - Mtu int `binding:"gte=0,lte=1500" json:",omitempty"` -} - -// PostPeerDeploymentConfig godoc -// @Tags Provisioning -// @Summary Creates the requested peer config and returns the config file -// @Accept json -// @Produce plain -// @Param body body ProvisioningRequest true "Provisioning Request Model" -// @Success 200 {object} string "The WireGuard configuration file" -// @Failure 401 {object} ApiError -// @Failure 403 {object} ApiError -// @Failure 404 {object} ApiError -// @Router /provisioning/peers [post] -// @Security GeneralBasicAuth -func (s *ApiServer) PostPeerDeploymentConfig(c *gin.Context) { - req := ProvisioningRequest{} - if err := c.BindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, ApiError{Message: err.Error()}) - return - } - - // Get authenticated user to check permissions - username, _, _ := c.Request.BasicAuth() - user := s.s.users.GetUser(username) - - if !user.IsAdmin && !s.s.config.Core.SelfProvisioningAllowed { - c.JSON(http.StatusForbidden, ApiError{Message: "peer provisioning service disabled"}) - return - } - - if !user.IsAdmin && user.Email != req.Email { - c.JSON(http.StatusForbidden, ApiError{Message: "not enough permissions to access this resource"}) - return - } - - deviceName := req.DeviceName - if deviceName == "" || !common.ListContains(s.s.config.WG.DeviceNames, deviceName) { - deviceName = s.s.config.WG.GetDefaultDeviceName() - } - device := s.s.peers.GetDevice(deviceName) - if device.Type != wireguard.DeviceTypeServer { - c.JSON(http.StatusForbidden, ApiError{Message: "invalid device, provisioning disabled"}) - return - } - - // check if private/public keys are set, if so check database for existing entries - peer, err := s.s.PrepareNewPeer(deviceName) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - peer.Email = req.Email - peer.Identifier = req.Identifier - - if req.AllowedIPsStr != "" { - peer.AllowedIPsStr = req.AllowedIPsStr - } - if req.PersistentKeepalive != 0 { - peer.PersistentKeepalive = req.PersistentKeepalive - } - if req.DNSStr != "" { - peer.DNSStr = req.DNSStr - } - if req.Mtu != 0 { - peer.Mtu = req.Mtu - } - - if err := s.s.CreatePeer(deviceName, peer); err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - config, err := peer.GetConfigFile(device) - if err != nil { - c.JSON(http.StatusInternalServerError, ApiError{Message: err.Error()}) - return - } - - c.Data(http.StatusOK, "text/plain", config) -} diff --git a/internal/server/auth.go b/internal/server/auth.go deleted file mode 100644 index 0e1cc72..0000000 --- a/internal/server/auth.go +++ /dev/null @@ -1,90 +0,0 @@ -package server - -import ( - "sort" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/authentication" - "github.com/h44z/wg-portal/internal/users" - "github.com/sirupsen/logrus" -) - -// AuthManager keeps track of available authentication providers. -type AuthManager struct { - Server *Server - Group *gin.RouterGroup // basic group for all providers (/auth) - providers []authentication.AuthProvider - UserManager *users.Manager -} - -// RegisterProvider register auth provider -func (auth *AuthManager) RegisterProvider(provider authentication.AuthProvider) { - name := provider.GetName() - if auth.GetProvider(name) != nil { - logrus.Warnf("auth provider %v already registered", name) - } - - provider.SetupRoutes(auth.Group) - auth.providers = append(auth.providers, provider) -} - -// RegisterProviderWithoutError register auth provider if err is nil -func (auth *AuthManager) RegisterProviderWithoutError(provider authentication.AuthProvider, err error) { - if err != nil { - logrus.Errorf("skipping provider registration: %v", err) - return - } - auth.RegisterProvider(provider) -} - -// GetProvider get provider by name -func (auth *AuthManager) GetProvider(name string) authentication.AuthProvider { - for _, provider := range auth.providers { - if provider.GetName() == name { - return provider - } - } - return nil -} - -// GetProviders return registered providers. -// Returned providers are ordered by provider priority. -func (auth *AuthManager) GetProviders() (providers []authentication.AuthProvider) { - for _, provider := range auth.providers { - providers = append(providers, provider) - } - - // order by priority - sort.SliceStable(providers, func(i, j int) bool { - return providers[i].GetPriority() < providers[j].GetPriority() - }) - - return -} - -// GetProvidersForType return registered providers for the given type. -// Returned providers are ordered by provider priority. -func (auth *AuthManager) GetProvidersForType(typ authentication.AuthProviderType) (providers []authentication.AuthProvider) { - for _, provider := range auth.providers { - if provider.GetType() == typ { - providers = append(providers, provider) - } - } - - // order by priority - sort.SliceStable(providers, func(i, j int) bool { - return providers[i].GetPriority() < providers[j].GetPriority() - }) - - return -} - -func NewAuthManager(server *Server) *AuthManager { - m := &AuthManager{ - Server: server, - } - - m.Group = m.Server.server.Group("/auth") - - return m -} diff --git a/internal/server/configuration.go b/internal/server/configuration.go deleted file mode 100644 index 7442a4a..0000000 --- a/internal/server/configuration.go +++ /dev/null @@ -1,138 +0,0 @@ -package server - -import ( - "os" - "reflect" - "runtime" - - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/ldap" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/kelseyhightower/envconfig" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "gopkg.in/yaml.v3" -) - -var ErrInvalidSpecification = errors.New("specification must be a struct pointer") - -// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct. -func loadConfigFile(cfg interface{}, filename string) error { - s := reflect.ValueOf(cfg) - - if s.Kind() != reflect.Ptr { - return ErrInvalidSpecification - } - s = s.Elem() - if s.Kind() != reflect.Struct { - return ErrInvalidSpecification - } - - f, err := os.Open(filename) - if err != nil { - return errors.Wrapf(err, "failed to open config file %s", filename) - } - defer f.Close() - - decoder := yaml.NewDecoder(f) - err = decoder.Decode(cfg) - if err != nil { - return errors.Wrapf(err, "failed to decode config file %s", filename) - } - - return nil -} - -// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct. -func loadConfigEnv(cfg interface{}) error { - err := envconfig.Process("", cfg) - if err != nil { - return errors.Wrap(err, "failed to process environment config") - } - - return nil -} - -type Config struct { - Core struct { - ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"` - ExternalUrl string `yaml:"externalUrl" envconfig:"EXTERNAL_URL"` - Title string `yaml:"title" envconfig:"WEBSITE_TITLE"` - CompanyName string `yaml:"company" envconfig:"COMPANY_NAME"` - MailFrom string `yaml:"mailFrom" envconfig:"MAIL_FROM"` - AdminUser string `yaml:"adminUser" envconfig:"ADMIN_USER"` // must be an email address - AdminPassword string `yaml:"adminPass" envconfig:"ADMIN_PASS"` - EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"` - CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"` - SelfProvisioningAllowed bool `yaml:"selfProvisioning" envconfig:"SELF_PROVISIONING"` - LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"` - SessionSecret string `yaml:"sessionSecret" envconfig:"SESSION_SECRET"` - } `yaml:"core"` - Database common.DatabaseConfig `yaml:"database"` - Email common.MailConfig `yaml:"email"` - LDAP ldap.Config `yaml:"ldap"` - WG wireguard.Config `yaml:"wg"` -} - -func NewConfig() *Config { - cfg := &Config{} - - // Default config - cfg.Core.ListeningAddress = ":8123" - cfg.Core.Title = "WireGuard VPN" - cfg.Core.CompanyName = "WireGuard Portal" - cfg.Core.ExternalUrl = "http://localhost:8123" - cfg.Core.MailFrom = "WireGuard VPN " - cfg.Core.AdminUser = "admin@wgportal.local" - cfg.Core.AdminPassword = "wgportal" - cfg.Core.LdapEnabled = false - cfg.Core.EditableKeys = true - cfg.Core.SessionSecret = "secret" - - cfg.Database.Typ = "sqlite" - cfg.Database.Database = "data/wg_portal.db" - - cfg.LDAP.URL = "ldap://srv-ad01.company.local:389" - cfg.LDAP.BaseDN = "DC=COMPANY,DC=LOCAL" - cfg.LDAP.StartTLS = true - cfg.LDAP.BindUser = "company\\\\ldap_wireguard" - cfg.LDAP.BindPass = "SuperSecret" - cfg.LDAP.EmailAttribute = "mail" - cfg.LDAP.FirstNameAttribute = "givenName" - cfg.LDAP.LastNameAttribute = "sn" - cfg.LDAP.PhoneAttribute = "telephoneNumber" - cfg.LDAP.GroupMemberAttribute = "memberOf" - cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL" - cfg.LDAP.LoginFilter = "(&(objectClass=organizationalPerson)(mail={{login_identifier}})(!userAccountControl:1.2.840.113556.1.4.803:=2))" - cfg.LDAP.SyncFilter = "(&(objectClass=organizationalPerson)(!userAccountControl:1.2.840.113556.1.4.803:=2)(mail=*))" - - cfg.WG.DeviceNames = []string{"wg0"} - cfg.WG.DefaultDeviceName = "wg0" - cfg.WG.ConfigDirectoryPath = "/etc/wireguard" - cfg.WG.ManageIPAddresses = true - cfg.Email.Host = "127.0.0.1" - cfg.Email.Port = 25 - cfg.Email.Encryption = common.MailEncryptionNone - cfg.Email.AuthType = common.MailAuthPlain - - // Load config from file and environment - cfgFile, ok := os.LookupEnv("CONFIG_FILE") - if !ok { - cfgFile = "config.yml" // Default config file - } - err := loadConfigFile(cfg, cfgFile) - if err != nil { - logrus.Warnf("unable to load config.yml file: %v, using default configuration...", err) - } - err = loadConfigEnv(cfg) - if err != nil { - logrus.Warnf("unable to load environment config: %v", err) - } - - if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" { - logrus.Warnf("managing IP addresses only works on linux, feature disabled...") - cfg.WG.ManageIPAddresses = false - } - - return cfg -} diff --git a/internal/server/docs/docs.go b/internal/server/docs/docs.go deleted file mode 100644 index 6eb3e36..0000000 --- a/internal/server/docs/docs.go +++ /dev/null @@ -1,1531 +0,0 @@ -// GENERATED BY THE COMMAND ABOVE; DO NOT EDIT -// This file was generated by swaggo/swag - -package docs - -import ( - "bytes" - "encoding/json" - "strings" - - "github.com/alecthomas/template" - "github.com/swaggo/swag" -) - -var doc = `{ - "schemes": {{ marshal .Schemes }}, - "swagger": "2.0", - "info": { - "description": "{{.Description}}", - "title": "{{.Title}}", - "contact": { - "name": "WireGuard Portal Project", - "url": "https://github.com/h44z/wg-portal" - }, - "license": { - "name": "MIT", - "url": "https://github.com/h44z/wg-portal/blob/master/LICENSE.txt" - }, - "version": "{{.Version}}" - }, - "host": "{{.Host}}", - "basePath": "{{.BasePath}}", - "paths": { - "/backend/device": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Interface" - ], - "summary": "Get the given device", - "parameters": [ - { - "type": "string", - "description": "Device Name", - "name": "device", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Device" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "put": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Interface" - ], - "summary": "Updates the given device based on the given device model (UNIMPLEMENTED)", - "parameters": [ - { - "type": "string", - "description": "Device Name", - "name": "device", - "in": "query", - "required": true - }, - { - "description": "Device Model", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/wireguard.Device" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Device" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "patch": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Interface" - ], - "summary": "Updates the given device based on the given partial device model (UNIMPLEMENTED)", - "parameters": [ - { - "type": "string", - "description": "Device Name", - "name": "device", - "in": "query", - "required": true - }, - { - "description": "Device Model", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/wireguard.Device" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Device" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/backend/devices": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Interface" - ], - "summary": "Get all devices", - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "array", - "items": { - "$ref": "#/definitions/wireguard.Device" - } - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/backend/peer": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Retrieves the peer for the given public key", - "parameters": [ - { - "type": "string", - "description": "Public Key (Base 64)", - "name": "pkey", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "put": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Updates the given peer based on the given peer model", - "parameters": [ - { - "type": "string", - "description": "Public Key", - "name": "pkey", - "in": "query", - "required": true - }, - { - "description": "Peer Model", - "name": "peer", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "delete": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Updates the given peer based on the given partial peer model", - "parameters": [ - { - "type": "string", - "description": "Public Key", - "name": "pkey", - "in": "query", - "required": true - } - ], - "responses": { - "202": { - "description": "No Content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "patch": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Updates the given peer based on the given partial peer model", - "parameters": [ - { - "type": "string", - "description": "Public Key", - "name": "pkey", - "in": "query", - "required": true - }, - { - "description": "Peer Model", - "name": "peer", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/backend/peers": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Retrieves all peers for the given interface", - "parameters": [ - { - "type": "string", - "description": "Device Name", - "name": "device", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "array", - "items": { - "$ref": "#/definitions/wireguard.Peer" - } - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "post": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Peers" - ], - "summary": "Creates a new peer based on the given peer model", - "parameters": [ - { - "type": "string", - "description": "Device Name", - "name": "device", - "in": "query", - "required": true - }, - { - "description": "Peer Model", - "name": "peer", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/wireguard.Peer" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/backend/user": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Retrieves user based on given Email", - "parameters": [ - { - "type": "string", - "description": "User Email", - "name": "email", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/users.User" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "put": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Updates a user based on the given user model", - "parameters": [ - { - "type": "string", - "description": "User Email", - "name": "email", - "in": "query", - "required": true - }, - { - "description": "User Model", - "name": "user", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/users.User" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/users.User" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "delete": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Deletes the specified user", - "parameters": [ - { - "type": "string", - "description": "User Email", - "name": "email", - "in": "query", - "required": true - } - ], - "responses": { - "204": { - "description": "No content" - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "patch": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Updates a user based on the given partial user model", - "parameters": [ - { - "type": "string", - "description": "User Email", - "name": "email", - "in": "query", - "required": true - }, - { - "description": "User Model", - "name": "user", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/users.User" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/users.User" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/backend/users": { - "get": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Retrieves all users", - "responses": { - "200": { - "description": "OK", - "schema": { - "type": "array", - "items": { - "$ref": "#/definitions/users.User" - } - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "post": { - "security": [ - { - "ApiBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "application/json" - ], - "tags": [ - "Users" - ], - "summary": "Creates a new user based on the given user model", - "parameters": [ - { - "description": "User Model", - "name": "user", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/users.User" - } - } - ], - "responses": { - "200": { - "description": "OK", - "schema": { - "$ref": "#/definitions/users.User" - } - }, - "400": { - "description": "Bad Request", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "500": { - "description": "Internal Server Error", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/provisioning/peer": { - "get": { - "security": [ - { - "GeneralBasicAuth": [] - } - ], - "produces": [ - "text/plain" - ], - "tags": [ - "Provisioning" - ], - "summary": "Retrieves the peer config for the given public key", - "parameters": [ - { - "type": "string", - "description": "Public Key (Base 64)", - "name": "pkey", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "The WireGuard configuration file", - "schema": { - "type": "string" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - }, - "/provisioning/peers": { - "get": { - "security": [ - { - "GeneralBasicAuth": [] - } - ], - "produces": [ - "application/json" - ], - "tags": [ - "Provisioning" - ], - "summary": "Retrieves all active peers for the given email address", - "parameters": [ - { - "type": "string", - "description": "Email Address", - "name": "email", - "in": "query", - "required": true - } - ], - "responses": { - "200": { - "description": "All active WireGuard peers", - "schema": { - "type": "array", - "items": { - "$ref": "#/definitions/server.PeerDeploymentInformation" - } - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - }, - "post": { - "security": [ - { - "GeneralBasicAuth": [] - } - ], - "consumes": [ - "application/json" - ], - "produces": [ - "text/plain" - ], - "tags": [ - "Provisioning" - ], - "summary": "Creates the requested peer config and returns the config file", - "parameters": [ - { - "description": "Provisioning Request Model", - "name": "body", - "in": "body", - "required": true, - "schema": { - "$ref": "#/definitions/server.ProvisioningRequest" - } - } - ], - "responses": { - "200": { - "description": "The WireGuard configuration file", - "schema": { - "type": "string" - } - }, - "401": { - "description": "Unauthorized", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "403": { - "description": "Forbidden", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - }, - "404": { - "description": "Not Found", - "schema": { - "$ref": "#/definitions/server.ApiError" - } - } - } - } - } - }, - "definitions": { - "gorm.DeletedAt": { - "type": "object", - "properties": { - "time": { - "type": "string" - }, - "valid": { - "description": "Valid is true if Time is not NULL", - "type": "boolean" - } - } - }, - "server.ApiError": { - "type": "object", - "properties": { - "message": { - "type": "string" - } - } - }, - "server.PeerDeploymentInformation": { - "type": "object", - "properties": { - "device": { - "type": "string" - }, - "deviceIdentifier": { - "type": "string" - }, - "identifier": { - "type": "string" - }, - "publicKey": { - "type": "string" - } - } - }, - "server.ProvisioningRequest": { - "type": "object", - "required": [ - "email", - "identifier" - ], - "properties": { - "allowedIPsStr": { - "type": "string" - }, - "deviceName": { - "description": "DeviceName is optional, if not specified, the configured default device will be used.", - "type": "string" - }, - "dnsstr": { - "type": "string" - }, - "email": { - "type": "string" - }, - "identifier": { - "type": "string" - }, - "mtu": { - "type": "integer" - }, - "persistentKeepalive": { - "type": "integer" - } - } - }, - "users.User": { - "type": "object", - "required": [ - "email", - "firstname", - "lastname" - ], - "properties": { - "createdAt": { - "description": "database internal fields", - "type": "string" - }, - "deletedAt": { - "$ref": "#/definitions/gorm.DeletedAt" - }, - "email": { - "description": "required fields", - "type": "string" - }, - "firstname": { - "description": "optional fields", - "type": "string" - }, - "isAdmin": { - "type": "boolean" - }, - "lastname": { - "type": "string" - }, - "password": { - "description": "optional, integrated password authentication", - "type": "string" - }, - "phone": { - "type": "string" - }, - "source": { - "type": "string" - }, - "updatedAt": { - "type": "string" - } - } - }, - "wireguard.Device": { - "type": "object", - "required": [ - "deviceName", - "ipsStr", - "privateKey", - "publicKey", - "type" - ], - "properties": { - "createdAt": { - "type": "string" - }, - "defaultAllowedIPsStr": { - "description": "comma separated list of IPs that are used in the client config file", - "type": "string" - }, - "defaultEndpoint": { - "description": "Settings that are applied to all peer by default", - "type": "string" - }, - "defaultPersistentKeepalive": { - "type": "integer" - }, - "deviceName": { - "type": "string" - }, - "displayName": { - "type": "string" - }, - "dnsstr": { - "description": "comma separated list of the DNS servers of the client, wg-quick addition", - "type": "string" - }, - "firewallMark": { - "type": "integer" - }, - "ipsStr": { - "description": "comma separated list of the IPs of the client, wg-quick addition", - "type": "string" - }, - "listenPort": { - "type": "integer" - }, - "mtu": { - "description": "the interface MTU, wg-quick addition", - "type": "integer" - }, - "postDown": { - "description": "post down script, wg-quick addition", - "type": "string" - }, - "postUp": { - "description": "post up script, wg-quick addition", - "type": "string" - }, - "preDown": { - "description": "pre down script, wg-quick addition", - "type": "string" - }, - "preUp": { - "description": "pre up script, wg-quick addition", - "type": "string" - }, - "privateKey": { - "description": "Core WireGuard Settings (Interface section)", - "type": "string" - }, - "publicKey": { - "description": "Misc. WireGuard Settings", - "type": "string" - }, - "routingTable": { - "description": "the routing table, wg-quick addition", - "type": "string" - }, - "saveConfig": { - "description": "if set to ` + "`" + `true', the configuration is saved from the current state of the interface upon shutdown, wg-quick addition", - "type": "boolean" - }, - "type": { - "type": "string" - }, - "updatedAt": { - "type": "string" - } - } - }, - "wireguard.Peer": { - "type": "object", - "required": [ - "deviceName", - "email", - "identifier", - "publicKey" - ], - "properties": { - "allowedIPsStr": { - "description": "a comma separated list of IPs that are used in the client config file", - "type": "string" - }, - "createdAt": { - "type": "string" - }, - "createdBy": { - "type": "string" - }, - "deactivatedAt": { - "type": "string" - }, - "deviceName": { - "type": "string" - }, - "dnsstr": { - "description": "comma separated list of the DNS servers for the client", - "type": "string" - }, - "email": { - "type": "string" - }, - "endpoint": { - "type": "string" - }, - "identifier": { - "description": "Identifier AND Email make a WireGuard peer unique", - "type": "string" - }, - "ignoreGlobalSettings": { - "type": "boolean" - }, - "ipsStr": { - "description": "a comma separated list of IPs of the client", - "type": "string" - }, - "mtu": { - "description": "Global Device Settings (can be ignored, only make sense if device is in server mode)", - "type": "integer" - }, - "persistentKeepalive": { - "type": "integer" - }, - "presharedKey": { - "type": "string" - }, - "privateKey": { - "description": "Misc. WireGuard Settings", - "type": "string" - }, - "publicKey": { - "description": "Core WireGuard Settings", - "type": "string" - }, - "updatedAt": { - "type": "string" - }, - "updatedBy": { - "type": "string" - } - } - } - }, - "securityDefinitions": { - "ApiBasicAuth": { - "type": "basic" - }, - "GeneralBasicAuth": { - "type": "basic" - } - } -}` - -type swaggerInfo struct { - Version string - Host string - BasePath string - Schemes []string - Title string - Description string -} - -// SwaggerInfo holds exported Swagger Info so clients can modify it -var SwaggerInfo = swaggerInfo{ - Version: "1.0", - Host: "", - BasePath: "/api/v1", - Schemes: []string{}, - Title: "WireGuard Portal API", - Description: "WireGuard Portal API for managing users and peers.", -} - -type s struct{} - -func (s *s) ReadDoc() string { - sInfo := SwaggerInfo - sInfo.Description = strings.Replace(sInfo.Description, "\n", "\\n", -1) - - t, err := template.New("swagger_info").Funcs(template.FuncMap{ - "marshal": func(v interface{}) string { - a, _ := json.Marshal(v) - return string(a) - }, - }).Parse(doc) - if err != nil { - return doc - } - - var tpl bytes.Buffer - if err := t.Execute(&tpl, sInfo); err != nil { - return doc - } - - return tpl.String() -} - -func init() { - swag.Register(swag.Name, &s{}) -} diff --git a/internal/server/handlers_auth.go b/internal/server/handlers_auth.go deleted file mode 100644 index f1b4c5d..0000000 --- a/internal/server/handlers_auth.go +++ /dev/null @@ -1,151 +0,0 @@ -package server - -import ( - "net/http" - "strings" - - "github.com/pkg/errors" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/authentication" - "github.com/h44z/wg-portal/internal/users" - "github.com/sirupsen/logrus" - csrf "github.com/utrack/gin-csrf" -) - -func (s *Server) GetLogin(c *gin.Context) { - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - c.Redirect(http.StatusSeeOther, "/") // already logged in - } - - authError := c.DefaultQuery("err", "") - errMsg := "Unknown error occurred, try again!" - switch authError { - case "missingdata": - errMsg = "Invalid login data retrieved, please fill out all fields and try again!" - case "authfail": - errMsg = "Authentication failed!" - case "loginreq": - errMsg = "Login required!" - } - - c.HTML(http.StatusOK, "login.html", gin.H{ - "error": authError != "", - "message": errMsg, - "static": s.getStaticData(), - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostLogin(c *gin.Context) { - currentSession := GetSessionData(c) - if currentSession.LoggedIn { - // already logged in - c.Redirect(http.StatusSeeOther, "/") - return - } - - username := strings.ToLower(c.PostForm("username")) - password := c.PostForm("password") - - // Validate form input - if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { - c.Redirect(http.StatusSeeOther, "/auth/login?err=missingdata") - return - } - - // Check all available auth backends - user, err := s.checkAuthentication(username, password) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "login error", err.Error()) - return - } - - // Check if user is authenticated - if user == nil { - c.Redirect(http.StatusSeeOther, "/auth/login?err=authfail") - return - } - - // Set authenticated session - sessionData := GetSessionData(c) - sessionData.LoggedIn = true - sessionData.IsAdmin = user.IsAdmin - sessionData.Email = user.Email - sessionData.Firstname = user.Firstname - sessionData.Lastname = user.Lastname - sessionData.DeviceName = s.wg.Cfg.DeviceNames[0] - - // Check if user already has a peer setup, if not create one - if err := s.CreateUserDefaultPeer(user.Email, s.wg.Cfg.GetDefaultDeviceName()); err != nil { - // Not a fatal error, just log it... - logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err) - } - - if err := UpdateSessionData(c, sessionData); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/") -} - -func (s *Server) GetLogout(c *gin.Context) { - currentSession := GetSessionData(c) - - if !currentSession.LoggedIn { // Not logged in - c.Redirect(http.StatusSeeOther, "/") - return - } - - if err := DestroySessionData(c); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "logout error", "failed to destroy session") - return - } - c.Redirect(http.StatusSeeOther, "/") -} - -func (s *Server) checkAuthentication(username, password string) (*users.User, error) { - var user *users.User - - // Check all available auth backends - for _, provider := range s.auth.GetProvidersForType(authentication.AuthProviderTypePassword) { - // try to log in to the given provider - authEmail, err := provider.Login(&authentication.AuthContext{ - Username: username, - Password: password, - }) - if err != nil { - continue - } - - // Login succeeded - user = s.users.GetUser(authEmail) - if user != nil { - break // user exists, nothing more to do... - } - - // create new user in the database (or reactivate him) - userData, err := provider.GetUserModel(&authentication.AuthContext{ - Username: username, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to get user model") - } - if err := s.CreateUser(users.User{ - Email: userData.Email, - Source: users.UserSource(provider.GetName()), - IsAdmin: userData.IsAdmin, - Firstname: userData.Firstname, - Lastname: userData.Lastname, - Phone: userData.Phone, - }, s.wg.Cfg.GetDefaultDeviceName()); err != nil { - return nil, errors.Wrap(err, "failed to update user data") - } - - user = s.users.GetUser(authEmail) - break - } - - return user, nil -} diff --git a/internal/server/handlers_common.go b/internal/server/handlers_common.go deleted file mode 100644 index ab851bb..0000000 --- a/internal/server/handlers_common.go +++ /dev/null @@ -1,201 +0,0 @@ -package server - -import ( - "net/http" - "strconv" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/pkg/errors" -) - -func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) { - currentSession := GetSessionData(c) - - c.HTML(code, "error.html", gin.H{ - "Data": gin.H{ - "Code": strconv.Itoa(code), - "Message": message, - "Details": details, - }, - "Route": c.Request.URL.Path, - "Session": GetSessionData(c), - "Static": s.getStaticData(), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - }) -} - -func (s *Server) GetIndex(c *gin.Context) { - currentSession := GetSessionData(c) - - c.HTML(http.StatusOK, "index.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - }) -} - -func (s *Server) GetAdminIndex(c *gin.Context) { - currentSession := GetSessionData(c) - - sort := c.Query("sort") - if sort != "" { - if currentSession.SortedBy["peers"] != sort { - currentSession.SortedBy["peers"] = sort - currentSession.SortDirection["peers"] = "asc" - } else { - if currentSession.SortDirection["peers"] == "asc" { - currentSession.SortDirection["peers"] = "desc" - } else { - currentSession.SortDirection["peers"] = "asc" - } - } - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "sort error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin/") - return - } - - search, searching := c.GetQuery("search") - if searching { - currentSession.Search["peers"] = search - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "search error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin/") - return - } - - deviceName := c.Query("device") - if deviceName != "" { - if !common.ListContains(s.wg.Cfg.DeviceNames, deviceName) { - s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "no such device") - return - } - currentSession.DeviceName = deviceName - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin/") - return - } - - device := s.peers.GetDevice(currentSession.DeviceName) - users := s.peers.GetFilteredAndSortedPeers(currentSession.DeviceName, currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"]) - - c.HTML(http.StatusOK, "admin_index.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Peers": users, - "TotalPeers": len(s.peers.GetAllPeers(currentSession.DeviceName)), - "Users": s.users.GetUsers(), - "Device": device, - "DeviceNames": s.GetDeviceNames(), - }) -} - -func (s *Server) GetUserIndex(c *gin.Context) { - currentSession := GetSessionData(c) - - sort := c.Query("sort") - if sort != "" { - if currentSession.SortedBy["userpeers"] != sort { - currentSession.SortedBy["userpeers"] = sort - currentSession.SortDirection["userpeers"] = "asc" - } else { - if currentSession.SortDirection["userpeers"] == "asc" { - currentSession.SortDirection["userpeers"] = "desc" - } else { - currentSession.SortDirection["userpeers"] = "asc" - } - } - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "sort error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin") - return - } - - peers := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email) - - c.HTML(http.StatusOK, "user_index.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Peers": peers, - "TotalPeers": len(peers), - "Users": []users.User{*s.users.GetUser(currentSession.Email)}, - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - }) -} - -func (s *Server) updateFormInSession(c *gin.Context, formData interface{}) error { - currentSession := GetSessionData(c) - currentSession.FormData = formData - - if err := UpdateSessionData(c, currentSession); err != nil { - return errors.WithMessage(err, "failed to update form in session") - } - - return nil -} - -func (s *Server) setNewPeerFormInSession(c *gin.Context) (SessionData, error) { - currentSession := GetSessionData(c) - - // If session does not contain a peer form ignore update - // If url contains a formerr parameter reset the form - if currentSession.FormData == nil || c.Query("formerr") == "" { - user, err := s.PrepareNewPeer(currentSession.DeviceName) - if err != nil { - return currentSession, errors.WithMessage(err, "failed to prepare new peer") - } - currentSession.FormData = user - } - - if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, errors.WithMessage(err, "failed to update peer form in session") - } - - return currentSession, nil -} - -func (s *Server) setFormInSession(c *gin.Context, formData interface{}) (SessionData, error) { - currentSession := GetSessionData(c) - // If session does not contain a form ignore update - // If url contains a formerr parameter reset the form - if currentSession.FormData == nil || c.Query("formerr") == "" { - currentSession.FormData = formData - } - - if err := UpdateSessionData(c, currentSession); err != nil { - return currentSession, errors.WithMessage(err, "failed to set form in session") - } - - return currentSession, nil -} - -func (s *Server) isUserStillValid(email string) bool { - if s.users.GetUser(email) == nil { - return false - } - return true -} diff --git a/internal/server/handlers_interface.go b/internal/server/handlers_interface.go deleted file mode 100644 index f22dec5..0000000 --- a/internal/server/handlers_interface.go +++ /dev/null @@ -1,177 +0,0 @@ -package server - -import ( - "fmt" - "net/http" - "strings" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/wireguard" - csrf "github.com/utrack/gin-csrf" -) - -func (s *Server) GetAdminEditInterface(c *gin.Context) { - currentSession := GetSessionData(c) - device := s.peers.GetDevice(currentSession.DeviceName) - currentSession, err := s.setFormInSession(c, device) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - - c.HTML(http.StatusOK, "admin_edit_interface.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Device": currentSession.FormData.(wireguard.Device), - "EditableKeys": s.config.Core.EditableKeys, - "DeviceNames": s.GetDeviceNames(), - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminEditInterface(c *gin.Context) { - currentSession := GetSessionData(c) - var formDevice wireguard.Device - if currentSession.FormData != nil { - formDevice = currentSession.FormData.(wireguard.Device) - } - if err := c.ShouldBind(&formDevice); err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=bind") - return - } - // Clean list input - formDevice.IPsStr = common.ListToString(common.ParseStringList(formDevice.IPsStr)) - formDevice.DefaultAllowedIPsStr = common.ListToString(common.ParseStringList(formDevice.DefaultAllowedIPsStr)) - formDevice.DNSStr = common.ListToString(common.ParseStringList(formDevice.DNSStr)) - - // Clean interface parameters based on interface type - switch formDevice.Type { - case wireguard.DeviceTypeClient: - formDevice.ListenPort = 0 - formDevice.DefaultEndpoint = "" - formDevice.DefaultAllowedIPsStr = "" - formDevice.DefaultPersistentKeepalive = 0 - formDevice.SaveConfig = false - case wireguard.DeviceTypeServer: - } - - // Update WireGuard device - err := s.wg.UpdateDevice(formDevice.DeviceName, formDevice.GetConfig()) - if err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, "Failed to update device in WireGuard: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=wg") - return - } - - // Update in database - err = s.peers.UpdateDevice(formDevice) - if err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, "Failed to update device in database: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") - return - } - - // Update WireGuard config file - err = s.WriteWireGuardConfigFile(currentSession.DeviceName) - if err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") - return - } - - // Update interface IP address - if s.config.WG.ManageIPAddresses { - if err := s.wg.SetIPAddress(currentSession.DeviceName, formDevice.GetIPAddresses()); err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") - } - if err := s.wg.SetMTU(currentSession.DeviceName, formDevice.Mtu); err != nil { - _ = s.updateFormInSession(c, formDevice) - SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update") - } - } - - SetFlashMessage(c, "Changes applied successfully!", "success") - if !s.config.WG.ManageIPAddresses { - SetFlashMessage(c, "WireGuard must be restarted to apply ip changes.", "warning") - } - c.Redirect(http.StatusSeeOther, "/admin/device/edit") -} - -func (s *Server) GetInterfaceConfig(c *gin.Context) { - currentSession := GetSessionData(c) - device := s.peers.GetDevice(currentSession.DeviceName) - peers := s.peers.GetActivePeers(device.DeviceName) - cfg, err := device.GetConfigFile(peers) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) - return - } - - filename := strings.ToLower(device.DeviceName) + ".conf" - - c.Header("Content-Disposition", "attachment; filename="+filename) - c.Data(http.StatusOK, "application/config", cfg) - return -} - -func (s *Server) GetSaveConfig(c *gin.Context) { - currentSession := GetSessionData(c) - - err := s.WriteWireGuardConfigFile(currentSession.DeviceName) - if err != nil { - SetFlashMessage(c, "Failed to save WireGuard config-file: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/") - return - } - - SetFlashMessage(c, "Updated WireGuard config-file", "success") - c.Redirect(http.StatusSeeOther, "/admin/") - return -} - -func (s *Server) GetApplyGlobalConfig(c *gin.Context) { - currentSession := GetSessionData(c) - device := s.peers.GetDevice(currentSession.DeviceName) - peers := s.peers.GetAllPeers(device.DeviceName) - - if device.Type == wireguard.DeviceTypeClient { - SetFlashMessage(c, "Cannot apply global configuration while interface is in client mode.", "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit") - return - } - - updateCounter := 0 - for _, peer := range peers { - if peer.IgnoreGlobalSettings { - continue - } - - peer.AllowedIPsStr = device.DefaultAllowedIPsStr - peer.Endpoint = device.DefaultEndpoint - peer.PersistentKeepalive = device.DefaultPersistentKeepalive - peer.DNSStr = device.DNSStr - peer.Mtu = device.Mtu - - if err := s.peers.UpdatePeer(peer); err != nil { - SetFlashMessage(c, err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/device/edit") - return - } - updateCounter++ - } - - SetFlashMessage(c, fmt.Sprintf("Global configuration updated for %d clients.", updateCounter), "success") - c.Redirect(http.StatusSeeOther, "/admin/device/edit") - return -} diff --git a/internal/server/handlers_peer.go b/internal/server/handlers_peer.go deleted file mode 100644 index 9c8b30d..0000000 --- a/internal/server/handlers_peer.go +++ /dev/null @@ -1,394 +0,0 @@ -package server - -import ( - "bytes" - "net" - "net/http" - "net/url" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/tatsushid/go-fastping" - csrf "github.com/utrack/gin-csrf" -) - -type LdapCreateForm struct { - Emails string `form:"email" binding:"required"` - Identifier string `form:"identifier" binding:"required,lte=20"` -} - -func (s *Server) GetAdminEditPeer(c *gin.Context) { - peer := s.peers.GetPeerByKey(c.Query("pkey")) - - currentSession, err := s.setFormInSession(c, peer) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - - c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Peer": currentSession.FormData.(wireguard.Peer), - "EditableKeys": s.config.Core.EditableKeys, - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - "AdminEmail": s.config.Core.AdminUser, - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminEditPeer(c *gin.Context) { - currentPeer := s.peers.GetPeerByKey(c.Query("pkey")) - urlEncodedKey := url.QueryEscape(c.Query("pkey")) - - currentSession := GetSessionData(c) - var formPeer wireguard.Peer - if currentSession.FormData != nil { - formPeer = currentSession.FormData.(wireguard.Peer) - } - if err := c.ShouldBind(&formPeer); err != nil { - _ = s.updateFormInSession(c, formPeer) - SetFlashMessage(c, "failed to bind form data: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey+"&formerr=bind") - return - } - - // Clean list input - formPeer.IPsStr = common.ListToString(common.ParseStringList(formPeer.IPsStr)) - formPeer.AllowedIPsStr = common.ListToString(common.ParseStringList(formPeer.AllowedIPsStr)) - formPeer.AllowedIPsSrvStr = common.ListToString(common.ParseStringList(formPeer.AllowedIPsSrvStr)) - - disabled := c.PostForm("isdisabled") != "" - now := time.Now() - if disabled && currentPeer.DeactivatedAt == nil { - formPeer.DeactivatedAt = &now - } else if !disabled { - formPeer.DeactivatedAt = nil - } - - // Update in database - if err := s.UpdatePeer(formPeer, now); err != nil { - _ = s.updateFormInSession(c, formPeer) - SetFlashMessage(c, "failed to update user: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey+"&formerr=update") - return - } - - SetFlashMessage(c, "changes applied successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey) -} - -func (s *Server) GetAdminCreatePeer(c *gin.Context) { - currentSession, err := s.setNewPeerFormInSession(c) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Peer": currentSession.FormData.(wireguard.Peer), - "EditableKeys": s.config.Core.EditableKeys, - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - "AdminEmail": s.config.Core.AdminUser, - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminCreatePeer(c *gin.Context) { - currentSession := GetSessionData(c) - var formPeer wireguard.Peer - if currentSession.FormData != nil { - formPeer = currentSession.FormData.(wireguard.Peer) - } - if err := c.ShouldBind(&formPeer); err != nil { - _ = s.updateFormInSession(c, formPeer) - SetFlashMessage(c, "failed to bind form data: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=bind") - return - } - - // Clean list input - formPeer.IPsStr = common.ListToString(common.ParseStringList(formPeer.IPsStr)) - formPeer.AllowedIPsStr = common.ListToString(common.ParseStringList(formPeer.AllowedIPsStr)) - formPeer.AllowedIPsSrvStr = common.ListToString(common.ParseStringList(formPeer.AllowedIPsSrvStr)) - - disabled := c.PostForm("isdisabled") != "" - now := time.Now() - if disabled { - formPeer.DeactivatedAt = &now - } - - if err := s.CreatePeer(currentSession.DeviceName, formPeer); err != nil { - _ = s.updateFormInSession(c, formPeer) - SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create") - return - } - - SetFlashMessage(c, "client created successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin") -} - -func (s *Server) GetAdminCreateLdapPeers(c *gin.Context) { - currentSession, err := s.setFormInSession(c, LdapCreateForm{Identifier: "Default"}) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - - c.HTML(http.StatusOK, "admin_create_clients.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Users": s.users.GetFilteredAndSortedUsers("lastname", "asc", ""), - "FormData": currentSession.FormData.(LdapCreateForm), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminCreateLdapPeers(c *gin.Context) { - currentSession := GetSessionData(c) - var formData LdapCreateForm - if currentSession.FormData != nil { - formData = currentSession.FormData.(LdapCreateForm) - } - if err := c.ShouldBind(&formData); err != nil { - _ = s.updateFormInSession(c, formData) - SetFlashMessage(c, "failed to bind form data: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=bind") - return - } - - emails := common.ParseStringList(formData.Emails) - for i := range emails { - // TODO: also check email addr for validity? - if !strings.ContainsRune(emails[i], '@') { - _ = s.updateFormInSession(c, formData) - SetFlashMessage(c, "invalid email address: "+emails[i], "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=mail") - return - } - } - - logrus.Infof("creating %d ldap peers", len(emails)) - - for i := range emails { - if err := s.CreatePeerByEmail(currentSession.DeviceName, emails[i], formData.Identifier, false); err != nil { - _ = s.updateFormInSession(c, formData) - SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create") - return - } - } - - SetFlashMessage(c, "client(s) created successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin/peer/createldap") -} - -func (s *Server) GetAdminDeletePeer(c *gin.Context) { - currentPeer := s.peers.GetPeerByKey(c.Query("pkey")) - if err := s.DeletePeer(currentPeer); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error()) - return - } - SetFlashMessage(c, "peer deleted successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin") -} - -func (s *Server) GetPeerQRCode(c *gin.Context) { - peer := s.peers.GetPeerByKey(c.Query("pkey")) - currentSession := GetSessionData(c) - if !currentSession.IsAdmin && peer.Email != currentSession.Email { - s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") - return - } - - png, err := peer.GetQRCode() - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "QRCode error", err.Error()) - return - } - c.Data(http.StatusOK, "image/png", png) - return -} - -func (s *Server) GetPeerConfig(c *gin.Context) { - peer := s.peers.GetPeerByKey(c.Query("pkey")) - currentSession := GetSessionData(c) - if !currentSession.IsAdmin && peer.Email != currentSession.Email { - s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") - return - } - - cfg, err := peer.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName)) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error()) - return - } - - c.Header("Content-Disposition", "attachment; filename="+peer.GetConfigFileName()) - c.Data(http.StatusOK, "application/config", cfg) - return -} - -func (s *Server) GetPeerConfigMail(c *gin.Context) { - peer := s.peers.GetPeerByKey(c.Query("pkey")) - currentSession := GetSessionData(c) - if !currentSession.IsAdmin && peer.Email != currentSession.Email { - s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") - return - } - - if err := s.sendPeerConfigMail(peer); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Email error", err.Error()) - return - } - - SetFlashMessage(c, "mail sent successfully", "success") - if strings.HasPrefix(c.Request.URL.Path, "/user") { - c.Redirect(http.StatusSeeOther, "/user/profile") - } else { - c.Redirect(http.StatusSeeOther, "/admin") - } -} - -func (s *Server) GetPeerStatus(c *gin.Context) { - peer := s.peers.GetPeerByKey(c.Query("pkey")) - currentSession := GetSessionData(c) - if !currentSession.IsAdmin && peer.Email != currentSession.Email { - s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") - return - } - - if peer.Peer == nil { // no peer means disabled - c.JSON(http.StatusOK, false) - return - } - - isOnline := false - ping := make(chan bool) - defer close(ping) - for _, cidr := range peer.GetIPAddresses() { - ip, _, _ := net.ParseCIDR(cidr) - var ra *net.IPAddr - if common.IsIPv6(ip.String()) { - ra, _ = net.ResolveIPAddr("ip6:ipv6-icmp", ip.String()) - } else { - - ra, _ = net.ResolveIPAddr("ip4:icmp", ip.String()) - } - - p := fastping.NewPinger() - p.AddIPAddr(ra) - p.OnRecv = func(addr *net.IPAddr, rtt time.Duration) { - ping <- true - p.Stop() - } - p.OnIdle = func() { - ping <- false - p.Stop() - } - p.MaxRTT = 500 * time.Millisecond - p.RunLoop() - - if <-ping { - isOnline = true - break - } - } - - c.JSON(http.StatusOK, isOnline) - return -} - -func (s *Server) GetAdminSendEmails(c *gin.Context) { - currentSession := GetSessionData(c) - if !currentSession.IsAdmin { - s.GetHandleError(c, http.StatusUnauthorized, "No permissions", "You don't have permissions to view this resource!") - return - } - - peers := s.peers.GetActivePeers(currentSession.DeviceName) - for _, peer := range peers { - if err := s.sendPeerConfigMail(peer); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Email error", err.Error()) - return - } - } - - SetFlashMessage(c, "emails sent successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin") -} - -func (s *Server) sendPeerConfigMail(peer wireguard.Peer) error { - user := s.users.GetUser(peer.Email) - - cfg, err := peer.GetConfigFile(s.peers.GetDevice(peer.DeviceName)) - if err != nil { - return errors.Wrap(err, "failed to get config file") - } - png, err := peer.GetQRCode() - if err != nil { - return errors.Wrap(err, "failed to get qr-code") - } - // Apply mail template - qrcodeFileName := "wireguard-qrcode.png" - var tplBuff bytes.Buffer - if err := s.mailTpl.Execute(&tplBuff, struct { - Peer wireguard.Peer - User *users.User - QrcodePngName string - PortalUrl string - }{ - Peer: peer, - User: user, - QrcodePngName: qrcodeFileName, - PortalUrl: s.config.Core.ExternalUrl, - }); err != nil { - return errors.Wrap(err, "failed to execute mail template") - } - - // Send mail - attachments := []common.MailAttachment{ - { - Name: peer.GetConfigFileName(), - ContentType: "application/config", - Data: bytes.NewReader(cfg), - }, - { - Name: qrcodeFileName, - ContentType: "image/png", - Data: bytes.NewReader(png), - Embedded: true, - }, - { - Name: qrcodeFileName, - ContentType: "image/png", - Data: bytes.NewReader(png), - }, - } - - if err := common.SendEmailWithAttachments(s.config.Email, s.config.Core.MailFrom, "", "WireGuard VPN Configuration", - "Your mail client does not support HTML. Please find the configuration attached to this mail.", tplBuff.String(), - []string{peer.Email}, attachments); err != nil { - return errors.Wrap(err, "failed to send email") - } - - return nil -} diff --git a/internal/server/handlers_user.go b/internal/server/handlers_user.go deleted file mode 100644 index c0816f5..0000000 --- a/internal/server/handlers_user.go +++ /dev/null @@ -1,192 +0,0 @@ -package server - -import ( - "net/http" - "net/url" - "time" - - "github.com/gin-gonic/gin" - "github.com/h44z/wg-portal/internal/users" - csrf "github.com/utrack/gin-csrf" - "gorm.io/gorm" -) - -func (s *Server) GetAdminUsersIndex(c *gin.Context) { - currentSession := GetSessionData(c) - - sort := c.Query("sort") - if sort != "" { - if currentSession.SortedBy["users"] != sort { - currentSession.SortedBy["users"] = sort - currentSession.SortDirection["users"] = "asc" - } else { - if currentSession.SortDirection["users"] == "asc" { - currentSession.SortDirection["users"] = "desc" - } else { - currentSession.SortDirection["users"] = "asc" - } - } - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "sort error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin/users/") - return - } - - search, searching := c.GetQuery("search") - if searching { - currentSession.Search["users"] = search - - if err := UpdateSessionData(c, currentSession); err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "search error", "failed to save session") - return - } - c.Redirect(http.StatusSeeOther, "/admin/users/") - return - } - - dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"]) - - c.HTML(http.StatusOK, "admin_user_index.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "Users": dbUsers, - "TotalUsers": len(s.users.GetUsers()), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - }) -} - -func (s *Server) GetAdminUsersEdit(c *gin.Context) { - user := s.users.GetUserUnscoped(c.Query("pkey")) - - currentSession, err := s.setFormInSession(c, *user) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - - c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "User": currentSession.FormData.(users.User), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - "Epoch": time.Time{}, - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminUsersEdit(c *gin.Context) { - currentUser := s.users.GetUserUnscoped(c.Query("pkey")) - if currentUser == nil { - SetFlashMessage(c, "invalid user", "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/") - return - } - urlEncodedKey := url.QueryEscape(c.Query("pkey")) - - currentSession := GetSessionData(c) - var formUser users.User - if currentSession.FormData != nil { - formUser = currentSession.FormData.(users.User) - } - if err := c.ShouldBind(&formUser); err != nil { - _ = s.updateFormInSession(c, formUser) - SetFlashMessage(c, "failed to bind form data: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=bind") - return - } - - disabled := c.PostForm("isdisabled") != "" - if disabled { - formUser.DeletedAt = gorm.DeletedAt{ - Time: time.Now(), - Valid: true, - } - } else { - formUser.DeletedAt = gorm.DeletedAt{} - } - formUser.IsAdmin = c.PostForm("isadmin") == "true" - - if err := s.UpdateUser(formUser); err != nil { - _ = s.updateFormInSession(c, formUser) - SetFlashMessage(c, "failed to update user: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey+"&formerr=update") - return - } - - SetFlashMessage(c, "changes applied successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin/users/edit?pkey="+urlEncodedKey) -} - -func (s *Server) GetAdminUsersCreate(c *gin.Context) { - user := users.User{} - - currentSession, err := s.setFormInSession(c, user) - if err != nil { - s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error()) - return - } - - c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{ - "Route": c.Request.URL.Path, - "Alerts": GetFlashes(c), - "Session": currentSession, - "Static": s.getStaticData(), - "User": currentSession.FormData.(users.User), - "Device": s.peers.GetDevice(currentSession.DeviceName), - "DeviceNames": s.GetDeviceNames(), - "Epoch": time.Time{}, - "Csrf": csrf.GetToken(c), - }) -} - -func (s *Server) PostAdminUsersCreate(c *gin.Context) { - currentSession := GetSessionData(c) - var formUser users.User - if currentSession.FormData != nil { - formUser = currentSession.FormData.(users.User) - } - if err := c.ShouldBind(&formUser); err != nil { - _ = s.updateFormInSession(c, formUser) - SetFlashMessage(c, "failed to bind form data: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=bind") - return - } - - if formUser.Password == "" { - _ = s.updateFormInSession(c, formUser) - SetFlashMessage(c, "invalid password", "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") - return - } - - disabled := c.PostForm("isdisabled") != "" - if disabled { - formUser.DeletedAt = gorm.DeletedAt{ - Time: time.Now(), - Valid: true, - } - } else { - formUser.DeletedAt = gorm.DeletedAt{} - } - formUser.IsAdmin = c.PostForm("isadmin") == "true" - formUser.Source = users.UserSourceDatabase - - if err := s.CreateUser(formUser, currentSession.DeviceName); err != nil { - _ = s.updateFormInSession(c, formUser) - SetFlashMessage(c, "failed to add user: "+err.Error(), "danger") - c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create") - return - } - - SetFlashMessage(c, "user created successfully", "success") - c.Redirect(http.StatusSeeOther, "/admin/users/") -} diff --git a/internal/server/ldapsync.go b/internal/server/ldapsync.go deleted file mode 100644 index b45bc46..0000000 --- a/internal/server/ldapsync.go +++ /dev/null @@ -1,163 +0,0 @@ -package server - -import ( - "strings" - "time" - - "github.com/h44z/wg-portal/internal/ldap" - "github.com/h44z/wg-portal/internal/users" - "github.com/sirupsen/logrus" - "gorm.io/gorm" -) - -func (s *Server) SyncLdapWithUserDatabase() { - logrus.Info("starting ldap user synchronization...") - running := true - for running { - // Select blocks until one of the cases happens - select { - case <-time.After(1 * time.Minute): - // Sleep for 1 minute - case <-s.ctx.Done(): - logrus.Trace("ldap-sync shutting down (context ended)...") - running = false - continue - } - - // Main work here - logrus.Trace("syncing ldap users to database...") - ldapUsers, err := ldap.FindAllUsers(&s.config.LDAP) - if err != nil { - logrus.Errorf("failed to fetch users from ldap: %v", err) - continue - } - logrus.Tracef("found %d users in ldap", len(ldapUsers)) - - // Update existing LDAP users - s.updateLdapUsers(ldapUsers) - - // Disable missing LDAP users - s.disableMissingLdapUsers(ldapUsers) - } - logrus.Info("ldap user synchronization stopped") -} - -func (s Server) userChangedInLdap(user *users.User, ldapData *ldap.RawLdapData) bool { - if user.Firstname != ldapData.Attributes[s.config.LDAP.FirstNameAttribute] { - return true - } - if user.Lastname != ldapData.Attributes[s.config.LDAP.LastNameAttribute] { - return true - } - if user.Email != strings.ToLower(ldapData.Attributes[s.config.LDAP.EmailAttribute]) { - return true - } - if user.Phone != ldapData.Attributes[s.config.LDAP.PhoneAttribute] { - return true - } - if user.Source != users.UserSourceLdap { - return true - } - - if user.DeletedAt.Valid { - return true - } - - ldapAdmin := false - for _, group := range ldapData.RawAttributes[s.config.LDAP.GroupMemberAttribute] { - if string(group) == s.config.LDAP.AdminLdapGroup { - ldapAdmin = true - break - } - } - if user.IsAdmin != ldapAdmin { - return true - } - - return false -} - -func (s *Server) disableMissingLdapUsers(ldapUsers []ldap.RawLdapData) { - // Disable missing LDAP users - activeUsers := s.users.GetUsers() - for i := range activeUsers { - if activeUsers[i].Source != users.UserSourceLdap { - continue - } - - existsInLDAP := false - for j := range ldapUsers { - if activeUsers[i].Email == strings.ToLower(ldapUsers[j].Attributes[s.config.LDAP.EmailAttribute]) { - existsInLDAP = true - break - } - } - - if existsInLDAP { - continue - } - - // disable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(activeUsers[i].Email) { - now := time.Now() - peer.DeactivatedAt = &now - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update deactivated peer %s: %v", peer.PublicKey, err) - } - } - - if err := s.users.DeleteUser(&activeUsers[i]); err != nil { - logrus.Errorf("failed to delete deactivated user %s in database: %v", activeUsers[i].Email, err) - } - } -} - -func (s *Server) updateLdapUsers(ldapUsers []ldap.RawLdapData) { - for i := range ldapUsers { - if ldapUsers[i].Attributes[s.config.LDAP.EmailAttribute] == "" { - logrus.Tracef("skipping sync of %s, empty email attribute", ldapUsers[i].DN) - continue - } - - user, err := s.users.GetOrCreateUserUnscoped(ldapUsers[i].Attributes[s.config.LDAP.EmailAttribute]) - if err != nil { - logrus.Errorf("failed to get/create user %s in database: %v", ldapUsers[i].Attributes[s.config.LDAP.EmailAttribute], err) - } - - // re-enable LDAP user if the user was disabled - if user.DeletedAt.Valid { - // enable all peers for the given user - for _, peer := range s.peers.GetPeersByMail(user.Email) { - now := time.Now() - peer.DeactivatedAt = nil - if err = s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update activated peer %s: %v", peer.PublicKey, err) - } - } - } - - // Sync attributes from ldap - if s.userChangedInLdap(user, &ldapUsers[i]) { - logrus.Debugf("updating ldap user %s", user.Email) - user.Firstname = ldapUsers[i].Attributes[s.config.LDAP.FirstNameAttribute] - user.Lastname = ldapUsers[i].Attributes[s.config.LDAP.LastNameAttribute] - user.Email = ldapUsers[i].Attributes[s.config.LDAP.EmailAttribute] - user.Phone = ldapUsers[i].Attributes[s.config.LDAP.PhoneAttribute] - user.IsAdmin = false - user.Source = users.UserSourceLdap - user.DeletedAt = gorm.DeletedAt{} // Not deleted - - for _, group := range ldapUsers[i].RawAttributes[s.config.LDAP.GroupMemberAttribute] { - if string(group) == s.config.LDAP.AdminLdapGroup { - user.IsAdmin = true - break - } - } - - if err = s.users.UpdateUser(user); err != nil { - logrus.Errorf("failed to update ldap user %s in database: %v", user.Email, err) - continue - } - } - } -} diff --git a/internal/server/routes.go b/internal/server/routes.go deleted file mode 100644 index 5d10776..0000000 --- a/internal/server/routes.go +++ /dev/null @@ -1,207 +0,0 @@ -package server - -import ( - "net/http" - "strings" - - "github.com/gin-gonic/gin" - wgportal "github.com/h44z/wg-portal" - _ "github.com/h44z/wg-portal/internal/server/docs" // docs is generated by Swag CLI, you have to import it. - ginSwagger "github.com/swaggo/gin-swagger" - "github.com/swaggo/gin-swagger/swaggerFiles" - csrf "github.com/utrack/gin-csrf" -) - -func SetupRoutes(s *Server) { - csrfMiddleware := csrf.Middleware(csrf.Options{ - Secret: s.config.Core.SessionSecret, - ErrorFunc: func(c *gin.Context) { - c.String(400, "CSRF token mismatch") - c.Abort() - }, - }) - - // Startpage - s.server.GET("/", s.GetIndex) - s.server.GET("/favicon.ico", func(c *gin.Context) { - file, _ := wgportal.Statics.ReadFile("assets/img/favicon.ico") - c.Data( - http.StatusOK, - "image/x-icon", - file, - ) - }) - - // Auth routes - auth := s.server.Group("/auth") - auth.Use(csrfMiddleware) - auth.GET("/login", s.GetLogin) - auth.POST("/login", s.PostLogin) - auth.GET("/logout", s.GetLogout) - - // Admin routes - admin := s.server.Group("/admin") - admin.Use(csrfMiddleware) - admin.Use(s.RequireAuthentication("admin")) - admin.GET("/", s.GetAdminIndex) - admin.GET("/device/edit", s.GetAdminEditInterface) - admin.POST("/device/edit", s.PostAdminEditInterface) - admin.GET("/device/download", s.GetInterfaceConfig) - admin.GET("/device/write", s.GetSaveConfig) - admin.GET("/device/applyglobals", s.GetApplyGlobalConfig) - admin.GET("/peer/edit", s.GetAdminEditPeer) - admin.POST("/peer/edit", s.PostAdminEditPeer) - admin.GET("/peer/create", s.GetAdminCreatePeer) - admin.POST("/peer/create", s.PostAdminCreatePeer) - admin.GET("/peer/createldap", s.GetAdminCreateLdapPeers) - admin.POST("/peer/createldap", s.PostAdminCreateLdapPeers) - admin.GET("/peer/delete", s.GetAdminDeletePeer) - admin.GET("/peer/download", s.GetPeerConfig) - admin.GET("/peer/email", s.GetPeerConfigMail) - admin.GET("/peer/emailall", s.GetAdminSendEmails) - - admin.GET("/users/", s.GetAdminUsersIndex) - admin.GET("/users/create", s.GetAdminUsersCreate) - admin.POST("/users/create", s.PostAdminUsersCreate) - admin.GET("/users/edit", s.GetAdminUsersEdit) - admin.POST("/users/edit", s.PostAdminUsersEdit) - - // User routes - user := s.server.Group("/user") - user.Use(csrfMiddleware) - user.Use(s.RequireAuthentication("")) // empty scope = all logged in users - user.GET("/qrcode", s.GetPeerQRCode) - user.GET("/profile", s.GetUserIndex) - user.GET("/download", s.GetPeerConfig) - user.GET("/email", s.GetPeerConfigMail) - user.GET("/status", s.GetPeerStatus) -} - -func SetupApiRoutes(s *Server) { - api := ApiServer{s: s} - - // Admin authenticated routes - apiV1Backend := s.server.Group("/api/v1/backend") - apiV1Backend.Use(s.RequireApiAuthentication("admin")) - - apiV1Backend.GET("/users", api.GetUsers) - apiV1Backend.POST("/users", api.PostUser) - apiV1Backend.GET("/user", api.GetUser) - apiV1Backend.PUT("/user", api.PutUser) - apiV1Backend.PATCH("/user", api.PatchUser) - apiV1Backend.DELETE("/user", api.DeleteUser) - - apiV1Backend.GET("/peers", api.GetPeers) - apiV1Backend.POST("/peers", api.PostPeer) - apiV1Backend.GET("/peer", api.GetPeer) - apiV1Backend.PUT("/peer", api.PutPeer) - apiV1Backend.PATCH("/peer", api.PatchPeer) - apiV1Backend.DELETE("/peer", api.DeletePeer) - - apiV1Backend.GET("/devices", api.GetDevices) - apiV1Backend.GET("/device", api.GetDevice) - apiV1Backend.PUT("/device", api.PutDevice) - apiV1Backend.PATCH("/device", api.PatchDevice) - - // Simple authenticated routes - apiV1Deployment := s.server.Group("/api/v1/provisioning") - apiV1Deployment.Use(s.RequireApiAuthentication("")) - - apiV1Deployment.GET("/peers", api.GetPeerDeploymentInformation) - apiV1Deployment.GET("/peer", api.GetPeerDeploymentConfig) - apiV1Deployment.POST("/peers", api.PostPeerDeploymentConfig) - - // Swagger doc/ui - s.server.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) -} - -func (s *Server) RequireAuthentication(scope string) gin.HandlerFunc { - return func(c *gin.Context) { - session := GetSessionData(c) - - if !session.LoggedIn { - // Abort the request with the appropriate error code - c.Abort() - c.Redirect(http.StatusSeeOther, "/auth/login?err=loginreq") - return - } - - if scope == "admin" && !session.IsAdmin { - // Abort the request with the appropriate error code - c.Abort() - s.GetHandleError(c, http.StatusUnauthorized, "unauthorized", "not enough permissions") - return - } - - // default case if some random scope was set... - if scope != "" && !session.IsAdmin { - // Abort the request with the appropriate error code - c.Abort() - s.GetHandleError(c, http.StatusUnauthorized, "unauthorized", "not enough permissions") - return - } - - // Check if logged-in user is still valid - if !s.isUserStillValid(session.Email) { - _ = DestroySessionData(c) - c.Abort() - s.GetHandleError(c, http.StatusUnauthorized, "unauthorized", "session no longer available") - return - } - - // Continue down the chain to handler etc - c.Next() - } -} - -func (s *Server) RequireApiAuthentication(scope string) gin.HandlerFunc { - return func(c *gin.Context) { - username, password, hasAuth := c.Request.BasicAuth() - if !hasAuth { - c.Abort() - c.JSON(http.StatusUnauthorized, ApiError{Message: "unauthorized"}) - return - } - - // Validate form input - if strings.Trim(username, " ") == "" || strings.Trim(password, " ") == "" { - c.Abort() - c.JSON(http.StatusUnauthorized, ApiError{Message: "unauthorized"}) - return - } - - // Check all available auth backends - user, err := s.checkAuthentication(username, password) - if err != nil { - c.Abort() - c.JSON(http.StatusInternalServerError, ApiError{Message: "login error"}) - return - } - - // Check if user is authenticated - if user == nil { - c.Abort() - c.JSON(http.StatusUnauthorized, ApiError{Message: "unauthorized"}) - return - } - - // Check admin scope - if scope == "admin" && !user.IsAdmin { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusForbidden, ApiError{Message: "unauthorized"}) - return - } - - // default case if some random scope was set... - if scope != "" && !user.IsAdmin { - // Abort the request with the appropriate error code - c.Abort() - c.JSON(http.StatusForbidden, ApiError{Message: "unauthorized"}) - return - } - - // Continue down the chain to handler etc - c.Next() - } -} diff --git a/internal/server/server.go b/internal/server/server.go deleted file mode 100644 index 635819e..0000000 --- a/internal/server/server.go +++ /dev/null @@ -1,344 +0,0 @@ -package server - -import ( - "context" - "encoding/gob" - "html/template" - "io/fs" - "io/ioutil" - "math/rand" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" - "time" - - "github.com/gin-contrib/sessions" - "github.com/gin-contrib/sessions/memstore" - "github.com/gin-gonic/gin" - wgportal "github.com/h44z/wg-portal" - ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap" - passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password" - "github.com/h44z/wg-portal/internal/common" - "github.com/h44z/wg-portal/internal/users" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - ginlogrus "github.com/toorop/gin-logrus" - "gorm.io/gorm" -) - -const SessionIdentifier = "wgPortalSession" - -func init() { - gob.Register(SessionData{}) - gob.Register(FlashData{}) - gob.Register(wireguard.Peer{}) - gob.Register(wireguard.Device{}) - gob.Register(LdapCreateForm{}) - gob.Register(users.User{}) -} - -type SessionData struct { - LoggedIn bool - IsAdmin bool - Firstname string - Lastname string - Email string - DeviceName string - - SortedBy map[string]string - SortDirection map[string]string - Search map[string]string - - AlertData string - AlertType string - FormData interface{} -} - -type FlashData struct { - HasAlert bool - Message string - Type string -} - -type StaticData struct { - WebsiteTitle string - WebsiteLogo string - CompanyName string - Year int - Version string -} - -type Server struct { - ctx context.Context - config *Config - server *gin.Engine - mailTpl *template.Template - auth *AuthManager - - db *gorm.DB - users *users.Manager - wg *wireguard.Manager - peers *wireguard.PeerManager -} - -func (s *Server) Setup(ctx context.Context) error { - var err error - - dir := s.getExecutableDirectory() - rDir, _ := filepath.Abs(filepath.Dir(os.Args[0])) - logrus.Infof("real working directory: %s", rDir) - logrus.Infof("current working directory: %s", dir) - - // Init rand - rand.Seed(time.Now().UnixNano()) - - s.config = NewConfig() - s.ctx = ctx - - // Setup database connection - s.db, err = common.GetDatabaseForConfig(&s.config.Database) - if err != nil { - return errors.WithMessage(err, "database setup failed") - } - err = common.MigrateDatabase(s.db, DatabaseVersion) - if err != nil { - return errors.WithMessage(err, "database migration failed") - } - - // Setup http server - gin.SetMode(gin.DebugMode) - gin.DefaultWriter = ioutil.Discard - s.server = gin.New() - if logrus.GetLevel() == logrus.TraceLevel { - s.server.Use(ginlogrus.Logger(logrus.StandardLogger())) - } - s.server.Use(gin.Recovery()) - s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte(s.config.Core.SessionSecret)))) - s.server.SetFuncMap(template.FuncMap{ - "formatBytes": common.ByteCountSI, - "urlEncode": url.QueryEscape, - "startsWith": strings.HasPrefix, - "userForEmail": func(users []users.User, email string) *users.User { - for i := range users { - if users[i].Email == email { - return &users[i] - } - } - return nil - }, - }) - - // Setup templates - templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wgportal.Templates, "assets/tpl/*.html")) - s.server.SetHTMLTemplate(templates) - - // Serve static files - s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/css")))) - s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/js")))) - s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/img")))) - s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/fonts")))) - - // Setup all routes - SetupRoutes(s) - SetupApiRoutes(s) - - // Setup user database (also needed for database authentication) - s.users, err = users.NewManager(s.db) - if err != nil { - return errors.WithMessage(err, "user-manager initialization failed") - } - - // Setup auth manager - s.auth = NewAuthManager(s) - pwProvider, err := passwordprovider.New(&s.config.Database) - if err != nil { - return errors.WithMessage(err, "password provider initialization failed") - } - if err = pwProvider.InitializeAdmin(s.config.Core.AdminUser, s.config.Core.AdminPassword); err != nil { - return errors.WithMessage(err, "admin initialization failed") - } - s.auth.RegisterProvider(pwProvider) - - if s.config.Core.LdapEnabled { - ldapProvider, err := ldapprovider.New(&s.config.LDAP) - if err != nil { - s.config.Core.LdapEnabled = false - logrus.Warnf("failed to setup LDAP connection, LDAP features disabled") - } - s.auth.RegisterProviderWithoutError(ldapProvider, err) - } - - // Setup WireGuard stuff - s.wg = &wireguard.Manager{Cfg: &s.config.WG} - if err = s.wg.Init(); err != nil { - return errors.WithMessage(err, "unable to initialize WireGuard manager") - } - - // Setup peer manager - if s.peers, err = wireguard.NewPeerManager(s.db, s.wg); err != nil { - return errors.WithMessage(err, "unable to setup peer manager") - } - - for _, deviceName := range s.wg.Cfg.DeviceNames { - if err = s.RestoreWireGuardInterface(deviceName); err != nil { - return errors.WithMessagef(err, "unable to restore WireGuard state for %s", deviceName) - } - } - - // Setup mail template - s.mailTpl, err = template.New("email.html").ParseFS(wgportal.Templates, "assets/tpl/email.html") - if err != nil { - return errors.Wrap(err, "unable to pare mail template") - } - - logrus.Infof("setup of service completed!") - return nil -} - -func (s *Server) Run() { - logrus.Infof("starting web service on %s", s.config.Core.ListeningAddress) - - // Start ldap sync - if s.config.Core.LdapEnabled { - go s.SyncLdapWithUserDatabase() - } - - // Run web service - srv := &http.Server{ - Addr: s.config.Core.ListeningAddress, - Handler: s.server, - } - - go func() { - if err := srv.ListenAndServe(); err != nil { - logrus.Debugf("web service on %s exited: %v", s.config.Core.ListeningAddress, err) - } - }() - - <-s.ctx.Done() - - logrus.Debug("web service shutting down...") - - shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - _ = srv.Shutdown(shutdownCtx) - -} - -func (s *Server) getExecutableDirectory() string { - dir, err := filepath.Abs(filepath.Dir(os.Args[0])) - if err != nil { - logrus.Errorf("failed to get executable directory: %v", err) - } - - if _, err := os.Stat(filepath.Join(dir, "assets")); os.IsNotExist(err) { - return "." // assets directory not found -> we are developing in goland =) - } - - return dir -} - -func (s *Server) getStaticData() StaticData { - return StaticData{ - WebsiteTitle: s.config.Core.Title, - WebsiteLogo: "/img/header-logo.png", - CompanyName: s.config.Core.CompanyName, - Year: time.Now().Year(), - Version: Version, - } -} - -func GetSessionData(c *gin.Context) SessionData { - session := sessions.Default(c) - rawSessionData := session.Get(SessionIdentifier) - - var sessionData SessionData - if rawSessionData != nil { - sessionData = rawSessionData.(SessionData) - } else { - sessionData = SessionData{ - Search: map[string]string{"peers": "", "userpeers": "", "users": ""}, - SortedBy: map[string]string{"peers": "handshake", "userpeers": "id", "users": "email"}, - SortDirection: map[string]string{"peers": "desc", "userpeers": "asc", "users": "asc"}, - Email: "", - Firstname: "", - Lastname: "", - DeviceName: "", - IsAdmin: false, - LoggedIn: false, - } - session.Set(SessionIdentifier, sessionData) - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session: %v", err) - } - } - - return sessionData -} - -func GetFlashes(c *gin.Context) []FlashData { - session := sessions.Default(c) - flashes := session.Flashes() - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session after setting flash: %v", err) - } - - flashData := make([]FlashData, len(flashes)) - for i := range flashes { - flashData[i] = flashes[i].(FlashData) - } - - return flashData -} - -func UpdateSessionData(c *gin.Context, data SessionData) error { - session := sessions.Default(c) - session.Set(SessionIdentifier, data) - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session: %v", err) - return errors.Wrap(err, "failed to store session") - } - return nil -} - -func DestroySessionData(c *gin.Context) error { - session := sessions.Default(c) - session.Delete(SessionIdentifier) - if err := session.Save(); err != nil { - logrus.Errorf("failed to destroy session: %v", err) - return errors.Wrap(err, "failed to destroy session") - } - return nil -} - -func SetFlashMessage(c *gin.Context, message, typ string) { - session := sessions.Default(c) - session.AddFlash(FlashData{ - Message: message, - Type: typ, - }) - if err := session.Save(); err != nil { - logrus.Errorf("failed to store session after setting flash: %v", err) - } -} - -func (s SessionData) GetSortIcon(table, field string) string { - if s.SortedBy[table] != field { - return "fa-sort" - } - if s.SortDirection[table] == "asc" { - return "fa-sort-alpha-down" - } else { - return "fa-sort-alpha-up" - } -} - -func fsMust(f fs.FS, err error) fs.FS { - if err != nil { - panic(err) - } - return f -} diff --git a/internal/server/server_helper.go b/internal/server/server_helper.go deleted file mode 100644 index 3cb90db..0000000 --- a/internal/server/server_helper.go +++ /dev/null @@ -1,358 +0,0 @@ -package server - -import ( - "crypto/md5" - "fmt" - "io/ioutil" - "path" - "syscall" - "time" - - "github.com/h44z/wg-portal/internal/users" - "github.com/h44z/wg-portal/internal/wireguard" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "golang.org/x/crypto/bcrypt" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "gorm.io/gorm" -) - -// PrepareNewPeer initiates a new peer for the given WireGuard device. -func (s *Server) PrepareNewPeer(device string) (wireguard.Peer, error) { - dev := s.peers.GetDevice(device) - deviceIPs := dev.GetIPAddresses() - - peer := wireguard.Peer{} - peer.IsNew = true - - switch dev.Type { - case wireguard.DeviceTypeServer: - peerIPs := make([]string, len(deviceIPs)) - for i := range deviceIPs { - freeIP, err := s.peers.GetAvailableIp(device, deviceIPs[i]) - if err != nil { - return wireguard.Peer{}, errors.WithMessage(err, "failed to get available IP addresses") - } - peerIPs[i] = freeIP - } - peer.SetIPAddresses(peerIPs...) - psk, err := wgtypes.GenerateKey() - if err != nil { - return wireguard.Peer{}, errors.Wrap(err, "failed to generate key") - } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return wireguard.Peer{}, errors.Wrap(err, "failed to generate private key") - } - peer.PresharedKey = psk.String() - peer.PrivateKey = key.String() - peer.PublicKey = key.PublicKey().String() - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - peer.Endpoint = dev.DefaultEndpoint - peer.DNSStr = dev.DNSStr - peer.PersistentKeepalive = dev.DefaultPersistentKeepalive - peer.AllowedIPsStr = dev.DefaultAllowedIPsStr - peer.Mtu = dev.Mtu - peer.DeviceName = device - case wireguard.DeviceTypeClient: - peer.UID = "newendpoint" - } - - return peer, nil -} - -// CreatePeerByEmail creates a new peer for the given email. -func (s *Server) CreatePeerByEmail(device, email, identifierSuffix string, disabled bool) error { - user := s.users.GetUser(email) - - peer, err := s.PrepareNewPeer(device) - if err != nil { - return errors.WithMessage(err, "failed to prepare new peer") - } - peer.Email = email - if user != nil { - peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix) - } else { - peer.Identifier = fmt.Sprintf("%s (%s)", email, identifierSuffix) - } - now := time.Now() - if disabled { - peer.DeactivatedAt = &now - } - - return s.CreatePeer(device, peer) -} - -// CreatePeer creates the new peer in the database. If the peer has no assigned ip addresses, a new one will be assigned -// automatically. Also, if the private key is empty, a new key-pair will be generated. -// This function also configures the new peer on the physical WireGuard interface if the peer is not deactivated. -func (s *Server) CreatePeer(device string, peer wireguard.Peer) error { - dev := s.peers.GetDevice(device) - deviceIPs := dev.GetIPAddresses() - peerIPs := peer.GetIPAddresses() - - peer.AllowedIPsStr = dev.DefaultAllowedIPsStr - if len(peerIPs) == 0 && dev.Type == wireguard.DeviceTypeServer { - peerIPs = make([]string, len(deviceIPs)) - for i := range deviceIPs { - freeIP, err := s.peers.GetAvailableIp(device, deviceIPs[i]) - if err != nil { - return errors.WithMessage(err, "failed to get available IP addresses") - } - peerIPs[i] = freeIP - } - peer.SetIPAddresses(peerIPs...) - } - if peer.PrivateKey == "" && dev.Type == wireguard.DeviceTypeServer { // if private key is empty create a new one - psk, err := wgtypes.GenerateKey() - if err != nil { - return errors.Wrap(err, "failed to generate key") - } - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - return errors.Wrap(err, "failed to generate private key") - } - peer.PresharedKey = psk.String() - peer.PrivateKey = key.String() - peer.PublicKey = key.PublicKey().String() - } - peer.DeviceName = dev.DeviceName - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - - // Create WireGuard interface - if peer.DeactivatedAt == nil { - if err := s.wg.AddPeer(device, peer.GetConfig(&dev)); err != nil { - return errors.WithMessage(err, "failed to add WireGuard peer") - } - } - - // Create in database - if err := s.peers.CreatePeer(peer); err != nil { - return errors.WithMessage(err, "failed to create peer") - } - - return s.WriteWireGuardConfigFile(device) -} - -// UpdatePeer updates the physical WireGuard interface and the database. -func (s *Server) UpdatePeer(peer wireguard.Peer, updateTime time.Time) error { - currentPeer := s.peers.GetPeerByKey(peer.PublicKey) - dev := s.peers.GetDevice(peer.DeviceName) - - // Update WireGuard device - var err error - switch { - case peer.DeactivatedAt != nil && *peer.DeactivatedAt == updateTime: - err = s.wg.RemovePeer(peer.DeviceName, peer.PublicKey) - case peer.DeactivatedAt == nil && currentPeer.Peer != nil: - err = s.wg.UpdatePeer(peer.DeviceName, peer.GetConfig(&dev)) - case peer.DeactivatedAt == nil && currentPeer.Peer == nil: - err = s.wg.AddPeer(peer.DeviceName, peer.GetConfig(&dev)) - } - if err != nil { - return errors.WithMessage(err, "failed to update WireGuard peer") - } - - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - - // Update in database - if err := s.peers.UpdatePeer(peer); err != nil { - return errors.WithMessage(err, "failed to update peer") - } - - return s.WriteWireGuardConfigFile(peer.DeviceName) -} - -// DeletePeer removes the peer from the physical WireGuard interface and the database. -func (s *Server) DeletePeer(peer wireguard.Peer) error { - // Delete WireGuard peer - if err := s.wg.RemovePeer(peer.DeviceName, peer.PublicKey); err != nil { - return errors.WithMessage(err, "failed to remove WireGuard peer") - } - - // Delete in database - if err := s.peers.DeletePeer(peer); err != nil { - return errors.WithMessage(err, "failed to remove peer") - } - - return s.WriteWireGuardConfigFile(peer.DeviceName) -} - -// RestoreWireGuardInterface restores the state of the physical WireGuard interface from the database. -func (s *Server) RestoreWireGuardInterface(device string) error { - activePeers := s.peers.GetActivePeers(device) - dev := s.peers.GetDevice(device) - - for i := range activePeers { - if activePeers[i].Peer == nil { - if err := s.wg.AddPeer(device, activePeers[i].GetConfig(&dev)); err != nil { - return errors.WithMessage(err, "failed to add WireGuard peer") - } - } - } - - return nil -} - -// WriteWireGuardConfigFile writes the configuration file for the physical WireGuard interface. -func (s *Server) WriteWireGuardConfigFile(device string) error { - if s.config.WG.ConfigDirectoryPath == "" { - return nil // writing disabled - } - if err := syscall.Access(s.config.WG.ConfigDirectoryPath, syscall.O_RDWR); err != nil { - return errors.Wrap(err, "failed to check WireGuard config access rights") - } - - dev := s.peers.GetDevice(device) - cfg, err := dev.GetConfigFile(s.peers.GetActivePeers(device)) - if err != nil { - return errors.WithMessage(err, "failed to get config file") - } - filePath := path.Join(s.config.WG.ConfigDirectoryPath, dev.DeviceName+".conf") - if err := ioutil.WriteFile(filePath, cfg, 0644); err != nil { - return errors.Wrap(err, "failed to write WireGuard config file") - } - return nil -} - -// CreateUser creates the user in the database and optionally adds a default WireGuard peer for the user. -func (s *Server) CreateUser(user users.User, device string) error { - if user.Email == "" { - return errors.New("cannot create user with empty email address") - } - - // Check if user already exists, if so re-enable - if existingUser := s.users.GetUserUnscoped(user.Email); existingUser != nil { - user.DeletedAt = gorm.DeletedAt{} // reset deleted flag to enable that user again - return s.UpdateUser(user) - } - - // Hash user password (if set) - if user.Password != "" { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "unable to hash password") - } - user.Password = users.PrivateString(hashedPassword) - } - - // Create user in database - if err := s.users.CreateUser(&user); err != nil { - return errors.WithMessage(err, "failed to create user in manager") - } - - // Check if user already has a peer setup, if not, create one - return s.CreateUserDefaultPeer(user.Email, device) -} - -// UpdateUser updates the user in the database. If the user is marked as deleted, it will get remove from the database. -// Also, if the user is re-enabled, all it's linked WireGuard peers will be activated again. -func (s *Server) UpdateUser(user users.User) error { - if user.DeletedAt.Valid { - return s.DeleteUser(user) - } - - currentUser := s.users.GetUserUnscoped(user.Email) - - // Hash user password (if set) - if user.Password != "" { - hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) - if err != nil { - return errors.Wrap(err, "unable to hash password") - } - user.Password = users.PrivateString(hashedPassword) - } else { - user.Password = currentUser.Password // keep current password - } - - // Update in database - if err := s.users.UpdateUser(&user); err != nil { - return errors.WithMessage(err, "failed to update user in manager") - } - - // If user was deleted (disabled), reactivate it's peers - if currentUser.DeletedAt.Valid { - for _, peer := range s.peers.GetPeersByMail(user.Email) { - now := time.Now() - peer.DeactivatedAt = nil - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update (re)activated peer %s for %s: %v", peer.PublicKey, user.Email, err) - } - } - } - - return nil -} - -// DeleteUser removes the user from the database. -// Also, if the user has linked WireGuard peers, they will be deactivated. -func (s *Server) DeleteUser(user users.User) error { - currentUser := s.users.GetUserUnscoped(user.Email) - - // Update in database - if err := s.users.DeleteUser(&user); err != nil { - return errors.WithMessage(err, "failed to delete user in manager") - } - - // If user was active, disable it's peers - if !currentUser.DeletedAt.Valid { - for _, peer := range s.peers.GetPeersByMail(user.Email) { - now := time.Now() - peer.DeactivatedAt = &now - if err := s.UpdatePeer(peer, now); err != nil { - logrus.Errorf("failed to update deactivated peer %s for %s: %v", peer.PublicKey, user.Email, err) - } - } - } - - return nil -} - -func (s *Server) CreateUserDefaultPeer(email, device string) error { - // Check if automatic peer creation is enabled - if !s.config.Core.CreateDefaultPeer { - return nil - } - - // Check if user is active, if not, quit - var existingUser *users.User - if existingUser = s.users.GetUser(email); existingUser == nil { - return nil - } - - // Check if user already has a peer setup, if not, create one - peers := s.peers.GetPeersByMail(email) - if len(peers) != 0 { - return nil - } - - // Create default vpn peer - peer, err := s.PrepareNewPeer(device) - if err != nil { - return errors.WithMessage(err, "failed to prepare new peer") - } - peer.Email = email - if existingUser.Firstname != "" && existingUser.Lastname != "" { - peer.Identifier = fmt.Sprintf("%s %s (%s)", existingUser.Firstname, existingUser.Lastname, "Default") - } else { - peer.Identifier = fmt.Sprintf("%s (%s)", existingUser.Email, "Default") - } - peer.CreatedBy = existingUser.Email - peer.UpdatedBy = existingUser.Email - if err := s.CreatePeer(device, peer); err != nil { - return errors.WithMessagef(err, "failed to automatically create vpn peer for %s", email) - } - - return nil -} - -func (s *Server) GetDeviceNames() map[string]string { - devNames := make(map[string]string, len(s.wg.Cfg.DeviceNames)) - - for _, devName := range s.wg.Cfg.DeviceNames { - dev := s.peers.GetDevice(devName) - devNames[devName] = dev.DisplayName - } - - return devNames -} diff --git a/internal/server/version.go b/internal/server/version.go deleted file mode 100644 index a1b0048..0000000 --- a/internal/server/version.go +++ /dev/null @@ -1,4 +0,0 @@ -package server - -var Version = "testbuild" -var DatabaseVersion = "1.0.8" diff --git a/internal/users/manager.go b/internal/users/manager.go deleted file mode 100644 index 53fb5d4..0000000 --- a/internal/users/manager.go +++ /dev/null @@ -1,224 +0,0 @@ -package users - -import ( - "sort" - "strconv" - "strings" - "time" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "gorm.io/gorm" -) - -type Manager struct { - db *gorm.DB -} - -func NewManager(db *gorm.DB) (*Manager, error) { - m := &Manager{db: db} - - // check if old user table exists (from version <= 1.0.2), if so rename it to peers. - if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") { - if err := m.db.Migrator().RenameTable("users", "peers"); err != nil { - return nil, errors.Wrapf(err, "failed to migrate old database structure") - } else { - logrus.Infof("upgraded database format from version v1.0.2") - } - } - - if err := m.db.AutoMigrate(&User{}); err != nil { - return nil, errors.Wrap(err, "failed to migrate user database") - } - - return m, nil -} - -func (m Manager) GetUsers() []User { - users := make([]User, 0) - m.db.Find(&users) - return users -} - -func (m Manager) GetUsersUnscoped() []User { - users := make([]User, 0) - m.db.Unscoped().Find(&users) - return users -} - -func (m Manager) UserExists(email string) bool { - return m.GetUser(email) != nil -} - -func (m Manager) GetUser(email string) *User { - email = strings.ToLower(email) - - user := User{} - m.db.Where("email = ?", email).First(&user) - - if user.Email != email { - return nil - } - - return &user -} - -func (m Manager) GetUserUnscoped(email string) *User { - email = strings.ToLower(email) - - user := User{} - m.db.Unscoped().Where("email = ?", email).First(&user) - - if user.Email != email { - return nil - } - - return &user -} - -func (m Manager) GetFilteredAndSortedUsers(sortKey, sortDirection, search string) []User { - users := make([]User, 0) - m.db.Find(&users) - - filteredUsers := filterUsers(users, search) - sortUsers(filteredUsers, sortKey, sortDirection) - - return filteredUsers -} - -func (m Manager) GetFilteredAndSortedUsersUnscoped(sortKey, sortDirection, search string) []User { - users := make([]User, 0) - m.db.Unscoped().Find(&users) - - filteredUsers := filterUsers(users, search) - sortUsers(filteredUsers, sortKey, sortDirection) - - return filteredUsers -} - -func (m Manager) GetOrCreateUser(email string) (*User, error) { - email = strings.ToLower(email) - - user := User{} - m.db.Where("email = ?", email).FirstOrInit(&user) - - if user.Email != email { - user.Email = email - user.CreatedAt = time.Now() - user.UpdatedAt = time.Now() - user.IsAdmin = false - user.Source = UserSourceDatabase - - res := m.db.Create(&user) - if res.Error != nil { - return nil, errors.Wrapf(res.Error, "failed to create user %s", email) - } - } - - return &user, nil -} - -func (m Manager) GetOrCreateUserUnscoped(email string) (*User, error) { - email = strings.ToLower(email) - - user := User{} - m.db.Unscoped().Where("email = ?", email).FirstOrInit(&user) - - if user.Email != email { - user.Email = email - user.CreatedAt = time.Now() - user.UpdatedAt = time.Now() - user.IsAdmin = false - user.Source = UserSourceDatabase - - res := m.db.Create(&user) - if res.Error != nil { - return nil, errors.Wrapf(res.Error, "failed to create user %s", email) - } - } - - return &user, nil -} - -func (m Manager) CreateUser(user *User) error { - user.Email = strings.ToLower(user.Email) - user.Source = UserSourceDatabase - res := m.db.Create(user) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to create user %s", user.Email) - } - - return nil -} - -func (m Manager) UpdateUser(user *User) error { - user.Email = strings.ToLower(user.Email) - res := m.db.Save(user) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to update user %s", user.Email) - } - - return nil -} - -func (m Manager) DeleteUser(user *User) error { - user.Email = strings.ToLower(user.Email) - res := m.db.Delete(user) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to update user %s", user.Email) - } - - return nil -} - -func sortUsers(users []User, key, direction string) { - sort.Slice(users, func(i, j int) bool { - var sortValueLeft string - var sortValueRight string - - switch key { - case "email": - sortValueLeft = users[i].Email - sortValueRight = users[j].Email - case "firstname": - sortValueLeft = users[i].Firstname - sortValueRight = users[j].Firstname - case "lastname": - sortValueLeft = users[i].Lastname - sortValueRight = users[j].Lastname - case "phone": - sortValueLeft = users[i].Phone - sortValueRight = users[j].Phone - case "source": - sortValueLeft = string(users[i].Source) - sortValueRight = string(users[j].Source) - case "admin": - sortValueLeft = strconv.FormatBool(users[i].IsAdmin) - sortValueRight = strconv.FormatBool(users[j].IsAdmin) - } - - if direction == "asc" { - return sortValueLeft < sortValueRight - } else { - return sortValueLeft > sortValueRight - } - }) -} - -func filterUsers(users []User, search string) []User { - if search == "" { - return users - } - - filteredUsers := make([]User, 0, len(users)) - for i := range users { - if strings.Contains(users[i].Email, strings.ToLower(search)) || - strings.Contains(users[i].Firstname, search) || - strings.Contains(users[i].Lastname, search) || - strings.Contains(string(users[i].Source), search) || - strings.Contains(users[i].Phone, search) { - filteredUsers = append(filteredUsers, users[i]) - } - } - return filteredUsers -} diff --git a/internal/users/user.go b/internal/users/user.go deleted file mode 100644 index f01fc87..0000000 --- a/internal/users/user.go +++ /dev/null @@ -1,46 +0,0 @@ -package users - -import ( - "time" - - "gorm.io/gorm" -) - -type UserSource string - -const ( - UserSourceLdap UserSource = "ldap" // LDAP / ActiveDirectory - UserSourceDatabase UserSource = "db" // sqlite / mysql database - UserSourceOIDC UserSource = "oidc" // open id connect, TODO: implement -) - -type PrivateString string - -func (PrivateString) MarshalJSON() ([]byte, error) { - return []byte(`""`), nil -} - -func (PrivateString) String() string { - return "" -} - -// User is the user model that gets linked to peer entries, by default an empty usermodel with only the email address is created -type User struct { - // required fields - Email string `gorm:"primaryKey" form:"email" binding:"required,email"` - Source UserSource - IsAdmin bool - - // optional fields - Firstname string `form:"firstname" binding:"required"` - Lastname string `form:"lastname" binding:"required"` - Phone string `form:"phone" binding:"omitempty"` - - // optional, integrated password authentication - Password PrivateString `form:"password" binding:"omitempty"` - - // database internal fields - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty"` -} diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go deleted file mode 100644 index 2937013..0000000 --- a/internal/wireguard/config.go +++ /dev/null @@ -1,17 +0,0 @@ -package wireguard - -import "github.com/h44z/wg-portal/internal/common" - -type Config struct { - DeviceNames []string `yaml:"devices" envconfig:"WG_DEVICES"` // managed devices - DefaultDeviceName string `yaml:"defaultDevice" envconfig:"WG_DEFAULT_DEVICE"` // this device is used for auto-created peers, use GetDefaultDeviceName() to access this field - ConfigDirectoryPath string `yaml:"configDirectory" envconfig:"WG_CONFIG_PATH"` // optional, if set, updates will be written to this path, filename: .conf - ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface -} - -func (c Config) GetDefaultDeviceName() string { - if c.DefaultDeviceName == "" || !common.ListContains(c.DeviceNames, c.DefaultDeviceName) { - return c.DeviceNames[0] - } - return c.DefaultDeviceName -} diff --git a/internal/wireguard/configuration.go b/internal/wireguard/configuration.go new file mode 100644 index 0000000..a05b33a --- /dev/null +++ b/internal/wireguard/configuration.go @@ -0,0 +1,153 @@ +package wireguard + +import ( + "database/sql" + "time" +) + +// ConfigOption is an Overridable configuration option +type ConfigOption struct { + Value interface{} + Overridable bool +} + +type StringConfigOption struct { + ConfigOption +} + +func (o StringConfigOption) GetValue() string { + return o.Value.(string) +} + +type IntConfigOption struct { + ConfigOption +} + +func (o IntConfigOption) GetValue() int { + return o.Value.(int) +} + +type Int32ConfigOption struct { + ConfigOption +} + +func (o Int32ConfigOption) GetValue() int32 { + return o.Value.(int32) +} + +type BoolConfigOption struct { + ConfigOption +} + +func (o BoolConfigOption) GetValue() bool { + return o.Value.(bool) +} + +type InterfaceType string + +const ( + InterfaceTypeServer InterfaceType = "server" + InterfaceTypeClient InterfaceType = "client" +) + +type DeviceIdentifier string +type PeerIdentifier string + +type InterfaceConfig struct { + // WireGuard specific (for the [interface] section of the config file) + + DeviceName DeviceIdentifier // device name, for example: wg0 + KeyPair KeyPair // private/public Key of the server interface + ListenPort int // the listening port, for example: 51820 + + AddressStr string // the interface ip addresses, comma separated + Dns string // the dns server that should be set if the interface is up + + Mtu int // the device MTU + FirewallMark int32 // a firewall mark + RoutingTable string // the routing table + + PreUp string // action that is executed before the device is up + PostUp string // action that is executed after the device is up + PreDown string // action that is executed before the device is down + PostDown string // action that is executed after the device is down + + SaveConfig bool // automatically persist config changes to the wgX.conf file + + // WG Portal specific + Enabled bool // flag that specifies if the interface is enabled (up) or nor (down) + DisplayName string // a nice display name/ description for the interface + Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient + DriverType string // the interface driver type (linux, software, ...) + + // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of + // the peer config + + PeerDefNetworkStr string // the default subnets from which peers will get their IP addresses, comma seperated + PeerDefDns string // the default dns server for the peer + PeerDefEndpoint string // the default endpoint for the peer + PeerDefAllowedIPsString string // the default allowed IP string for the peer + PeerDefMtu int // the default device MTU + PeerDefPersistentKeepalive int // the default persistent keep-alive Value + PeerDefFirewallMark int32 // default firewall mark + PeerDefRoutingTable string // the default routing table + + PeerDefPreUp string // default action that is executed before the device is up + PeerDefPostUp string // default action that is executed after the device is up + PeerDefPreDown string // default action that is executed before the device is down + PeerDefPostDown string // default action that is executed after the device is down +} + +type PeerConfig struct { + // WireGuard specific (for the [peer] section of the config file) + + Endpoint StringConfigOption // the endpoint address + AllowedIPsString StringConfigOption // all allowed ip subnets, comma seperated + ExtraAllowedIPsString string // all allowed ip subnets on the server side, comma seperated + KeyPair KeyPair // private/public Key of the peer + PresharedKey string // the pre-shared Key of the peer + PersistentKeepalive IntConfigOption // the persistent keep-alive interval + + // WG Portal specific + + Identifier string // a nice display name/ description for the peer + Uid PeerIdentifier // peer unique identifier + + // Interface settings for the peer, used to generate the [interface] section in the peer config file + + AddressStr StringConfigOption // the interface ip addresses, comma separated + Dns StringConfigOption // the dns server that should be set if the interface is up + Mtu IntConfigOption // the device MTU + FirewallMark Int32ConfigOption // a firewall mark + RoutingTable StringConfigOption // the routing table + + PreUp StringConfigOption // action that is executed before the device is up + PostUp StringConfigOption // action that is executed after the device is up + PreDown StringConfigOption // action that is executed before the device is down + PostDown StringConfigOption // action that is executed after the device is down + + // Internal stats + + DeactivatedAt sql.NullTime + CreatedBy string + UpdatedBy string + CreatedAt time.Time + UpdatedAt time.Time +} + +type InterfaceConfigPersister interface { + PersistInterface(cfg InterfaceConfig) + LoadInterface(cfg InterfaceConfig) + DeleteInterface(cfg InterfaceConfig) +} + +type PeerConfigPersister interface { + PersistPeer(cfg PeerConfig) + LoadPeer(cfg PeerConfig) + DeletePeer(cfg PeerConfig) +} + +type ConfigPersister interface { + InterfaceConfigPersister + PeerConfigPersister +} diff --git a/internal/wireguard/keys.go b/internal/wireguard/keys.go new file mode 100644 index 0000000..caa8bb8 --- /dev/null +++ b/internal/wireguard/keys.go @@ -0,0 +1,24 @@ +package wireguard + +import "encoding/base64" + +type KeyPair struct { + PrivateKey string + PublicKey string +} + +type PreSharedKey string + +func (p KeyPair) GetPrivateKeyBytes() []byte { + data, _ := base64.StdEncoding.DecodeString(p.PrivateKey) + return data +} + +func (p KeyPair) GetPublicKeyBytes() []byte { + data, _ := base64.StdEncoding.DecodeString(p.PublicKey) + return data +} + +func KeyBytesToString(key []byte) string { + return base64.StdEncoding.EncodeToString(key) +} diff --git a/internal/wireguard/keys_test.go b/internal/wireguard/keys_test.go new file mode 100644 index 0000000..3275781 --- /dev/null +++ b/internal/wireguard/keys_test.go @@ -0,0 +1,31 @@ +package wireguard + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeyPair_GetPrivateKeyBytes(t *testing.T) { + kp := KeyPair{ + PrivateKey: "aGVsbG8=", + PublicKey: "d29ybGQ=", + } + + got := kp.GetPrivateKeyBytes() + assert.Equal(t, []byte("hello"), got) +} + +func TestKeyPair_GetPublicKeyBytes(t *testing.T) { + kp := KeyPair{ + PrivateKey: "aGVsbG8=", + PublicKey: "d29ybGQ=", + } + + got := kp.GetPublicKeyBytes() + assert.Equal(t, []byte("world"), got) +} + +func TestKeyBytesToString(t *testing.T) { + assert.Equal(t, "aGVsbG8=", KeyBytesToString([]byte("hello"))) +} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index c80c7ac..d387c90 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -1,121 +1,503 @@ package wireguard import ( - "sync" + "bufio" + "fmt" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "time" "github.com/pkg/errors" - "golang.zx2c4.com/wireguard/wgctrl" + "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// Manager offers a synchronized management interface to the real WireGuard interface. -type Manager struct { - Cfg *Config - wg *wgctrl.Client - mux sync.RWMutex +type KeyGenerator interface { + GetFreshKeypair() (KeyPair, error) + GetPreSharedKey() (PreSharedKey, error) } -func (m *Manager) Init() error { - var err error - m.wg, err = wgctrl.New() +// DeviceManager provides methods to create/update/delete physical WireGuard devices. +type DeviceManager interface { + CreateDevice(device DeviceIdentifier) error + DeleteDevice(device DeviceIdentifier) error + UpdateDevice(device DeviceIdentifier, cfg InterfaceConfig) error +} + +type PeerManager interface { + GetPeers(device DeviceIdentifier) ([]PeerConfig, error) + SavePeers(device DeviceIdentifier, peers ...PeerConfig) error + RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error +} + +type Manager interface { + KeyGenerator + DeviceManager + PeerManager +} + +type ManagementUtil struct { + configPath string + + wg Client + nl NetlinkClient + cp ConfigPersister + + // internal holder of interface configurations + interfaces map[DeviceIdentifier]InterfaceConfig + // internal holder of peer configurations + peers map[DeviceIdentifier]map[PeerIdentifier]PeerConfig +} + +func (m ManagementUtil) GetFreshKeypair() (KeyPair, error) { + privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { - return errors.Wrap(err, "could not create WireGuard client") + return KeyPair{}, errors.Wrap(err, "failed to generate private Key") } + return KeyPair{ + PrivateKey: privateKey.String(), + PublicKey: privateKey.PublicKey().String(), + }, nil +} + +func (m ManagementUtil) GetPreSharedKey() (PreSharedKey, error) { + preSharedKey, err := wgtypes.GenerateKey() + if err != nil { + return "", errors.Wrap(err, "failed to generate pre-shared Key") + } + + return PreSharedKey(preSharedKey.String()), nil +} + +func (m *ManagementUtil) CreateDevice(identifier DeviceIdentifier) error { + if m.deviceExists(identifier) { + return errors.Errorf("device %s already exists", identifier) + } + link := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(identifier), + }, + LinkType: "wireguard", + } + err := m.nl.LinkAdd(link) + if err != nil { + return errors.Wrapf(err, "failed to create WireGuard interface") + } + + if err := m.nl.LinkSetUp(link); err != nil { + return errors.Wrapf(err, "failed to enable WireGuard interface") + } + + m.interfaces[identifier] = InterfaceConfig{DeviceName: identifier} + return nil } -func (m *Manager) GetDeviceInfo(device string) (*wgtypes.Device, error) { - dev, err := m.wg.Device(device) +func (m *ManagementUtil) DeleteDevice(identifier DeviceIdentifier) error { + if !m.deviceExists(identifier) { + return errors.Errorf("device %s does not exist", identifier) + } + err := m.nl.LinkDel(&netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(identifier), + }, + LinkType: "wireguard", + }) if err != nil { - return nil, errors.Wrap(err, "could not get WireGuard device") + return errors.Wrapf(err, "failed to delete WireGuard interface") } - return dev, nil + delete(m.interfaces, identifier) + + return nil } -func (m *Manager) GetPeerList(device string) ([]wgtypes.Peer, error) { - m.mux.RLock() - defer m.mux.RUnlock() - - dev, err := m.wg.Device(device) - if err != nil { - return nil, errors.Wrap(err, "could not get WireGuard device") +func (m *ManagementUtil) UpdateDevice(identifier DeviceIdentifier, cfg InterfaceConfig) error { + if !m.deviceExists(identifier) { + return errors.Errorf("device %s does not exist", identifier) } + cfg.DeviceName = identifier // ensure that the same device name is set - return dev.Peers, nil -} - -func (m *Manager) GetPeer(device string, pubKey string) (*wgtypes.Peer, error) { - m.mux.RLock() - defer m.mux.RUnlock() - - publicKey, err := wgtypes.ParseKey(pubKey) + // Update net-link attributes + link, err := m.nl.LinkByName(string(identifier)) if err != nil { - return nil, errors.Wrap(err, "invalid public key") + return errors.Wrapf(err, "failed to open WireGuard interface") } - - peers, err := m.GetPeerList(device) - if err != nil { - return nil, errors.Wrap(err, "could not get WireGuard peers") + if err := m.nl.LinkSetMTU(link, cfg.Mtu); err != nil { + return errors.Wrapf(err, "failed to set MTU") } - - for _, peer := range peers { - if peer.PublicKey == publicKey { - return &peer, nil + addresses, err := parseIpAddressString(cfg.AddressStr) + for i := 0; i < len(addresses); i++ { + var err error + if i == 0 { + err = m.nl.AddrReplace(link, addresses[i]) + } else { + err = m.nl.AddrAdd(link, addresses[i]) + } + if err != nil { + return errors.Wrapf(err, "failed to set ip address %v", addresses[i]) } } - return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey) + // Update WireGuard attributes + pKey, _ := wgtypes.NewKey(cfg.KeyPair.GetPrivateKeyBytes()) + var fwMark *int + if cfg.FirewallMark != 0 { + *fwMark = int(cfg.FirewallMark) + } + err = m.wg.ConfigureDevice(string(identifier), wgtypes.Config{ + PrivateKey: &pKey, + ListenPort: &cfg.ListenPort, + FirewallMark: fwMark, + }) + if err != nil { + return errors.Wrapf(err, "failed to update WireGuard settings") + } + + // Update link state + if cfg.Enabled { + if err := m.nl.LinkSetUp(link); err != nil { + return errors.Wrapf(err, "failed to enable WireGuard interface") + } + } else { + if err := m.nl.LinkSetDown(link); err != nil { + return errors.Wrapf(err, "failed to disable WireGuard interface") + } + } + + m.interfaces[identifier] = cfg + + return nil } -func (m *Manager) AddPeer(device string, cfg wgtypes.PeerConfig) error { - m.mux.Lock() - defer m.mux.Unlock() +func (m ManagementUtil) GetPeers(device DeviceIdentifier) ([]PeerConfig, error) { + if !m.deviceExists(device) { + return nil, errors.Errorf("device %s does not exist", device) + } - err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) - if err != nil { - return errors.Wrap(err, "could not configure WireGuard device") + peers := make([]PeerConfig, 0, len(m.peers[device])) + for _, config := range m.peers[device] { + peers = append(peers, config) + } + + return peers, nil +} + +func (m ManagementUtil) SavePeers(device DeviceIdentifier, peers ...PeerConfig) error { + if !m.deviceExists(device) { + return errors.Errorf("device %s does not exist", device) + } + + deviceConfig := m.interfaces[device] + + for _, peer := range peers { + wgPeer, err := getWireGuardPeerConfig(deviceConfig.Type, peer) + if err != nil { + return errors.Wrapf(err, "could not generate WireGuard peer configuration for %s", peer.Uid) + } + + err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) + if err != nil { + return errors.Wrapf(err, "could not save peer %s to WireGuard device %s", peer.Uid, device) + } + + m.peers[device][peer.Uid] = peer } return nil } -func (m *Manager) UpdatePeer(device string, cfg wgtypes.PeerConfig) error { - m.mux.Lock() - defer m.mux.Unlock() - - cfg.UpdateOnly = true - err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}}) - if err != nil { - return errors.Wrap(err, "could not configure WireGuard device") +func (m ManagementUtil) RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error { + if !m.deviceExists(device) { + return errors.Errorf("device %s does not exist", device) + } + if !m.peerExists(peer) { + return errors.Errorf("peer %s does not exist", peer) } - return nil -} + peerConfig := m.peers[device][peer] -func (m *Manager) RemovePeer(device string, pubKey string) error { - m.mux.Lock() - defer m.mux.Unlock() - - publicKey, err := wgtypes.ParseKey(pubKey) + publicKey, err := wgtypes.ParseKey(peerConfig.KeyPair.PublicKey) if err != nil { - return errors.Wrap(err, "invalid public key") + return errors.Wrapf(err, "invalid public key for peer %s", peer) } - peer := wgtypes.PeerConfig{ + wgPeer := wgtypes.PeerConfig{ PublicKey: publicKey, Remove: true, } - err = m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}}) + err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) if err != nil { - return errors.Wrap(err, "could not configure WireGuard device") + return errors.Wrapf(err, "could not remove peer %s from WireGuard device %s", peer, device) } + delete(m.peers[device], peer) + return nil } -func (m *Manager) UpdateDevice(device string, cfg wgtypes.Config) error { - return m.wg.ConfigureDevice(device, cfg) +// +// ---- Helpers +// + +func getWireGuardPeerConfig(deviceType InterfaceType, peer PeerConfig) (wgtypes.PeerConfig, error) { + publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey) + if err != nil { + return wgtypes.PeerConfig{}, errors.Wrapf(err, "invalid public key for peer %s", peer.Uid) + } + + var presharedKey *wgtypes.Key + if tmpPresharedKey, err := wgtypes.ParseKey(peer.PresharedKey); err == nil { + presharedKey = &tmpPresharedKey + } + + var endpoint *net.UDPAddr + if peer.Endpoint.Value != "" && deviceType == InterfaceTypeClient { + addr, err := net.ResolveUDPAddr("udp", peer.Endpoint.Value.(string)) + if err == nil { + endpoint = addr + } + } + + var keepAlive *time.Duration + if peer.PersistentKeepalive.Value != 0 { + keepAliveDuration := time.Duration(peer.PersistentKeepalive.Value.(int)) * time.Second + keepAlive = &keepAliveDuration + } + + allowedIPs := make([]net.IPNet, 0) + var peerAllowedIPs []*netlink.Addr + switch deviceType { + case InterfaceTypeClient: + peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsString.GetValue()) + if err != nil { + return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid) + } + case InterfaceTypeServer: + peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsString.GetValue()) + if err != nil { + return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid) + } + peerExtraAllowedIPs, err := parseIpAddressString(peer.ExtraAllowedIPsString) + if err != nil { + return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse extra allowed IP's for peer %s", peer.Uid) + } + + peerAllowedIPs = append(peerAllowedIPs, peerExtraAllowedIPs...) + } + for _, ip := range peerAllowedIPs { + allowedIPs = append(allowedIPs, *ip.IPNet) + } + + wgPeer := wgtypes.PeerConfig{ + PublicKey: publicKey, + Remove: false, + UpdateOnly: true, + PresharedKey: presharedKey, + Endpoint: endpoint, + PersistentKeepaliveInterval: keepAlive, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + return wgPeer, nil +} + +func (m ManagementUtil) deviceExists(identifier DeviceIdentifier) bool { + if _, ok := m.interfaces[identifier]; ok { + return true + } + return false +} + +func (m ManagementUtil) peerExists(identifier PeerIdentifier) bool { + for _, peers := range m.peers { + if _, ok := peers[identifier]; ok { + return true + } + } + + return false +} + +// TODO: fix/implement +func (m ManagementUtil) loadExistingInterfaces() ([]InterfaceConfig, error) { + devices, err := m.wg.Devices() + if err != nil { + return nil, errors.Wrapf(err, "failed to get WireGuard device list") + } + + interfaces := make([]InterfaceConfig, len(devices)) + for i, device := range devices { + interfaces[i].DeviceName = DeviceIdentifier(device.Name) + interfaces[i].FirewallMark = int32(device.FirewallMark) + interfaces[i].KeyPair = KeyPair{ + PrivateKey: device.PrivateKey.String(), + PublicKey: device.PublicKey.String(), + } + interfaces[i].ListenPort = device.ListenPort + interfaces[i].DriverType = device.Type.String() + + parsedInterface, _, err := m.parseConfigFile(device.Name) + if err != nil { + continue + } + interfaces[i].Dns = parsedInterface.Dns + interfaces[i].DisplayName = parsedInterface.DisplayName + interfaces[i].PostDown = parsedInterface.PostDown + interfaces[i].PreDown = parsedInterface.PreDown + interfaces[i].PostUp = parsedInterface.PostUp + interfaces[i].PreUp = parsedInterface.PreUp + interfaces[i].AddressStr = parsedInterface.AddressStr + interfaces[i].RoutingTable = parsedInterface.RoutingTable + interfaces[i].Mtu = parsedInterface.Mtu + + fmt.Println(interfaces[i]) + } + + return interfaces, nil +} + +// parseConfigFile parses WireGuard configuration files (INI syntax) and some additional comments in the file +// TODO: fix/implement +func (m ManagementUtil) parseConfigFile(interfaceName string) (InterfaceConfig, []PeerConfig, error) { + configFile := filepath.Join(m.configPath, interfaceName+".conf") + + file, err := os.Open(configFile) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "unable to open config file for interface %s", interfaceName) + } + scanner := bufio.NewScanner(file) + + peerSection := false + iface := InterfaceConfig{} + for scanner.Scan() { + line := scanner.Text() + line = strings.TrimSpace(line) + + switch { + case strings.HasPrefix(line, "#"): // A comment line + line = line[1:] + commentParts := strings.SplitN(line, "=", 1) + fmt.Println(commentParts, peerSection) + case strings.HasPrefix(line, "["): // Config section + line = strings.ToLower(line[1 : len(line)-1]) + switch line { + case "peer": + peerSection = true + case "interface": + peerSection = false + default: + return InterfaceConfig{}, nil, errors.Errorf("configuration file contains unsupported section %s", line) + } + default: //Config option + optionParts := strings.SplitN(line, "=", 1) + if len(optionParts) != 2 { + return InterfaceConfig{}, nil, errors.Errorf("configuration file contains invalid line %s", line) + } + option := strings.ToLower(strings.TrimSpace(optionParts[0])) + value := strings.TrimSpace(optionParts[1]) + peerOption := false + switch option { + // Interface + case "privatekey": + key, err := wgtypes.ParseKey(value) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has no valid private Key") + } + iface.KeyPair = KeyPair{ + PrivateKey: key.String(), + PublicKey: key.PublicKey().String(), + } + case "address": + iface.AddressStr = value + case "listenport": + port, err := strconv.Atoi(value) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid listen port Value") + } + iface.ListenPort = port + case "postup": + iface.PostUp = value + case "postdown": + iface.PostDown = value + case "preup": + iface.PreUp = value + case "predown": + iface.PreDown = value + case "mtu": + mtu, err := strconv.Atoi(value) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid MTU Value") + } + iface.Mtu = mtu + case "dns": + iface.Dns = value + case "table": + iface.RoutingTable = value + case "fwmark": + fwMark, err := strconv.Atoi(value) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid fwmark Value") + } + iface.FirewallMark = int32(fwMark) + case "saveconfig": + saveConfig, err := strconv.ParseBool(value) + if err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid save-config Value") + } + iface.SaveConfig = saveConfig + // Peer + case "endpoint": + peerOption = true + case "publickey": + peerOption = true + case "allowedips": + peerOption = true + case "persistentkeepalive": + peerOption = true + case "presharedkey": + peerOption = true + } + + if peerSection != peerOption { + return InterfaceConfig{}, nil, errors.Errorf("config section contains invalid option %s", option) + } + + fmt.Println(value) + } + if strings.HasPrefix(line, "#") { + fmt.Println("comment") + } + fmt.Println(line) + } + + if err := scanner.Err(); err != nil { + return InterfaceConfig{}, nil, errors.Wrapf(err, "unable to scan config file for interface %s", interfaceName) + } + + return InterfaceConfig{}, nil, nil +} + +func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { + rawAddresses := strings.Split(addrStr, ",") + addresses := make([]*netlink.Addr, 0, len(rawAddresses)) + for i := range rawAddresses { + rawAddress := strings.TrimSpace(rawAddresses[i]) + if rawAddress == "" { + continue // skip empty string + } + address, err := netlink.ParseAddr(rawAddress) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress) + } + addresses = append(addresses, address) + } + + return addresses, nil } diff --git a/internal/wireguard/manager_int_test.go b/internal/wireguard/manager_int_test.go new file mode 100644 index 0000000..74fe4be --- /dev/null +++ b/internal/wireguard/manager_int_test.go @@ -0,0 +1,99 @@ +//go:build integration +// +build integration + +// In Goland you can use File-Nesting to enhance the project view: _int_test.go; _test.go + +// Run integrations tests as root! + +package wireguard + +import ( + "os/exec" + "testing" + + "golang.zx2c4.com/wireguard/wgctrl" + + "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" +) + +func prepareTest(dev DeviceIdentifier) { + _ = netlink.LinkDel(&netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(dev), + }, + LinkType: "wireguard", + }) +} + +func TestManagementUtil_CreateDevice(t *testing.T) { + devName := DeviceIdentifier("wg666") + prepareTest(devName) + m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: NetlinkManager{}} + + defer m.DeleteDevice(devName) + err := m.CreateDevice(devName) + assert.NoError(t, err) + + cmd := exec.Command("ip", "addr") + out, err := cmd.CombinedOutput() + assert.NoError(t, err) + assert.Contains(t, string(out), devName) +} + +func TestManagementUtil_DeleteDevice(t *testing.T) { + devName := DeviceIdentifier("wg667") + prepareTest(devName) + m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: NetlinkManager{}} + + err := m.CreateDevice(devName) + assert.NoError(t, err) + err = m.DeleteDevice(devName) + assert.NoError(t, err) + + cmd := exec.Command("ip", "addr") + out, err := cmd.CombinedOutput() + assert.NoError(t, err) + assert.NotContains(t, string(out), devName) +} + +func TestManagementUtil_deviceExists(t *testing.T) { + m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig)} + assert.False(t, m.deviceExists("test")) + + m = ManagementUtil{interfaces: map[DeviceIdentifier]InterfaceConfig{"test": {}}} + assert.True(t, m.deviceExists("test")) +} + +func TestManagementUtil_UpdateDevice(t *testing.T) { + devName := DeviceIdentifier("wg668") + prepareTest(devName) + wg, err := wgctrl.New() + if !assert.NoError(t, err) { + return + } + m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: NetlinkManager{}, wg: wg} + + defer m.DeleteDevice(devName) + err = m.CreateDevice(devName) + if !assert.NoError(t, err) { + return + } + + err = m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24", Mtu: 1234}) + assert.NoError(t, err) + + cmd := exec.Command("ip", "addr") + out, err := cmd.CombinedOutput() + assert.NoError(t, err) + assert.Contains(t, string(out), "123.123.123.123") + + err = m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24,fd9f:6666::10:6:6:1/64", Mtu: 1600}) + assert.NoError(t, err) + + cmd = exec.Command("ip", "addr") + out, err = cmd.CombinedOutput() + assert.NoError(t, err) + assert.Contains(t, string(out), "123.123.123.123") + assert.Contains(t, string(out), "fd9f:6666::10:6:6:1") +} diff --git a/internal/wireguard/manager_net.go b/internal/wireguard/manager_net.go deleted file mode 100644 index cc934a8..0000000 --- a/internal/wireguard/manager_net.go +++ /dev/null @@ -1,121 +0,0 @@ -package wireguard - -import ( - "fmt" - "net" - - "github.com/milosgajdos/tenus" - "github.com/pkg/errors" -) - -const DefaultMTU = 1420 - -func (m *Manager) GetIPAddress(device string) ([]string, error) { - wgInterface, err := tenus.NewLinkFrom(device) - if err != nil { - return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device) - } - - // Get golang net.interface - iface := wgInterface.NetInterface() - if iface == nil { // Not sure if this check is really necessary - return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface") - } - - addrs, err := iface.Addrs() - if err != nil { - return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses") - } - - ipAddresses := make([]string, 0, len(addrs)) - for _, addr := range addrs { - var ip net.IP - var mask net.IPMask - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - mask = v.Mask - case *net.IPAddr: - ip = v.IP - mask = ip.DefaultMask() - } - if ip == nil || mask == nil { - continue // something is wrong? - } - - maskSize, _ := mask.Size() - cidr := fmt.Sprintf("%s/%d", ip.String(), maskSize) - ipAddresses = append(ipAddresses, cidr) - } - - return ipAddresses, nil -} - -func (m *Manager) SetIPAddress(device string, cidrs []string) error { - wgInterface, err := tenus.NewLinkFrom(device) - if err != nil { - return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device) - } - - // First remove existing IP addresses - existingIPs, err := m.GetIPAddress(device) - if err != nil { - return errors.Wrap(err, "could not retrieve IP addresses") - } - for _, cidr := range existingIPs { - wgIp, wgIpNet, err := net.ParseCIDR(cidr) - if err != nil { - return errors.Wrapf(err, "unable to parse cidr %s", cidr) - } - - if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil { - return errors.Wrapf(err, "failed to unset ip %s", cidr) - } - } - - // Next set new IP addresses - for _, cidr := range cidrs { - wgIp, wgIpNet, err := net.ParseCIDR(cidr) - if err != nil { - return errors.Wrapf(err, "unable to parse cidr %s", cidr) - } - - if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil { - return errors.Wrapf(err, "failed to set ip %s", cidr) - } - } - - return nil -} - -func (m *Manager) GetMTU(device string) (int, error) { - wgInterface, err := tenus.NewLinkFrom(device) - if err != nil { - return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device) - } - - // Get golang net.interface - iface := wgInterface.NetInterface() - if iface == nil { // Not sure if this check is really necessary - return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface") - } - - return iface.MTU, nil -} - -func (m *Manager) SetMTU(device string, mtu int) error { - wgInterface, err := tenus.NewLinkFrom(device) - if err != nil { - return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device) - } - - if mtu == 0 { - mtu = DefaultMTU - } - - if err := wgInterface.SetLinkMTU(mtu); err != nil { - return errors.Wrapf(err, "could not set MTU on interface %s", device) - } - - return nil -} diff --git a/internal/wireguard/manager_test.go b/internal/wireguard/manager_test.go new file mode 100644 index 0000000..1f4e446 --- /dev/null +++ b/internal/wireguard/manager_test.go @@ -0,0 +1,239 @@ +//go:build !integration +// +build !integration + +package wireguard + +import ( + "net" + "reflect" + "testing" + + "github.com/stretchr/testify/mock" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" +) + +type MockWireGuardClient struct { + mock.Mock +} + +func (m *MockWireGuardClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockWireGuardClient) Devices() ([]*wgtypes.Device, error) { + args := m.Called() + return args.Get(0).([]*wgtypes.Device), args.Error(1) +} + +func (m *MockWireGuardClient) Device(name string) (*wgtypes.Device, error) { + args := m.Called(name) + return args.Get(0).(*wgtypes.Device), args.Error(1) +} + +func (m *MockWireGuardClient) ConfigureDevice(name string, cfg wgtypes.Config) error { + args := m.Called(name, cfg) + return args.Error(0) +} + +type MockNetlinkClient struct { + mock.Mock +} + +func (m *MockNetlinkClient) LinkAdd(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkDel(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkByName(name string) (netlink.Link, error) { + args := m.Called(name) + return args.Get(0).(netlink.Link), args.Error(1) +} + +func (m *MockNetlinkClient) LinkSetUp(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkSetDown(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkSetMTU(link netlink.Link, mtu int) error { + args := m.Called(link, mtu) + return args.Error(0) +} + +func (m *MockNetlinkClient) AddrReplace(link netlink.Link, addr *netlink.Addr) error { + args := m.Called(link, addr) + return args.Error(0) +} + +func (m *MockNetlinkClient) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + args := m.Called(link, addr) + return args.Error(0) +} + +// +// ---------- Tests +// + +func TestManagementUtil_GetFreshKeypair(t *testing.T) { + m := ManagementUtil{} + kp, err := m.GetFreshKeypair() + assert.NoError(t, err) + assert.NotEmpty(t, kp.PrivateKey) + assert.NotEmpty(t, kp.PublicKey) +} + +func TestManagementUtil_GetPreSharedKey(t *testing.T) { + m := ManagementUtil{} + psk, err := m.GetPreSharedKey() + assert.NoError(t, err) + assert.NotEmpty(t, psk) +} + +func Test_parseIpAddressString(t *testing.T) { + type args struct { + addrStr string + } + var tests = []struct { + name string + args args + want []*netlink.Addr + wantErr bool + }{ + { + name: "Empty String", + args: args{}, + want: []*netlink.Addr{}, + wantErr: false, + }, + { + name: "Single IPv4", + args: args{addrStr: "123.123.123.123"}, + want: nil, + wantErr: true, + }, + { + name: "Malformed", + args: args{addrStr: "hello world"}, + want: nil, + wantErr: true, + }, + { + name: "Single IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24, 200.201.202.203/16"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }}, + wantErr: false, + }, + { + name: "Single IPv6 CIDR", + args: args{addrStr: "fe80::1/64"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv6 CIDR", + args: args{addrStr: "fe80::1/64 , 2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }}, + wantErr: false, + }, + { + name: "Mixed IPv4 and IPv6 CIDR", + args: args{addrStr: "200.201.202.203/16,2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseIpAddressString(tt.args.addrStr) + if (err != nil) != tt.wantErr { + t.Errorf("parseIpAddressString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseIpAddressString() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestManagementUtil_UpdateDevice(t *testing.T) { + devName := DeviceIdentifier("wg668") + wg := new(MockWireGuardClient) + nl := new(MockNetlinkClient) + + // expectations + nl.On("LinkByName", string(devName)).Return(&netlink.GenericLink{}, nil) + nl.On("LinkSetMTU", mock.Anything, 1234).Return(nil) + nl.On("AddrReplace", mock.Anything, mock.Anything).Return(nil) + wg.On("ConfigureDevice", string(devName), mock.Anything).Return(nil) + nl.On("LinkSetDown", mock.Anything).Return(nil) + + m := ManagementUtil{interfaces: map[DeviceIdentifier]InterfaceConfig{devName: {}}, nl: nl, wg: wg} + + err := m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24", Mtu: 1234}) + assert.NoError(t, err) + + // assert that the expectations were met + wg.AssertExpectations(t) + nl.AssertExpectations(t) +} diff --git a/internal/wireguard/peermanager.go b/internal/wireguard/peermanager.go deleted file mode 100644 index 1ca27e4..0000000 --- a/internal/wireguard/peermanager.go +++ /dev/null @@ -1,868 +0,0 @@ -package wireguard - -// WireGuard documentation: https://manpages.debian.org/unstable/wireguard-tools/wg.8.en.html - -import ( - "bytes" - "crypto/md5" - "fmt" - "net" - "regexp" - "sort" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/gin-gonic/gin/binding" - "github.com/go-playground/validator/v10" - "github.com/h44z/wg-portal/internal/common" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "github.com/skip2/go-qrcode" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "gorm.io/gorm" -) - -// -// CUSTOM VALIDATORS ---------------------------------------------------------------------------- -// -var cidrList validator.Func = func(fl validator.FieldLevel) bool { - cidrListStr := fl.Field().String() - - cidrList := common.ParseStringList(cidrListStr) - for i := range cidrList { - _, _, err := net.ParseCIDR(cidrList[i]) - if err != nil { - return false - } - } - return true -} - -var ipList validator.Func = func(fl validator.FieldLevel) bool { - ipListStr := fl.Field().String() - ipList := common.ParseStringList(ipListStr) - for i := range ipList { - ip := net.ParseIP(ipList[i]) - if ip == nil { - return false - } - } - return true -} - -func init() { - if v, ok := binding.Validator.Engine().(*validator.Validate); ok { - _ = v.RegisterValidation("cidrlist", cidrList) - _ = v.RegisterValidation("iplist", ipList) - } -} - -// -// PEER ---------------------------------------------------------------------------------------- -// - -type Peer struct { - Peer *wgtypes.Peer `gorm:"-" json:"-"` // WireGuard peer - Config string `gorm:"-" json:"-"` - - UID string `form:"uid" binding:"required,alphanum" json:"-"` // uid for html identification - DeviceName string `gorm:"index" form:"device" binding:"required"` - DeviceType DeviceType `gorm:"-" form:"devicetype" binding:"required,oneof=client server" json:"-"` - Identifier string `form:"identifier" binding:"required,max=64"` // Identifier AND Email make a WireGuard peer unique - Email string `gorm:"index" form:"mail" binding:"required,email"` - IgnoreGlobalSettings bool `form:"ignoreglobalsettings"` - - IsOnline bool `gorm:"-" json:"-"` - IsNew bool `gorm:"-" json:"-"` - LastHandshake string `gorm:"-" json:"-"` - LastHandshakeTime string `gorm:"-" json:"-"` - - // Core WireGuard Settings - PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"` // the public key of the peer itself - PresharedKey string `form:"presharedkey" binding:"omitempty,base64"` - AllowedIPsStr string `form:"allowedip" binding:"cidrlist"` // a comma separated list of IPs that are used in the client config file - AllowedIPsSrvStr string `form:"allowedipSrv" binding:"cidrlist"` // a comma separated list of IPs that are used in the server config file - Endpoint string `form:"endpoint" binding:"omitempty,hostname_port"` - PersistentKeepalive int `form:"keepalive" binding:"gte=0"` - - // Misc. WireGuard Settings - PrivateKey string `form:"privkey" binding:"omitempty,base64"` - IPsStr string `form:"ip" binding:"cidrlist,required_if=DeviceType server"` // a comma separated list of IPs of the client - DNSStr string `form:"dns" binding:"iplist"` // comma separated list of the DNS servers for the client - // Global Device Settings (can be ignored, only make sense if device is in server mode) - Mtu int `form:"mtu" binding:"gte=0,lte=1500"` - - DeactivatedAt *time.Time `json:",omitempty"` - CreatedBy string - UpdatedBy string - CreatedAt time.Time - UpdatedAt time.Time -} - -func (p *Peer) SetIPAddresses(addresses ...string) { - p.IPsStr = common.ListToString(addresses) -} - -func (p Peer) GetIPAddresses() []string { - return common.ParseStringList(p.IPsStr) -} - -func (p *Peer) SetDNSServers(addresses ...string) { - p.DNSStr = common.ListToString(addresses) -} - -func (p Peer) GetDNSServers() []string { - return common.ParseStringList(p.DNSStr) -} - -func (p *Peer) SetAllowedIPs(addresses ...string) { - p.AllowedIPsStr = common.ListToString(addresses) -} - -func (p Peer) GetAllowedIPs() []string { - return common.ParseStringList(p.AllowedIPsStr) -} - -func (p Peer) GetAllowedIPsSrv() []string { - return common.ParseStringList(p.AllowedIPsSrvStr) -} - -func (p Peer) GetConfig(dev *Device) wgtypes.PeerConfig { - publicKey, _ := wgtypes.ParseKey(p.PublicKey) - - var presharedKey *wgtypes.Key - if p.PresharedKey != "" { - presharedKeyTmp, _ := wgtypes.ParseKey(p.PresharedKey) - presharedKey = &presharedKeyTmp - } - - var endpoint *net.UDPAddr - if p.Endpoint != "" && dev.Type == DeviceTypeClient { - addr, err := net.ResolveUDPAddr("udp", p.Endpoint) - if err == nil { - endpoint = addr - } - } - - var keepAlive *time.Duration - if p.PersistentKeepalive != 0 { - keepAliveDuration := time.Duration(p.PersistentKeepalive) * time.Second - keepAlive = &keepAliveDuration - } - - allowedIPs := make([]net.IPNet, 0) - var peerAllowedIPs []string - switch dev.Type { - case DeviceTypeClient: - peerAllowedIPs = p.GetAllowedIPs() - case DeviceTypeServer: - peerAllowedIPs = p.GetIPAddresses() - peerAllowedIPs = append(peerAllowedIPs, p.GetAllowedIPsSrv()...) - } - for _, ip := range peerAllowedIPs { - _, ipNet, err := net.ParseCIDR(ip) - if err == nil { - allowedIPs = append(allowedIPs, *ipNet) - } - } - - cfg := wgtypes.PeerConfig{ - PublicKey: publicKey, - Remove: false, - UpdateOnly: false, - PresharedKey: presharedKey, - Endpoint: endpoint, - PersistentKeepaliveInterval: keepAlive, - ReplaceAllowedIPs: true, - AllowedIPs: allowedIPs, - } - - return cfg -} - -func (p Peer) GetConfigFile(device Device) ([]byte, error) { - var tplBuff bytes.Buffer - - err := templateCache.ExecuteTemplate(&tplBuff, "peer.tpl", gin.H{ - "Peer": p, - "Interface": device, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to execute client template") - } - - return tplBuff.Bytes(), nil -} - -func (p Peer) GetQRCode() ([]byte, error) { - png, err := qrcode.Encode(p.Config, qrcode.Medium, 250) - if err == nil { - return png, nil - } - - if err.Error() != "content too long to encode" { - logrus.Errorf("failed to create qrcode: %v", err) - return nil, errors.Wrap(err, "failed to encode qrcode") - } - - png, err = qrcode.Encode(p.Config, qrcode.Low, 250) - if err != nil { - logrus.Errorf("failed to create qrcode: %v", err) - return nil, errors.Wrap(err, "failed to encode qrcode") - } - - return png, nil -} - -func (p Peer) IsValid() bool { - if p.PublicKey == "" { - return false - } - - return true -} - -func (p Peer) GetConfigFileName() string { - reg := regexp.MustCompile("[^a-zA-Z0-9_-]+") - return reg.ReplaceAllString(strings.ReplaceAll(p.Identifier, " ", "-"), "") + ".conf" -} - -// -// DEVICE -------------------------------------------------------------------------------------- -// - -type DeviceType string - -const ( - DeviceTypeServer DeviceType = "server" - DeviceTypeClient DeviceType = "client" -) - -type Device struct { - Interface *wgtypes.Device `gorm:"-" json:"-"` - Peers []Peer `gorm:"foreignKey:DeviceName" binding:"-" json:"-"` // linked WireGuard peers - - Type DeviceType `form:"devicetype" binding:"required,oneof=client server"` - DeviceName string `form:"device" gorm:"primaryKey" binding:"required,alphanum"` - DisplayName string `form:"displayname" binding:"omitempty,max=200"` - - // Core WireGuard Settings (Interface section) - PrivateKey string `form:"privkey" binding:"required,base64"` - ListenPort int `form:"port" binding:"required_if=Type server,omitempty,gt=0,lt=65535"` - FirewallMark int32 `form:"firewallmark" binding:"gte=0"` - // Misc. WireGuard Settings - PublicKey string `form:"pubkey" binding:"required,base64"` - Mtu int `form:"mtu" binding:"gte=0,lte=1500"` // the interface MTU, wg-quick addition - IPsStr string `form:"ip" binding:"required,cidrlist"` // comma separated list of the IPs of the client, wg-quick addition - DNSStr string `form:"dns" binding:"iplist"` // comma separated list of the DNS servers of the client, wg-quick addition - RoutingTable string `form:"routingtable"` // the routing table, wg-quick addition - PreUp string `form:"preup"` // pre up script, wg-quick addition - PostUp string `form:"postup"` // post up script, wg-quick addition - PreDown string `form:"predown"` // pre down script, wg-quick addition - PostDown string `form:"postdown"` // post down script, wg-quick addition - SaveConfig bool `form:"saveconfig"` // if set to `true', the configuration is saved from the current state of the interface upon shutdown, wg-quick addition - - // Settings that are applied to all peer by default - DefaultEndpoint string `form:"endpoint" binding:"required_if=Type server,omitempty,hostname_port"` - DefaultAllowedIPsStr string `form:"allowedip" binding:"cidrlist"` // comma separated list of IPs that are used in the client config file - DefaultPersistentKeepalive int `form:"keepalive" binding:"gte=0"` - - CreatedAt time.Time - UpdatedAt time.Time -} - -func (d Device) IsValid() bool { - switch d.Type { - case DeviceTypeServer: - if d.PublicKey == "" { - return false - } - if len(d.GetIPAddresses()) == 0 { - return false - } - if d.DefaultEndpoint == "" { - return false - } - case DeviceTypeClient: - if d.PublicKey == "" { - return false - } - if len(d.GetIPAddresses()) == 0 { - return false - } - } - - return true -} - -func (d *Device) SetIPAddresses(addresses ...string) { - d.IPsStr = common.ListToString(addresses) -} - -func (d Device) GetIPAddresses() []string { - return common.ParseStringList(d.IPsStr) -} - -func (d *Device) SetDNSServers(addresses ...string) { - d.DNSStr = common.ListToString(addresses) -} - -func (d Device) GetDNSServers() []string { - return common.ParseStringList(d.DNSStr) -} - -func (d *Device) SetDefaultAllowedIPs(addresses ...string) { - d.DefaultAllowedIPsStr = common.ListToString(addresses) -} - -func (d Device) GetDefaultAllowedIPs() []string { - return common.ParseStringList(d.DefaultAllowedIPsStr) -} - -func (d Device) GetConfig() wgtypes.Config { - var privateKey *wgtypes.Key - if d.PrivateKey != "" { - pKey, _ := wgtypes.ParseKey(d.PrivateKey) - privateKey = &pKey - } - - fwMark := int(d.FirewallMark) - - cfg := wgtypes.Config{ - PrivateKey: privateKey, - ListenPort: &d.ListenPort, - FirewallMark: &fwMark, - } - - return cfg -} - -func (d Device) GetConfigFile(peers []Peer) ([]byte, error) { - var tplBuff bytes.Buffer - - err := templateCache.ExecuteTemplate(&tplBuff, "interface.tpl", gin.H{ - "Peers": peers, - "Interface": d, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to execute server template") - } - - return tplBuff.Bytes(), nil -} - -// -// PEER-MANAGER -------------------------------------------------------------------------------- -// - -type PeerManager struct { - db *gorm.DB - wg *Manager -} - -func NewPeerManager(db *gorm.DB, wg *Manager) (*PeerManager, error) { - pm := &PeerManager{db: db, wg: wg} - - // check if old device table exists (from version <= 1.0.3), if so migrate it. - if db.Migrator().HasColumn(&Device{}, "endpoint") { - if err := db.Migrator().RenameColumn(&Device{}, "endpoint", "default_endpoint"); err != nil { - return nil, errors.Wrapf(err, "failed to migrate old database structure for column endpoint") - } - } - if db.Migrator().HasColumn(&Device{}, "allowed_ips_str") { - if err := db.Migrator().RenameColumn(&Device{}, "allowed_ips_str", "default_allowed_ips_str"); err != nil { - return nil, errors.Wrapf(err, "failed to migrate old database structure for column allowed_ips_str") - } - } - if db.Migrator().HasColumn(&Device{}, "persistent_keepalive") { - if err := db.Migrator().RenameColumn(&Device{}, "persistent_keepalive", "default_persistent_keepalive"); err != nil { - return nil, errors.Wrapf(err, "failed to migrate old database structure for column persistent_keepalive") - } - } - - if err := pm.db.AutoMigrate(&Device{}, &Peer{}); err != nil { - return nil, errors.WithMessage(err, "failed to migrate peer database") - } - - if err := pm.initFromPhysicalInterface(); err != nil { - return nil, errors.WithMessagef(err, "unable to initialize peer manager") - } - - // check if peers without device name exist (from version <= 1.0.3), if so assign them to the default device. - peers := make([]Peer, 0) - pm.db.Find(&peers) - for i := range peers { - if peers[i].DeviceName == "" { - peers[i].DeviceName = wg.Cfg.GetDefaultDeviceName() - pm.db.Save(&peers[i]) - } - } - - // validate and update existing peers if needed - for _, deviceName := range wg.Cfg.DeviceNames { - dev := pm.GetDevice(deviceName) - peers := pm.GetAllPeers(deviceName) - for i := range peers { - if err := pm.fixPeerDefaultData(&peers[i], &dev); err != nil { - return nil, errors.WithMessagef(err, "unable to fix peers for interface %s", deviceName) - } - } - } - - return pm, nil -} - -// initFromPhysicalInterface read all WireGuard peers from the WireGuard interface configuration. If a peer does not -// exist in the local database, it gets created. -func (m *PeerManager) initFromPhysicalInterface() error { - for _, deviceName := range m.wg.Cfg.DeviceNames { - peers, err := m.wg.GetPeerList(deviceName) - if err != nil { - return errors.Wrapf(err, "failed to get peer list for device %s", deviceName) - } - device, err := m.wg.GetDeviceInfo(deviceName) - if err != nil { - return errors.Wrapf(err, "failed to get device info for device %s", deviceName) - } - var ipAddresses []string - var mtu int - if m.wg.Cfg.ManageIPAddresses { - if ipAddresses, err = m.wg.GetIPAddress(deviceName); err != nil { - return errors.Wrapf(err, "failed to get ip address for device %s", deviceName) - } - if mtu, err = m.wg.GetMTU(deviceName); err != nil { - return errors.Wrapf(err, "failed to get MTU for device %s", deviceName) - } - } - - // Check if device already exists in database, if not, create it - if err := m.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil { - return errors.WithMessagef(err, "failed to validate device %s", device.Name) - } - - // Check if entries already exist in database, if not, create them - for _, peer := range peers { - if err := m.validateOrCreatePeer(deviceName, peer); err != nil { - return errors.WithMessagef(err, "failed to validate peer %s for device %s", peer.PublicKey, deviceName) - } - } - } - - return nil -} - -// validateOrCreatePeer checks if the given WireGuard peer already exists in the database, if not, the peer entry will be created -// assumption: server mode is used -func (m *PeerManager) validateOrCreatePeer(device string, wgPeer wgtypes.Peer) error { - peer := Peer{} - m.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer) - - dev := m.GetDevice(device) - - if peer.PublicKey == "" { // peer not found, create - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String()))) - if dev.Type == DeviceTypeServer { - peer.PublicKey = wgPeer.PublicKey.String() - peer.Identifier = "Autodetected Client (" + peer.PublicKey[0:8] + ")" - } else if dev.Type == DeviceTypeClient { - peer.PublicKey = wgPeer.PublicKey.String() - if wgPeer.Endpoint != nil { - peer.Endpoint = wgPeer.Endpoint.String() - } - peer.Identifier = "Autodetected Endpoint (" + peer.PublicKey[0:8] + ")" - } - if wgPeer.PresharedKey != (wgtypes.Key{}) { - peer.PresharedKey = wgPeer.PresharedKey.String() - } - peer.Email = "autodetected@example.com" - peer.UpdatedAt = time.Now() - peer.CreatedAt = time.Now() - IPs := make([]string, len(wgPeer.AllowedIPs)) // use allowed IP's as the peer IP's - for i, ip := range wgPeer.AllowedIPs { - IPs[i] = ip.String() - } - peer.SetIPAddresses(IPs...) - peer.DeviceName = device - - res := m.db.Create(&peer) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey) - } - } - - if peer.DeviceName == "" { - peer.DeviceName = device - res := m.db.Save(&peer) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to update autodetected peer %s", peer.PublicKey) - } - } - - return nil -} - -// validateOrCreateDevice checks if the given WireGuard device already exists in the database, if not, the peer entry will be created -func (m *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error { - device := Device{} - m.db.Where("device_name = ?", dev.Name).FirstOrInit(&device) - - if device.PublicKey == "" { // device not found, create - device.Type = DeviceTypeServer // imported device, we assume that server mode is used - device.PublicKey = dev.PublicKey.String() - device.PrivateKey = dev.PrivateKey.String() - device.DeviceName = dev.Name - device.ListenPort = dev.ListenPort - device.FirewallMark = int32(dev.FirewallMark) - device.Mtu = 0 - device.DefaultPersistentKeepalive = 16 // Default - device.IPsStr = strings.Join(ipAddresses, ", ") - if mtu == DefaultMTU { - mtu = 0 - } - device.Mtu = mtu - - res := m.db.Create(&device) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to create autodetected device") - } - } - - if device.Type == "" { - device.Type = DeviceTypeServer // from version <= 1.0.3, only server mode devices were supported - - res := m.db.Save(&device) - if res.Error != nil { - return errors.Wrapf(res.Error, "failed to update autodetected device") - } - } - - return nil -} - -// populatePeerData enriches the peer struct with WireGuard live data like last handshake, ... -func (m *PeerManager) populatePeerData(peer *Peer) { - // Set config file - tmpCfg, _ := peer.GetConfigFile(m.GetDevice(peer.DeviceName)) - peer.Config = string(tmpCfg) - - // set data from WireGuard interface - peer.Peer, _ = m.wg.GetPeer(peer.DeviceName, peer.PublicKey) - peer.LastHandshake = "never" - peer.LastHandshakeTime = "Never connected, or user is disabled." - if peer.Peer != nil { - since := time.Since(peer.Peer.LastHandshakeTime) - sinceSeconds := int(since.Round(time.Second).Seconds()) - sinceMinutes := sinceSeconds / 60 - sinceSeconds -= sinceMinutes * 60 - - if sinceMinutes > 2*10080 { // 2 weeks - peer.LastHandshake = "a while ago" - } else if sinceMinutes > 10080 { // 1 week - peer.LastHandshake = "a week ago" - } else { - peer.LastHandshake = fmt.Sprintf("%02dm %02ds", sinceMinutes, sinceSeconds) - } - peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate) - } - peer.IsOnline = false -} - -// fixPeerDefaultData tries to fill all required fields for the given peer -// also tries to migrate data if the database schema changed -func (m *PeerManager) fixPeerDefaultData(peer *Peer, device *Device) error { - updatePeer := false - - switch device.Type { - case DeviceTypeServer: - if peer.Endpoint == "" { - peer.Endpoint = device.DefaultEndpoint - updatePeer = true - } - case DeviceTypeClient: - } - - if updatePeer { - return m.UpdatePeer(*peer) - } - return nil -} - -// populateDeviceData enriches the device struct with WireGuard live data like interface information -func (m *PeerManager) populateDeviceData(device *Device) { - // set data from WireGuard interface - device.Interface, _ = m.wg.GetDeviceInfo(device.DeviceName) -} - -func (m *PeerManager) GetAllPeers(device string) []Peer { - peers := make([]Peer, 0) - m.db.Where("device_name = ?", device).Find(&peers) - - for i := range peers { - m.populatePeerData(&peers[i]) - } - - return peers -} - -func (m *PeerManager) GetActivePeers(device string) []Peer { - peers := make([]Peer, 0) - m.db.Where("device_name = ? AND deactivated_at IS NULL", device).Find(&peers) - - for i := range peers { - m.populatePeerData(&peers[i]) - } - - return peers -} - -func (m *PeerManager) GetFilteredAndSortedPeers(device, sortKey, sortDirection, search string) []Peer { - peers := make([]Peer, 0) - m.db.Where("device_name = ?", device).Find(&peers) - - filteredPeers := make([]Peer, 0, len(peers)) - for i := range peers { - m.populatePeerData(&peers[i]) - - if search == "" || - strings.Contains(peers[i].Email, strings.ToLower(search)) || - strings.Contains(peers[i].Identifier, search) || - strings.Contains(peers[i].PublicKey, search) { - filteredPeers = append(filteredPeers, peers[i]) - } - } - - sortPeers(sortKey, sortDirection, filteredPeers) - - return filteredPeers -} - -func (m *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer { - email = strings.ToLower(email) - peers := make([]Peer, 0) - m.db.Where("email = ?", email).Find(&peers) - - for i := range peers { - m.populatePeerData(&peers[i]) - } - - sortPeers(sortKey, sortDirection, peers) - - return peers -} - -func sortPeers(sortKey string, sortDirection string, peers []Peer) { - sort.Slice(peers, func(i, j int) bool { - var sortValueLeft string - var sortValueRight string - - switch sortKey { - case "id": - sortValueLeft = peers[i].Identifier - sortValueRight = peers[j].Identifier - case "pubKey": - sortValueLeft = peers[i].PublicKey - sortValueRight = peers[j].PublicKey - case "mail": - sortValueLeft = peers[i].Email - sortValueRight = peers[j].Email - case "ip": - sortValueLeft = peers[i].IPsStr - sortValueRight = peers[j].IPsStr - case "endpoint": - sortValueLeft = peers[i].Endpoint - sortValueRight = peers[j].Endpoint - case "handshake": - if peers[i].Peer == nil { - return true - } else if peers[j].Peer == nil { - return false - } - sortValueLeft = peers[i].Peer.LastHandshakeTime.Format(time.RFC3339) - sortValueRight = peers[j].Peer.LastHandshakeTime.Format(time.RFC3339) - } - - if sortDirection == "asc" { - return sortValueLeft < sortValueRight - } else { - return sortValueLeft > sortValueRight - } - }) -} - -func (m *PeerManager) GetDevice(device string) Device { - dev := Device{} - - m.db.Where("device_name = ?", device).First(&dev) - m.populateDeviceData(&dev) - - return dev -} - -func (m *PeerManager) GetPeerByKey(publicKey string) Peer { - peer := Peer{} - m.db.Where("public_key = ?", publicKey).FirstOrInit(&peer) - m.populatePeerData(&peer) - return peer -} - -func (m *PeerManager) GetPeersByMail(mail string) []Peer { - mail = strings.ToLower(mail) - var peers []Peer - m.db.Where("email = ?", mail).Find(&peers) - for i := range peers { - m.populatePeerData(&peers[i]) - } - - return peers -} - -// ---- Database helpers ----- - -func (m *PeerManager) CreatePeer(peer Peer) error { - peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey))) - peer.UpdatedAt = time.Now() - peer.CreatedAt = time.Now() - peer.Email = strings.ToLower(peer.Email) - - res := m.db.Create(&peer) - if res.Error != nil { - logrus.Errorf("failed to create peer: %v", res.Error) - return errors.Wrap(res.Error, "failed to create peer") - } - - return nil -} - -func (m *PeerManager) UpdatePeer(peer Peer) error { - peer.UpdatedAt = time.Now() - peer.Email = strings.ToLower(peer.Email) - - res := m.db.Save(&peer) - if res.Error != nil { - logrus.Errorf("failed to update peer: %v", res.Error) - return errors.Wrap(res.Error, "failed to update peer") - } - - return nil -} - -func (m *PeerManager) DeletePeer(peer Peer) error { - res := m.db.Delete(&peer) - if res.Error != nil { - logrus.Errorf("failed to delete peer: %v", res.Error) - return errors.Wrap(res.Error, "failed to delete peer") - } - - return nil -} - -func (m *PeerManager) UpdateDevice(device Device) error { - device.UpdatedAt = time.Now() - - res := m.db.Save(&device) - if res.Error != nil { - logrus.Errorf("failed to update device: %v", res.Error) - return errors.Wrap(res.Error, "failed to update device") - } - - return nil -} - -// ---- IP helpers ---- - -func (m *PeerManager) GetAllReservedIps(device string) ([]string, error) { - reservedIps := make([]string, 0) - peers := m.GetAllPeers(device) - for _, user := range peers { - for _, cidr := range user.GetIPAddresses() { - if cidr == "" { - continue - } - ip, _, err := net.ParseCIDR(cidr) - if err != nil { - return nil, errors.Wrap(err, "failed to parse cidr") - } - reservedIps = append(reservedIps, ip.String()) - } - } - - dev := m.GetDevice(device) - for _, cidr := range dev.GetIPAddresses() { - if cidr == "" { - continue - } - ip, _, err := net.ParseCIDR(cidr) - if err != nil { - return nil, errors.Wrap(err, "failed to parse cidr") - } - - reservedIps = append(reservedIps, ip.String()) - } - - return reservedIps, nil -} - -func (m *PeerManager) IsIPReserved(device string, cidr string) bool { - reserved, err := m.GetAllReservedIps(device) - if err != nil { - return true // in case something failed, assume the ip is reserved - } - ip, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - return true - } - - // this two addresses are not usable - broadcastAddr := common.BroadcastAddr(ipnet).String() - networkAddr := ipnet.IP.String() - address := ip.String() - - if address == broadcastAddr || address == networkAddr { - return true - } - - for _, r := range reserved { - if address == r { - return true - } - } - - return false -} - -// GetAvailableIp search for an available ip in cidr against a list of reserved ips -func (m *PeerManager) GetAvailableIp(device string, cidr string) (string, error) { - reserved, err := m.GetAllReservedIps(device) - if err != nil { - return "", errors.WithMessagef(err, "failed to get all reserved IP addresses for %s", device) - } - ip, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - return "", errors.Wrap(err, "failed to parse cidr") - } - - // this two addresses are not usable - broadcastAddr := common.BroadcastAddr(ipnet).String() - networkAddr := ipnet.IP.String() - - for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); common.IncreaseIP(ip) { - ok := true - address := ip.String() - for _, r := range reserved { - if address == r { - ok = false - break - } - } - if ok && address != networkAddr && address != broadcastAddr { - netMask := "/32" - if common.IsIPv6(address) { - netMask = "/128" - } - return address + netMask, nil - } - } - - return "", errors.New("no more available address from cidr") -} diff --git a/internal/wireguard/template.go b/internal/wireguard/template.go deleted file mode 100644 index d9c2e4a..0000000 --- a/internal/wireguard/template.go +++ /dev/null @@ -1,20 +0,0 @@ -package wireguard - -import ( - "embed" - "strings" - "text/template" -) - -//go:embed tpl/* -var Templates embed.FS - -var templateCache *template.Template - -func init() { - var err error - templateCache, err = template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).ParseFS(Templates, "tpl/*.tpl") - if err != nil { - panic(err) - } -} diff --git a/internal/wireguard/tpl/interface.tpl b/internal/wireguard/tpl/interface.tpl deleted file mode 100644 index 9662a5a..0000000 --- a/internal/wireguard/tpl/interface.tpl +++ /dev/null @@ -1,78 +0,0 @@ -# AUTOGENERATED FILE - DO NOT EDIT -# -WGP- Interface: {{ .Interface.DeviceName }} / Updated: {{ .Interface.UpdatedAt }} / Created: {{ .Interface.CreatedAt }} -# -WGP- Interface display name: {{ .Interface.DisplayName }} -# -WGP- Interface mode: {{ .Interface.Type }} -# -WGP- PublicKey = {{ .Interface.PublicKey }} - -[Interface] - -# Core settings -PrivateKey = {{ .Interface.PrivateKey }} -Address = {{ .Interface.IPsStr }} - -# Misc. settings (optional) -{{- if ne .Interface.ListenPort 0}} -ListenPort = {{ .Interface.ListenPort }} -{{- end}} -{{- if ne .Interface.Mtu 0}} -MTU = {{.Interface.Mtu}} -{{- end}} -{{- if and (ne .Interface.DNSStr "") (eq $.Interface.Type "client")}} -DNS = {{ .Interface.DNSStr }} -{{- end}} -{{- if ne .Interface.FirewallMark 0}} -FwMark = {{.Interface.FirewallMark}} -{{- end}} -{{- if ne .Interface.RoutingTable ""}} -Table = {{.Interface.RoutingTable}} -{{- end}} -{{- if .Interface.SaveConfig}} -SaveConfig = true -{{- end}} - -# Interface hooks (optional) -{{- if .Interface.PreUp}} -PreUp = {{ .Interface.PreUp }} -{{- end}} -{{- if .Interface.PostUp}} -PostUp = {{ .Interface.PostUp }} -{{- end}} -{{- if .Interface.PreDown}} -PreDown = {{ .Interface.PreDown }} -{{- end}} -{{- if .Interface.PostDown}} -PostDown = {{ .Interface.PostDown }} -{{- end}} - -# -# Peers -# - -{{range .Peers}} -{{- if not .DeactivatedAt}} -# -WGP- Peer: {{.Identifier}} / Updated: {{.UpdatedAt}} / Created: {{.CreatedAt}} -# -WGP- Peer email: {{.Email}} -{{- if .PrivateKey}} -# -WGP- PrivateKey: {{.PrivateKey}} -{{- end}} -[Peer] -PublicKey = {{ .PublicKey }} -{{- if .PresharedKey}} -PresharedKey = {{ .PresharedKey }} -{{- end}} -{{- if eq $.Interface.Type "server"}} -AllowedIPs = {{ .IPsStr }}{{if ne .AllowedIPsSrvStr ""}}, {{ .AllowedIPsSrvStr }}{{end}} -{{- end}} -{{- if eq $.Interface.Type "client"}} -{{- if .AllowedIPsStr}} -AllowedIPs = {{ .AllowedIPsStr }} -{{- end}} -{{- end}} -{{- if and (ne .Endpoint "") (eq $.Interface.Type "client")}} -Endpoint = {{ .Endpoint }} -{{- end}} -{{- if ne .PersistentKeepalive 0}} -PersistentKeepalive = {{ .PersistentKeepalive }} -{{- end}} -{{- end}} -{{end}} \ No newline at end of file diff --git a/internal/wireguard/tpl/peer.tpl b/internal/wireguard/tpl/peer.tpl deleted file mode 100644 index d6a85e4..0000000 --- a/internal/wireguard/tpl/peer.tpl +++ /dev/null @@ -1,30 +0,0 @@ -# AUTOGENERATED FILE - PROVIDED BY WIREGUARD PORTAL -# WireGuard configuration: {{ .Peer.Identifier }} -# -WGP- PublicKey: {{ .Peer.PublicKey }} - -[Interface] - -# Core settings -PrivateKey = {{ .Peer.PrivateKey }} -Address = {{ .Peer.IPsStr }} - -# Misc. settings (optional) -{{- if .Peer.DNSStr}} -DNS = {{ .Peer.DNSStr }} -{{- end}} -{{- if ne .Peer.Mtu 0}} -MTU = {{.Peer.Mtu}} -{{- end}} - -[Peer] -PublicKey = {{ .Interface.PublicKey }} -Endpoint = {{ .Peer.Endpoint }} -{{- if .Peer.AllowedIPsStr}} -AllowedIPs = {{ .Peer.AllowedIPsStr }} -{{- end}} -{{- if .Peer.PresharedKey}} -PresharedKey = {{ .Peer.PresharedKey }} -{{- end}} -{{- if ne .Peer.PersistentKeepalive 0}} -PersistentKeepalive = {{.Peer.PersistentKeepalive}} -{{- end}} \ No newline at end of file diff --git a/internal/wireguard/wrappers.go b/internal/wireguard/wrappers.go new file mode 100644 index 0000000..0ec5e47 --- /dev/null +++ b/internal/wireguard/wrappers.go @@ -0,0 +1,55 @@ +package wireguard + +import ( + "io" + + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// A Client is a type which can control a WireGuard device. +type Client interface { + io.Closer + Devices() ([]*wgtypes.Device, error) + Device(name string) (*wgtypes.Device, error) + ConfigureDevice(name string, cfg wgtypes.Config) error +} + +// A NetlinkClient is a type which can control a netlink device. +type NetlinkClient interface { + LinkAdd(link netlink.Link) error + LinkDel(link netlink.Link) error + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + LinkSetDown(link netlink.Link) error + LinkSetMTU(link netlink.Link, mtu int) error + AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrAdd(link netlink.Link, addr *netlink.Addr) error +} + +type NetlinkManager struct { +} + +func (n NetlinkManager) LinkAdd(link netlink.Link) error { return netlink.LinkAdd(link) } + +func (n NetlinkManager) LinkDel(link netlink.Link) error { return netlink.LinkDel(link) } + +func (n NetlinkManager) LinkByName(name string) (netlink.Link, error) { + return netlink.LinkByName(name) +} + +func (n NetlinkManager) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } + +func (n NetlinkManager) LinkSetDown(link netlink.Link) error { return netlink.LinkSetDown(link) } + +func (n NetlinkManager) LinkSetMTU(link netlink.Link, mtu int) error { + return netlink.LinkSetMTU(link, mtu) +} + +func (n NetlinkManager) AddrReplace(link netlink.Link, addr *netlink.Addr) error { + return netlink.AddrReplace(link, addr) +} + +func (n NetlinkManager) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + return netlink.AddrAdd(link, addr) +}