User details
- {{if not $p.User}}
+ {{if not $peerUser}}
No user information available...
{{else}}
- - Firstname: {{$p.User.Firstname}}
- - Lastname: {{$p.User.Lastname}}
- - Phone: {{$p.User.Phone}}
- - Mail: {{$p.User.Email}}
+ - Firstname: {{$peerUser.Firstname}}
+ - Lastname: {{$peerUser.Lastname}}
+ - Phone: {{$peerUser.Phone}}
+ - Mail: {{$peerUser.Email}}
{{end}}
Traffic
diff --git a/internal/authentication/providers/password/provider.go b/internal/authentication/providers/password/provider.go
index c6f937d..fbe95c5 100644
--- a/internal/authentication/providers/password/provider.go
+++ b/internal/authentication/providers/password/provider.go
@@ -7,6 +7,8 @@
"strings"
"time"
+ "github.com/h44z/wg-portal/internal/common"
+
"github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/authentication"
"github.com/h44z/wg-portal/internal/users"
@@ -22,11 +24,11 @@
db *gorm.DB
}
-func New(cfg *users.Config) (*Provider, error) {
+func New(cfg *common.DatabaseConfig) (*Provider, error) {
p := &Provider{}
var err error
- p.db, err = users.GetDatabaseForConfig(cfg)
+ p.db, err = common.GetDatabaseForConfig(cfg)
if err != nil {
return nil, errors.Wrapf(err, "failed to setup authentication database %s", cfg.Database)
}
diff --git a/internal/common/configuration.go b/internal/common/configuration.go
deleted file mode 100644
index 3912b38..0000000
--- a/internal/common/configuration.go
+++ /dev/null
@@ -1,132 +0,0 @@
-package common
-
-import (
- "os"
- "reflect"
- "runtime"
-
- "github.com/h44z/wg-portal/internal/ldap"
- "github.com/h44z/wg-portal/internal/users"
- "github.com/h44z/wg-portal/internal/wireguard"
- "github.com/kelseyhightower/envconfig"
- "github.com/pkg/errors"
- "github.com/sirupsen/logrus"
- "gopkg.in/yaml.v3"
-)
-
-var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
-
-// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct.
-func loadConfigFile(cfg interface{}, filename string) error {
- s := reflect.ValueOf(cfg)
-
- if s.Kind() != reflect.Ptr {
- return ErrInvalidSpecification
- }
- s = s.Elem()
- if s.Kind() != reflect.Struct {
- return ErrInvalidSpecification
- }
-
- f, err := os.Open(filename)
- if err != nil {
- return errors.Wrapf(err, "failed to open config file %s", filename)
- }
- defer f.Close()
-
- decoder := yaml.NewDecoder(f)
- err = decoder.Decode(cfg)
- if err != nil {
- return errors.Wrapf(err, "failed to decode config file %s", filename)
- }
-
- return nil
-}
-
-// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
-func loadConfigEnv(cfg interface{}) error {
- err := envconfig.Process("", cfg)
- if err != nil {
- return errors.Wrap(err, "failed to process environment config")
- }
-
- return nil
-}
-
-type Config struct {
- Core struct {
- ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"`
- ExternalUrl string `yaml:"externalUrl" envconfig:"EXTERNAL_URL"`
- Title string `yaml:"title" envconfig:"WEBSITE_TITLE"`
- CompanyName string `yaml:"company" envconfig:"COMPANY_NAME"`
- MailFrom string `yaml:"mailFrom" envconfig:"MAIL_FROM"`
- AdminUser string `yaml:"adminUser" envconfig:"ADMIN_USER"` // must be an email address
- AdminPassword string `yaml:"adminPass" envconfig:"ADMIN_PASS"`
- EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"`
- CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"`
- LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"`
- } `yaml:"core"`
- Database users.Config `yaml:"database"`
- Email MailConfig `yaml:"email"`
- LDAP ldap.Config `yaml:"ldap"`
- WG wireguard.Config `yaml:"wg"`
-}
-
-func NewConfig() *Config {
- cfg := &Config{}
-
- // Default config
- cfg.Core.ListeningAddress = ":8123"
- cfg.Core.Title = "WireGuard VPN"
- cfg.Core.CompanyName = "WireGuard Portal"
- cfg.Core.ExternalUrl = "http://localhost:8123"
- cfg.Core.MailFrom = "WireGuard VPN
"
- cfg.Core.AdminUser = "admin@wgportal.local"
- cfg.Core.AdminPassword = "wgportal"
- cfg.Core.LdapEnabled = false
-
- cfg.Database.Typ = "sqlite"
- cfg.Database.Database = "data/wg_portal.db"
-
- cfg.LDAP.URL = "ldap://srv-ad01.company.local:389"
- cfg.LDAP.BaseDN = "DC=COMPANY,DC=LOCAL"
- cfg.LDAP.StartTLS = true
- cfg.LDAP.BindUser = "company\\\\ldap_wireguard"
- cfg.LDAP.BindPass = "SuperSecret"
- cfg.LDAP.Type = "AD"
- cfg.LDAP.UserClass = "organizationalPerson"
- cfg.LDAP.EmailAttribute = "mail"
- cfg.LDAP.FirstNameAttribute = "givenName"
- cfg.LDAP.LastNameAttribute = "sn"
- cfg.LDAP.PhoneAttribute = "telephoneNumber"
- cfg.LDAP.GroupMemberAttribute = "memberOf"
- cfg.LDAP.DisabledAttribute = "userAccountControl"
- cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
-
- cfg.WG.DeviceName = "wg0"
- cfg.WG.WireGuardConfig = "/etc/wireguard/wg0.conf"
- cfg.WG.ManageIPAddresses = true
- cfg.Email.Host = "127.0.0.1"
- cfg.Email.Port = 25
-
- // Load config from file and environment
- cfgFile, ok := os.LookupEnv("CONFIG_FILE")
- if !ok {
- cfgFile = "config.yml" // Default config file
- }
- err := loadConfigFile(cfg, cfgFile)
- if err != nil {
- logrus.Warnf("unable to load config.yml file: %v, using default configuration...", err)
- }
- err = loadConfigEnv(cfg)
- if err != nil {
- logrus.Warnf("unable to load environment config: %v", err)
- }
-
- if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" {
- logrus.Warnf("managing IP addresses only works on linux, feature disabled...")
- cfg.WG.ManageIPAddresses = false
- }
-
- return cfg
-}
diff --git a/internal/common/db.go b/internal/common/db.go
new file mode 100644
index 0000000..9261c58
--- /dev/null
+++ b/internal/common/db.go
@@ -0,0 +1,76 @@
+package common
+
+import (
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+ "gorm.io/driver/mysql"
+ "gorm.io/driver/sqlite"
+ "gorm.io/gorm"
+ "gorm.io/gorm/logger"
+)
+
+type SupportedDatabase string
+
+const (
+ SupportedDatabaseMySQL SupportedDatabase = "mysql"
+ SupportedDatabaseSQLite SupportedDatabase = "sqlite"
+)
+
+type DatabaseConfig struct {
+ Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
+ Host string `yaml:"host" envconfig:"DATABASE_HOST"`
+ Port int `yaml:"port" envconfig:"DATABASE_PORT"`
+ Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
+ User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
+ Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
+}
+
+func GetDatabaseForConfig(cfg *DatabaseConfig) (db *gorm.DB, err error) {
+ switch cfg.Typ {
+ case SupportedDatabaseSQLite:
+ if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
+ if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
+ return
+ }
+ }
+ db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
+ if err != nil {
+ return
+ }
+ case SupportedDatabaseMySQL:
+ connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
+ db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
+ if err != nil {
+ return
+ }
+
+ sqlDB, _ := db.DB()
+ sqlDB.SetConnMaxLifetime(time.Minute * 5)
+ sqlDB.SetMaxIdleConns(2)
+ sqlDB.SetMaxOpenConns(10)
+ err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to ping mysql authentication database")
+ }
+ }
+
+ // Enable Logger (logrus)
+ logCfg := logger.Config{
+ SlowThreshold: time.Second, // all slower than one second
+ Colorful: false,
+ LogLevel: logger.Silent, // default: log nothing
+ }
+
+ if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
+ logCfg.LogLevel = logger.Info
+ logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
+ }
+
+ db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
+ return
+}
diff --git a/internal/common/util.go b/internal/common/util.go
index fc974b1..bbde700 100644
--- a/internal/common/util.go
+++ b/internal/common/util.go
@@ -60,6 +60,16 @@
return strings.Join(lst, ", ")
}
+// ListContains checks if a needle exists in the given list.
+func ListContains(lst []string, needle string) bool {
+ for _, entry := range lst {
+ if entry == needle {
+ return true
+ }
+ }
+ return false
+}
+
// https://yourbasic.org/golang/formatting-byte-size-to-human-readable-format/
func ByteCountSI(b int64) string {
const unit = 1000
diff --git a/internal/server/configuration.go b/internal/server/configuration.go
new file mode 100644
index 0000000..a88d964
--- /dev/null
+++ b/internal/server/configuration.go
@@ -0,0 +1,133 @@
+package server
+
+import (
+ "os"
+ "reflect"
+ "runtime"
+
+ "github.com/h44z/wg-portal/internal/common"
+ "github.com/h44z/wg-portal/internal/ldap"
+ "github.com/h44z/wg-portal/internal/wireguard"
+ "github.com/kelseyhightower/envconfig"
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+ "gopkg.in/yaml.v3"
+)
+
+var ErrInvalidSpecification = errors.New("specification must be a struct pointer")
+
+// loadConfigFile parses yaml files. It uses yaml annotation to store the data in a struct.
+func loadConfigFile(cfg interface{}, filename string) error {
+ s := reflect.ValueOf(cfg)
+
+ if s.Kind() != reflect.Ptr {
+ return ErrInvalidSpecification
+ }
+ s = s.Elem()
+ if s.Kind() != reflect.Struct {
+ return ErrInvalidSpecification
+ }
+
+ f, err := os.Open(filename)
+ if err != nil {
+ return errors.Wrapf(err, "failed to open config file %s", filename)
+ }
+ defer f.Close()
+
+ decoder := yaml.NewDecoder(f)
+ err = decoder.Decode(cfg)
+ if err != nil {
+ return errors.Wrapf(err, "failed to decode config file %s", filename)
+ }
+
+ return nil
+}
+
+// loadConfigEnv processes envconfig annotations and loads environment variables to the given configuration struct.
+func loadConfigEnv(cfg interface{}) error {
+ err := envconfig.Process("", cfg)
+ if err != nil {
+ return errors.Wrap(err, "failed to process environment config")
+ }
+
+ return nil
+}
+
+type Config struct {
+ Core struct {
+ ListeningAddress string `yaml:"listeningAddress" envconfig:"LISTENING_ADDRESS"`
+ ExternalUrl string `yaml:"externalUrl" envconfig:"EXTERNAL_URL"`
+ Title string `yaml:"title" envconfig:"WEBSITE_TITLE"`
+ CompanyName string `yaml:"company" envconfig:"COMPANY_NAME"`
+ MailFrom string `yaml:"mailFrom" envconfig:"MAIL_FROM"`
+ AdminUser string `yaml:"adminUser" envconfig:"ADMIN_USER"` // must be an email address
+ AdminPassword string `yaml:"adminPass" envconfig:"ADMIN_PASS"`
+ EditableKeys bool `yaml:"editableKeys" envconfig:"EDITABLE_KEYS"`
+ CreateDefaultPeer bool `yaml:"createDefaultPeer" envconfig:"CREATE_DEFAULT_PEER"`
+ LdapEnabled bool `yaml:"ldapEnabled" envconfig:"LDAP_ENABLED"`
+ } `yaml:"core"`
+ Database common.DatabaseConfig `yaml:"database"`
+ Email common.MailConfig `yaml:"email"`
+ LDAP ldap.Config `yaml:"ldap"`
+ WG wireguard.Config `yaml:"wg"`
+}
+
+func NewConfig() *Config {
+ cfg := &Config{}
+
+ // Default config
+ cfg.Core.ListeningAddress = ":8123"
+ cfg.Core.Title = "WireGuard VPN"
+ cfg.Core.CompanyName = "WireGuard Portal"
+ cfg.Core.ExternalUrl = "http://localhost:8123"
+ cfg.Core.MailFrom = "WireGuard VPN "
+ cfg.Core.AdminUser = "admin@wgportal.local"
+ cfg.Core.AdminPassword = "wgportal"
+ cfg.Core.LdapEnabled = false
+
+ cfg.Database.Typ = "sqlite"
+ cfg.Database.Database = "data/wg_portal.db"
+
+ cfg.LDAP.URL = "ldap://srv-ad01.company.local:389"
+ cfg.LDAP.BaseDN = "DC=COMPANY,DC=LOCAL"
+ cfg.LDAP.StartTLS = true
+ cfg.LDAP.BindUser = "company\\\\ldap_wireguard"
+ cfg.LDAP.BindPass = "SuperSecret"
+ cfg.LDAP.Type = "AD"
+ cfg.LDAP.UserClass = "organizationalPerson"
+ cfg.LDAP.EmailAttribute = "mail"
+ cfg.LDAP.FirstNameAttribute = "givenName"
+ cfg.LDAP.LastNameAttribute = "sn"
+ cfg.LDAP.PhoneAttribute = "telephoneNumber"
+ cfg.LDAP.GroupMemberAttribute = "memberOf"
+ cfg.LDAP.DisabledAttribute = "userAccountControl"
+ cfg.LDAP.AdminLdapGroup = "CN=WireGuardAdmins,OU=_O_IT,DC=COMPANY,DC=LOCAL"
+
+ cfg.WG.DeviceNames = []string{"wg0"}
+ cfg.WG.DefaultDeviceName = "wg0"
+ cfg.WG.ConfigDirectoryPath = "/etc/wireguard"
+ cfg.WG.ManageIPAddresses = true
+ cfg.Email.Host = "127.0.0.1"
+ cfg.Email.Port = 25
+
+ // Load config from file and environment
+ cfgFile, ok := os.LookupEnv("CONFIG_FILE")
+ if !ok {
+ cfgFile = "config.yml" // Default config file
+ }
+ err := loadConfigFile(cfg, cfgFile)
+ if err != nil {
+ logrus.Warnf("unable to load config.yml file: %v, using default configuration...", err)
+ }
+ err = loadConfigEnv(cfg)
+ if err != nil {
+ logrus.Warnf("unable to load environment config: %v", err)
+ }
+
+ if cfg.WG.ManageIPAddresses && runtime.GOOS != "linux" {
+ logrus.Warnf("managing IP addresses only works on linux, feature disabled...")
+ cfg.WG.ManageIPAddresses = false
+ }
+
+ return cfg
+}
diff --git a/internal/server/handlers_auth.go b/internal/server/handlers_auth.go
index bb70b95..72833cc 100644
--- a/internal/server/handlers_auth.go
+++ b/internal/server/handlers_auth.go
@@ -98,7 +98,7 @@
Firstname: userData.Firstname,
Lastname: userData.Lastname,
Phone: userData.Phone,
- }); err != nil {
+ }, s.wg.Cfg.DefaultDeviceName); err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "login error", "failed to update user data")
return
}
@@ -121,9 +121,10 @@
sessionData.Email = user.Email
sessionData.Firstname = user.Firstname
sessionData.Lastname = user.Lastname
+ sessionData.DeviceName = s.wg.Cfg.DeviceNames[0]
// Check if user already has a peer setup, if not create one
- if err := s.CreateUserDefaultPeer(user.Email); err != nil {
+ if err := s.CreateUserDefaultPeer(user.Email, s.wg.Cfg.DefaultDeviceName); err != nil {
// Not a fatal error, just log it...
logrus.Errorf("failed to automatically create vpn peer for %s: %v", sessionData.Email, err)
}
diff --git a/internal/server/handlers_common.go b/internal/server/handlers_common.go
index 0ff0885..054acb8 100644
--- a/internal/server/handlers_common.go
+++ b/internal/server/handlers_common.go
@@ -4,37 +4,42 @@
"net/http"
"strconv"
+ "github.com/h44z/wg-portal/internal/users"
+
+ "github.com/h44z/wg-portal/internal/common"
+
"github.com/pkg/errors"
"github.com/gin-gonic/gin"
)
func (s *Server) GetHandleError(c *gin.Context, code int, message, details string) {
+ currentSession := GetSessionData(c)
+
c.HTML(code, "error.html", gin.H{
"Data": gin.H{
"Code": strconv.Itoa(code),
"Message": message,
"Details": details,
},
- "Route": c.Request.URL.Path,
- "Session": GetSessionData(c),
- "Static": s.getStaticData(),
+ "Route": c.Request.URL.Path,
+ "Session": GetSessionData(c),
+ "Static": s.getStaticData(),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
func (s *Server) GetIndex(c *gin.Context) {
- c.HTML(http.StatusOK, "index.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Device Device
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: GetSessionData(c),
- Static: s.getStaticData(),
- Device: s.peers.GetDevice(),
+ currentSession := GetSessionData(c)
+
+ c.HTML(http.StatusOK, "index.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -74,25 +79,35 @@
return
}
- device := s.peers.GetDevice()
- users := s.peers.GetFilteredAndSortedPeers(currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"])
+ deviceName := c.Query("device")
+ if deviceName != "" {
+ if !common.ListContains(s.wg.Cfg.DeviceNames, deviceName) {
+ s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "no such device")
+ return
+ }
+ currentSession.DeviceName = deviceName
- c.HTML(http.StatusOK, "admin_index.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Peers []Peer
- TotalPeers int
- Device Device
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Peers: users,
- TotalPeers: len(s.peers.GetAllPeers()),
- Device: device,
+ if err := UpdateSessionData(c, currentSession); err != nil {
+ s.GetHandleError(c, http.StatusInternalServerError, "device selection error", "failed to save session")
+ return
+ }
+ c.Redirect(http.StatusSeeOther, "/admin/")
+ return
+ }
+
+ device := s.peers.GetDevice(currentSession.DeviceName)
+ users := s.peers.GetFilteredAndSortedPeers(currentSession.DeviceName, currentSession.SortedBy["peers"], currentSession.SortDirection["peers"], currentSession.Search["peers"])
+
+ c.HTML(http.StatusOK, "admin_index.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Peers": users,
+ "TotalPeers": len(s.peers.GetAllPeers(currentSession.DeviceName)),
+ "Users": s.users.GetUsers(),
+ "Device": device,
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -120,25 +135,18 @@
return
}
- device := s.peers.GetDevice()
- users := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
+ peers := s.peers.GetSortedPeersForEmail(currentSession.SortedBy["userpeers"], currentSession.SortDirection["userpeers"], currentSession.Email)
- c.HTML(http.StatusOK, "user_index.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Peers []Peer
- TotalPeers int
- Device Device
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Peers: users,
- TotalPeers: len(users),
- Device: device,
+ c.HTML(http.StatusOK, "user_index.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Peers": peers,
+ "TotalPeers": len(peers),
+ "Users": []users.User{*s.users.GetUser(currentSession.Email)},
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -158,7 +166,7 @@
// If session does not contain a peer form ignore update
// If url contains a formerr parameter reset the form
if currentSession.FormData == nil || c.Query("formerr") == "" {
- user, err := s.PrepareNewPeer()
+ user, err := s.PrepareNewPeer(currentSession.DeviceName)
if err != nil {
return currentSession, errors.WithMessage(err, "failed to prepare new peer")
}
diff --git a/internal/server/handlers_interface.go b/internal/server/handlers_interface.go
index 77a152e..d3f3655 100644
--- a/internal/server/handlers_interface.go
+++ b/internal/server/handlers_interface.go
@@ -4,44 +4,37 @@
"net/http"
"strings"
+ "github.com/h44z/wg-portal/internal/wireguard"
+
"github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common"
)
func (s *Server) GetAdminEditInterface(c *gin.Context) {
- device := s.peers.GetDevice()
- users := s.peers.GetAllPeers()
-
+ currentSession := GetSessionData(c)
+ device := s.peers.GetDevice(currentSession.DeviceName)
currentSession, err := s.setFormInSession(c, device)
if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return
}
- c.HTML(http.StatusOK, "admin_edit_interface.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Peers []Peer
- Device Device
- EditableKeys bool
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Peers: users,
- Device: currentSession.FormData.(Device),
- EditableKeys: s.config.Core.EditableKeys,
+ c.HTML(http.StatusOK, "admin_edit_interface.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Device": currentSession.FormData.(wireguard.Device),
+ "EditableKeys": s.config.Core.EditableKeys,
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
func (s *Server) PostAdminEditInterface(c *gin.Context) {
currentSession := GetSessionData(c)
- var formDevice Device
+ var formDevice wireguard.Device
if currentSession.FormData != nil {
- formDevice = currentSession.FormData.(Device)
+ formDevice = currentSession.FormData.(wireguard.Device)
}
if err := c.ShouldBind(&formDevice); err != nil {
_ = s.updateFormInSession(c, formDevice)
@@ -76,7 +69,7 @@
}
// Update WireGuard config file
- err = s.WriteWireGuardConfigFile()
+ err = s.WriteWireGuardConfigFile(currentSession.DeviceName)
if err != nil {
_ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update WireGuard config-file: "+err.Error(), "danger")
@@ -86,12 +79,12 @@
// Update interface IP address
if s.config.WG.ManageIPAddresses {
- if err := s.wg.SetIPAddress(formDevice.IPs); err != nil {
+ if err := s.wg.SetIPAddress(currentSession.DeviceName, formDevice.IPs); err != nil {
_ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update ip address: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
}
- if err := s.wg.SetMTU(formDevice.Mtu); err != nil {
+ if err := s.wg.SetMTU(currentSession.DeviceName, formDevice.Mtu); err != nil {
_ = s.updateFormInSession(c, formDevice)
SetFlashMessage(c, "Failed to update MTU: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit?formerr=update")
@@ -106,9 +99,10 @@
}
func (s *Server) GetInterfaceConfig(c *gin.Context) {
- device := s.peers.GetDevice()
- users := s.peers.GetActivePeers()
- cfg, err := device.GetConfigFile(users)
+ currentSession := GetSessionData(c)
+ device := s.peers.GetDevice(currentSession.DeviceName)
+ peers := s.peers.GetActivePeers(device.DeviceName)
+ cfg, err := device.GetConfigFile(peers)
if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return
@@ -122,13 +116,14 @@
}
func (s *Server) GetApplyGlobalConfig(c *gin.Context) {
- device := s.peers.GetDevice()
- users := s.peers.GetAllPeers()
+ currentSession := GetSessionData(c)
+ device := s.peers.GetDevice(currentSession.DeviceName)
+ peers := s.peers.GetAllPeers(device.DeviceName)
- for _, user := range users {
- user.AllowedIPs = device.AllowedIPs
- user.AllowedIPsStr = device.AllowedIPsStr
- if err := s.peers.UpdatePeer(user); err != nil {
+ for _, peer := range peers {
+ peer.AllowedIPs = device.AllowedIPs
+ peer.AllowedIPsStr = device.AllowedIPsStr
+ if err := s.peers.UpdatePeer(peer); err != nil {
SetFlashMessage(c, err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/device/edit")
}
diff --git a/internal/server/handlers_peer.go b/internal/server/handlers_peer.go
index ae05236..9f2de38 100644
--- a/internal/server/handlers_peer.go
+++ b/internal/server/handlers_peer.go
@@ -8,9 +8,10 @@
"strings"
"time"
+ "github.com/h44z/wg-portal/internal/wireguard"
+
"github.com/gin-gonic/gin"
"github.com/h44z/wg-portal/internal/common"
- "github.com/h44z/wg-portal/internal/users"
"github.com/sirupsen/logrus"
"github.com/tatsushid/go-fastping"
)
@@ -21,7 +22,6 @@
}
func (s *Server) GetAdminEditPeer(c *gin.Context) {
- device := s.peers.GetDevice()
peer := s.peers.GetPeerByKey(c.Query("pkey"))
currentSession, err := s.setFormInSession(c, peer)
@@ -30,22 +30,15 @@
return
}
- c.HTML(http.StatusOK, "admin_edit_client.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Peer Peer
- Device Device
- EditableKeys bool
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Peer: currentSession.FormData.(Peer),
- Device: device,
- EditableKeys: s.config.Core.EditableKeys,
+ c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Peer": currentSession.FormData.(wireguard.Peer),
+ "EditableKeys": s.config.Core.EditableKeys,
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -54,9 +47,9 @@
urlEncodedKey := url.QueryEscape(c.Query("pkey"))
currentSession := GetSessionData(c)
- var formPeer Peer
+ var formPeer wireguard.Peer
if currentSession.FormData != nil {
- formPeer = currentSession.FormData.(Peer)
+ formPeer = currentSession.FormData.(wireguard.Peer)
}
if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer)
@@ -92,37 +85,28 @@
}
func (s *Server) GetAdminCreatePeer(c *gin.Context) {
- device := s.peers.GetDevice()
-
currentSession, err := s.setNewPeerFormInSession(c)
if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "Session error", err.Error())
return
}
- c.HTML(http.StatusOK, "admin_edit_client.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Peer Peer
- Device Device
- EditableKeys bool
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Peer: currentSession.FormData.(Peer),
- Device: device,
- EditableKeys: s.config.Core.EditableKeys,
+ c.HTML(http.StatusOK, "admin_edit_client.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Peer": currentSession.FormData.(wireguard.Peer),
+ "EditableKeys": s.config.Core.EditableKeys,
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
func (s *Server) PostAdminCreatePeer(c *gin.Context) {
currentSession := GetSessionData(c)
- var formPeer Peer
+ var formPeer wireguard.Peer
if currentSession.FormData != nil {
- formPeer = currentSession.FormData.(Peer)
+ formPeer = currentSession.FormData.(wireguard.Peer)
}
if err := c.ShouldBind(&formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer)
@@ -143,7 +127,7 @@
formPeer.DeactivatedAt = &now
}
- if err := s.CreatePeer(formPeer); err != nil {
+ if err := s.CreatePeer(currentSession.DeviceName, formPeer); err != nil {
_ = s.updateFormInSession(c, formPeer)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/create?formerr=create")
@@ -161,22 +145,15 @@
return
}
- c.HTML(http.StatusOK, "admin_create_clients.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Users []users.User
- FormData LdapCreateForm
- Device Device
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Users: s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
- FormData: currentSession.FormData.(LdapCreateForm),
- Device: s.peers.GetDevice(),
+ c.HTML(http.StatusOK, "admin_create_clients.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Users": s.users.GetFilteredAndSortedUsers("lastname", "asc", ""),
+ "FormData": currentSession.FormData.(LdapCreateForm),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -207,7 +184,7 @@
logrus.Infof("creating %d ldap peers", len(emails))
for i := range emails {
- if err := s.CreatePeerByEmail(emails[i], formData.Identifier, false); err != nil {
+ if err := s.CreatePeerByEmail(currentSession.DeviceName, emails[i], formData.Identifier, false); err != nil {
_ = s.updateFormInSession(c, formData)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/peer/createldap?formerr=create")
@@ -225,7 +202,7 @@
s.GetHandleError(c, http.StatusInternalServerError, "Deletion error", err.Error())
return
}
- SetFlashMessage(c, "user deleted successfully", "success")
+ SetFlashMessage(c, "peer deleted successfully", "success")
c.Redirect(http.StatusSeeOther, "/admin")
}
@@ -254,7 +231,7 @@
return
}
- cfg, err := user.GetConfigFile(s.peers.GetDevice())
+ cfg, err := user.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return
@@ -273,7 +250,7 @@
return
}
- cfg, err := user.GetConfigFile(s.peers.GetDevice())
+ cfg, err := user.GetConfigFile(s.peers.GetDevice(currentSession.DeviceName))
if err != nil {
s.GetHandleError(c, http.StatusInternalServerError, "ConfigFile error", err.Error())
return
@@ -286,7 +263,7 @@
// Apply mail template
var tplBuff bytes.Buffer
if err := s.mailTpl.Execute(&tplBuff, struct {
- Client Peer
+ Client wireguard.Peer
QrcodePngName string
PortalUrl string
}{
diff --git a/internal/server/handlers_user.go b/internal/server/handlers_user.go
index c96a37c..0bdfc95 100644
--- a/internal/server/handlers_user.go
+++ b/internal/server/handlers_user.go
@@ -49,22 +49,15 @@
dbUsers := s.users.GetFilteredAndSortedUsersUnscoped(currentSession.SortedBy["users"], currentSession.SortDirection["users"], currentSession.Search["users"])
- c.HTML(http.StatusOK, "admin_user_index.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- Users []users.User
- TotalUsers int
- Device Device
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- Users: dbUsers,
- TotalUsers: len(s.users.GetUsers()),
- Device: s.peers.GetDevice(),
+ c.HTML(http.StatusOK, "admin_user_index.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "Users": dbUsers,
+ "TotalUsers": len(s.users.GetUsers()),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -77,21 +70,14 @@
return
}
- c.HTML(http.StatusOK, "admin_edit_user.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- User users.User
- Device Device
- Epoch time.Time
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- User: currentSession.FormData.(users.User),
- Device: s.peers.GetDevice(),
+ c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "User": currentSession.FormData.(users.User),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -160,21 +146,14 @@
return
}
- c.HTML(http.StatusOK, "admin_edit_user.html", struct {
- Route string
- Alerts []FlashData
- Session SessionData
- Static StaticData
- User users.User
- Device Device
- Epoch time.Time
- }{
- Route: c.Request.URL.Path,
- Alerts: GetFlashes(c),
- Session: currentSession,
- Static: s.getStaticData(),
- User: currentSession.FormData.(users.User),
- Device: s.peers.GetDevice(),
+ c.HTML(http.StatusOK, "admin_edit_user.html", gin.H{
+ "Route": c.Request.URL.Path,
+ "Alerts": GetFlashes(c),
+ "Session": currentSession,
+ "Static": s.getStaticData(),
+ "User": currentSession.FormData.(users.User),
+ "Device": s.peers.GetDevice(currentSession.DeviceName),
+ "DeviceNames": s.wg.Cfg.DeviceNames,
})
}
@@ -218,7 +197,7 @@
formUser.IsAdmin = c.PostForm("isadmin") == "true"
formUser.Source = users.UserSourceDatabase
- if err := s.CreateUser(formUser); err != nil {
+ if err := s.CreateUser(formUser, currentSession.DeviceName); err != nil {
_ = s.updateFormInSession(c, formUser)
SetFlashMessage(c, "failed to add user: "+err.Error(), "danger")
c.Redirect(http.StatusSeeOther, "/admin/users/create?formerr=create")
diff --git a/internal/server/peermanager.go b/internal/server/peermanager.go
deleted file mode 100644
index ad95fd1..0000000
--- a/internal/server/peermanager.go
+++ /dev/null
@@ -1,724 +0,0 @@
-package server
-
-import (
- "bytes"
- "crypto/md5"
- "fmt"
- "net"
- "reflect"
- "regexp"
- "sort"
- "strings"
- "text/template"
- "time"
-
- "github.com/gin-gonic/gin/binding"
- "github.com/go-playground/validator/v10"
- "github.com/h44z/wg-portal/internal/common"
- "github.com/h44z/wg-portal/internal/users"
- "github.com/h44z/wg-portal/internal/wireguard"
- "github.com/pkg/errors"
- "github.com/sirupsen/logrus"
- "github.com/skip2/go-qrcode"
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "gorm.io/gorm"
-)
-
-//
-// CUSTOM VALIDATORS ----------------------------------------------------------------------------
-//
-var cidrList validator.Func = func(fl validator.FieldLevel) bool {
- cidrListStr := fl.Field().String()
-
- cidrList := common.ParseStringList(cidrListStr)
- for i := range cidrList {
- _, _, err := net.ParseCIDR(cidrList[i])
- if err != nil {
- return false
- }
- }
- return true
-}
-
-var ipList validator.Func = func(fl validator.FieldLevel) bool {
- ipListStr := fl.Field().String()
-
- ipList := common.ParseStringList(ipListStr)
- for i := range ipList {
- ip := net.ParseIP(ipList[i])
- if ip == nil {
- return false
- }
- }
- return true
-}
-
-func init() {
- if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
- _ = v.RegisterValidation("cidrlist", cidrList)
- _ = v.RegisterValidation("iplist", ipList)
- }
-}
-
-//
-// PEER ----------------------------------------------------------------------------------------
-//
-
-type Peer struct {
- Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer
- User *users.User `gorm:"-"` // user reference for the peer
- Config string `gorm:"-"`
-
- UID string `form:"uid" binding:"alphanum"` // uid for html identification
- IsOnline bool `gorm:"-"`
- IsNew bool `gorm:"-"`
- Identifier string `form:"identifier" binding:"required,lt=64"` // Identifier AND Email make a WireGuard peer unique
- Email string `gorm:"index" form:"mail" binding:"required,email"`
- LastHandshake string `gorm:"-"`
- LastHandshakeTime string `gorm:"-"`
-
- IgnorePersistentKeepalive bool `form:"ignorekeepalive"`
- PresharedKey string `form:"presharedkey" binding:"omitempty,base64"`
- AllowedIPsStr string `form:"allowedip" binding:"cidrlist"`
- IPsStr string `form:"ip" binding:"cidrlist"`
- AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
- IPs []string `gorm:"-"` // The IPs of the client
- PrivateKey string `form:"privkey" binding:"omitempty,base64"`
- PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"`
-
- DeactivatedAt *time.Time
- CreatedBy string
- UpdatedBy string
- CreatedAt time.Time
- UpdatedAt time.Time
-}
-
-func (p Peer) GetConfig() wgtypes.PeerConfig {
- publicKey, _ := wgtypes.ParseKey(p.PublicKey)
- var presharedKey *wgtypes.Key
- if p.PresharedKey != "" {
- presharedKeyTmp, _ := wgtypes.ParseKey(p.PresharedKey)
- presharedKey = &presharedKeyTmp
- }
-
- cfg := wgtypes.PeerConfig{
- PublicKey: publicKey,
- Remove: false,
- UpdateOnly: false,
- PresharedKey: presharedKey,
- Endpoint: nil,
- PersistentKeepaliveInterval: nil,
- ReplaceAllowedIPs: true,
- AllowedIPs: make([]net.IPNet, len(p.IPs)),
- }
- for i, ip := range p.IPs {
- _, ipNet, err := net.ParseCIDR(ip)
- if err == nil {
- cfg.AllowedIPs[i] = *ipNet
- }
- }
-
- return cfg
-}
-
-func (p Peer) GetConfigFile(device Device) ([]byte, error) {
- tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl)
- if err != nil {
- return nil, errors.Wrap(err, "failed to parse client template")
- }
-
- var tplBuff bytes.Buffer
-
- err = tpl.Execute(&tplBuff, struct {
- Client Peer
- Server Device
- }{
- Client: p,
- Server: device,
- })
- if err != nil {
- return nil, errors.Wrap(err, "failed to execute client template")
- }
-
- return tplBuff.Bytes(), nil
-}
-
-func (p Peer) GetQRCode() ([]byte, error) {
- png, err := qrcode.Encode(p.Config, qrcode.Medium, 250)
- if err != nil {
- logrus.WithFields(logrus.Fields{
- "err": err,
- }).Error("failed to create qrcode")
- return nil, errors.Wrap(err, "failed to encode qrcode")
- }
- return png, nil
-}
-
-func (p Peer) IsValid() bool {
- if p.PublicKey == "" {
- return false
- }
-
- return true
-}
-
-func (p Peer) ToMap() map[string]string {
- out := make(map[string]string)
-
- v := reflect.ValueOf(p)
- if v.Kind() == reflect.Ptr {
- v = v.Elem()
- }
-
- typ := v.Type()
- for i := 0; i < v.NumField(); i++ {
- // gets us a StructField
- fi := typ.Field(i)
- if tagv := fi.Tag.Get("form"); tagv != "" {
- // set key of map to value in struct field
- out[tagv] = v.Field(i).String()
- }
- }
- return out
-}
-
-func (p Peer) GetConfigFileName() string {
- reg := regexp.MustCompile("[^a-zA-Z0-9_-]+")
- return reg.ReplaceAllString(strings.ReplaceAll(p.Identifier, " ", "-"), "") + ".conf"
-}
-
-//
-// DEVICE --------------------------------------------------------------------------------------
-//
-
-type Device struct {
- Interface *wgtypes.Device `gorm:"-"`
-
- DeviceName string `form:"device" gorm:"primaryKey" binding:"required,alphanum"`
- PrivateKey string `form:"privkey" binding:"required,base64"`
- PublicKey string `form:"pubkey" binding:"required,base64"`
- PersistentKeepalive int `form:"keepalive" binding:"gte=0"`
- ListenPort int `form:"port" binding:"required,gt=0"`
- Mtu int `form:"mtu" binding:"gte=0,lte=1500"`
- Endpoint string `form:"endpoint" binding:"required,hostname_port"`
- AllowedIPsStr string `form:"allowedip" binding:"cidrlist"`
- IPsStr string `form:"ip" binding:"required,cidrlist"`
- AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
- IPs []string `gorm:"-"` // The IPs of the client
- DNSStr string `form:"dns" binding:"iplist"`
- DNS []string `gorm:"-"` // The DNS servers of the client
- PreUp string `form:"preup"`
- PostUp string `form:"postup"`
- PreDown string `form:"predown"`
- PostDown string `form:"postdown"`
- CreatedAt time.Time
- UpdatedAt time.Time
-}
-
-func (d Device) IsValid() bool {
- if d.PublicKey == "" {
- return false
- }
- if len(d.IPs) == 0 {
- return false
- }
- if d.Endpoint == "" {
- return false
- }
-
- return true
-}
-
-func (d Device) GetConfig() wgtypes.Config {
- var privateKey *wgtypes.Key
- if d.PrivateKey != "" {
- pKey, _ := wgtypes.ParseKey(d.PrivateKey)
- privateKey = &pKey
- }
-
- cfg := wgtypes.Config{
- PrivateKey: privateKey,
- ListenPort: &d.ListenPort,
- }
-
- return cfg
-}
-
-func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
- tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.DeviceCfgTpl)
- if err != nil {
- return nil, errors.Wrap(err, "failed to parse server template")
- }
-
- var tplBuff bytes.Buffer
-
- err = tpl.Execute(&tplBuff, struct {
- Clients []Peer
- Server Device
- }{
- Clients: peers,
- Server: d,
- })
- if err != nil {
- return nil, errors.Wrap(err, "failed to execute server template")
- }
-
- return tplBuff.Bytes(), nil
-}
-
-//
-// PEER-MANAGER --------------------------------------------------------------------------------
-//
-
-type PeerManager struct {
- db *gorm.DB
- wg *wireguard.Manager
- users *users.Manager
-}
-
-func NewPeerManager(cfg *common.Config, wg *wireguard.Manager, userDB *users.Manager) (*PeerManager, error) {
- um := &PeerManager{wg: wg, users: userDB}
- var err error
- um.db, err = users.GetDatabaseForConfig(&cfg.Database)
- if err != nil {
- return nil, errors.WithMessage(err, "failed to open peer database")
- }
-
- err = um.db.AutoMigrate(&Peer{}, &Device{})
- if err != nil {
- return nil, errors.WithMessage(err, "failed to migrate peer database")
- }
-
- return um, nil
-}
-
-func (u *PeerManager) InitFromCurrentInterface() error {
- peers, err := u.wg.GetPeerList()
- if err != nil {
- return errors.Wrapf(err, "failed to get peer list")
- }
- device, err := u.wg.GetDeviceInfo()
- if err != nil {
- return errors.Wrapf(err, "failed to get device info")
- }
- var ipAddresses []string
- var mtu int
- if u.wg.Cfg.ManageIPAddresses {
- if ipAddresses, err = u.wg.GetIPAddress(); err != nil {
- return errors.Wrapf(err, "failed to get ip address")
- }
- if mtu, err = u.wg.GetMTU(); err != nil {
- return errors.Wrapf(err, "failed to get MTU")
- }
- }
-
- // Check if entries already exist in database, if not create them
- for _, peer := range peers {
- if err := u.validateOrCreatePeer(peer); err != nil {
- return errors.WithMessagef(err, "failed to validate peer %s", peer.PublicKey)
- }
- }
- if err := u.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil {
- return errors.WithMessagef(err, "failed to validate device %s", device.Name)
- }
-
- return nil
-}
-
-func (u *PeerManager) validateOrCreatePeer(wgPeer wgtypes.Peer) error {
- peer := Peer{}
- u.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer)
-
- if peer.PublicKey == "" { // peer not found, create
- peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String())))
- peer.PublicKey = wgPeer.PublicKey.String()
- peer.PrivateKey = "" // UNKNOWN
- if wgPeer.PresharedKey != (wgtypes.Key{}) {
- peer.PresharedKey = wgPeer.PresharedKey.String()
- }
- peer.Email = "autodetected@example.com"
- peer.Identifier = "Autodetected (" + peer.PublicKey[0:8] + ")"
- peer.UpdatedAt = time.Now()
- peer.CreatedAt = time.Now()
- peer.AllowedIPs = make([]string, 0) // UNKNOWN
- peer.IPs = make([]string, len(wgPeer.AllowedIPs))
- for i, ip := range wgPeer.AllowedIPs {
- peer.IPs[i] = ip.String()
- }
- peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
- peer.IPsStr = strings.Join(peer.IPs, ", ")
-
- res := u.db.Create(&peer)
- if res.Error != nil {
- return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey)
- }
- }
-
- return nil
-}
-
-func (u *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error {
- device := Device{}
- u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
-
- if device.PublicKey == "" { // device not found, create
- device.PublicKey = dev.PublicKey.String()
- device.PrivateKey = dev.PrivateKey.String()
- device.DeviceName = dev.Name
- device.ListenPort = dev.ListenPort
- device.Mtu = 0
- device.PersistentKeepalive = 16 // Default
- device.IPsStr = strings.Join(ipAddresses, ", ")
- if mtu == wireguard.DefaultMTU {
- mtu = 0
- }
- device.Mtu = mtu
-
- res := u.db.Create(&device)
- if res.Error != nil {
- return errors.Wrapf(res.Error, "failed to create autodetected device")
- }
- }
-
- return nil
-}
-
-func (u *PeerManager) populatePeerData(peer *Peer) {
- peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ")
- peer.IPs = strings.Split(peer.IPsStr, ", ")
- // Set config file
- tmpCfg, _ := peer.GetConfigFile(u.GetDevice())
- peer.Config = string(tmpCfg)
-
- // set data from WireGuard interface
- peer.Peer, _ = u.wg.GetPeer(peer.PublicKey)
- peer.LastHandshake = "never"
- peer.LastHandshakeTime = "Never connected, or user is disabled."
- if peer.Peer != nil {
- since := time.Since(peer.Peer.LastHandshakeTime)
- sinceSeconds := int(since.Round(time.Second).Seconds())
- sinceMinutes := int(sinceSeconds / 60)
- sinceSeconds -= sinceMinutes * 60
-
- if sinceMinutes > 2*10080 { // 2 weeks
- peer.LastHandshake = "a while ago"
- } else if sinceMinutes > 10080 { // 1 week
- peer.LastHandshake = "a week ago"
- } else {
- peer.LastHandshake = fmt.Sprintf("%02dm %02ds", sinceMinutes, sinceSeconds)
- }
- peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate)
- }
- peer.IsOnline = false
-
- // set user data
- peer.User = u.users.GetUser(peer.Email)
-}
-
-func (u *PeerManager) populateDeviceData(device *Device) {
- device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
- device.IPs = strings.Split(device.IPsStr, ", ")
- device.DNS = strings.Split(device.DNSStr, ", ")
-
- // set data from WireGuard interface
- device.Interface, _ = u.wg.GetDeviceInfo()
-}
-
-func (u *PeerManager) GetAllPeers() []Peer {
- peers := make([]Peer, 0)
- u.db.Find(&peers)
-
- for i := range peers {
- u.populatePeerData(&peers[i])
- }
-
- return peers
-}
-
-func (u *PeerManager) GetActivePeers() []Peer {
- peers := make([]Peer, 0)
- u.db.Where("deactivated_at IS NULL").Find(&peers)
-
- for i := range peers {
- u.populatePeerData(&peers[i])
- }
-
- return peers
-}
-
-func (u *PeerManager) GetFilteredAndSortedPeers(sortKey, sortDirection, search string) []Peer {
- peers := make([]Peer, 0)
- u.db.Find(&peers)
-
- filteredPeers := make([]Peer, 0, len(peers))
- for i := range peers {
- u.populatePeerData(&peers[i])
-
- if search == "" ||
- strings.Contains(peers[i].Email, search) ||
- strings.Contains(peers[i].Identifier, search) ||
- strings.Contains(peers[i].PublicKey, search) {
- filteredPeers = append(filteredPeers, peers[i])
- }
- }
-
- sort.Slice(filteredPeers, func(i, j int) bool {
- var sortValueLeft string
- var sortValueRight string
-
- switch sortKey {
- case "id":
- sortValueLeft = filteredPeers[i].Identifier
- sortValueRight = filteredPeers[j].Identifier
- case "pubKey":
- sortValueLeft = filteredPeers[i].PublicKey
- sortValueRight = filteredPeers[j].PublicKey
- case "mail":
- sortValueLeft = filteredPeers[i].Email
- sortValueRight = filteredPeers[j].Email
- case "ip":
- sortValueLeft = filteredPeers[i].IPsStr
- sortValueRight = filteredPeers[j].IPsStr
- case "handshake":
- if filteredPeers[i].Peer == nil {
- return false
- } else if filteredPeers[j].Peer == nil {
- return true
- }
- sortValueLeft = filteredPeers[i].Peer.LastHandshakeTime.Format(time.RFC3339)
- sortValueRight = filteredPeers[j].Peer.LastHandshakeTime.Format(time.RFC3339)
- }
-
- if sortDirection == "asc" {
- return sortValueLeft < sortValueRight
- } else {
- return sortValueLeft > sortValueRight
- }
- })
-
- return filteredPeers
-}
-
-func (u *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer {
- peers := make([]Peer, 0)
- u.db.Where("email = ?", email).Find(&peers)
-
- for i := range peers {
- u.populatePeerData(&peers[i])
- }
-
- sort.Slice(peers, func(i, j int) bool {
- var sortValueLeft string
- var sortValueRight string
-
- switch sortKey {
- case "id":
- sortValueLeft = peers[i].Identifier
- sortValueRight = peers[j].Identifier
- case "pubKey":
- sortValueLeft = peers[i].PublicKey
- sortValueRight = peers[j].PublicKey
- case "mail":
- sortValueLeft = peers[i].Email
- sortValueRight = peers[j].Email
- case "ip":
- sortValueLeft = peers[i].IPsStr
- sortValueRight = peers[j].IPsStr
- case "handshake":
- if peers[i].Peer == nil {
- return true
- } else if peers[j].Peer == nil {
- return false
- }
- sortValueLeft = peers[i].Peer.LastHandshakeTime.Format(time.RFC3339)
- sortValueRight = peers[j].Peer.LastHandshakeTime.Format(time.RFC3339)
- }
-
- if sortDirection == "asc" {
- return sortValueLeft < sortValueRight
- } else {
- return sortValueLeft > sortValueRight
- }
- })
-
- return peers
-}
-
-func (u *PeerManager) GetDevice() Device {
- devices := make([]Device, 0, 1)
- u.db.Find(&devices)
-
- for i := range devices {
- u.populateDeviceData(&devices[i])
- }
-
- return devices[0] // use first device for now... more to come?
-}
-
-func (u *PeerManager) GetPeerByKey(publicKey string) Peer {
- peer := Peer{}
- u.db.Where("public_key = ?", publicKey).FirstOrInit(&peer)
- u.populatePeerData(&peer)
- return peer
-}
-
-func (u *PeerManager) GetPeersByMail(mail string) []Peer {
- var peers []Peer
- u.db.Where("email = ?", mail).Find(&peers)
- for i := range peers {
- u.populatePeerData(&peers[i])
- }
-
- return peers
-}
-
-func (u *PeerManager) CreatePeer(peer Peer) error {
- peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
- peer.UpdatedAt = time.Now()
- peer.CreatedAt = time.Now()
- peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
- peer.IPsStr = strings.Join(peer.IPs, ", ")
-
- res := u.db.Create(&peer)
- if res.Error != nil {
- logrus.Errorf("failed to create peer: %v", res.Error)
- return errors.Wrap(res.Error, "failed to create peer")
- }
-
- return nil
-}
-
-func (u *PeerManager) UpdatePeer(peer Peer) error {
- peer.UpdatedAt = time.Now()
- peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
- peer.IPsStr = strings.Join(peer.IPs, ", ")
-
- res := u.db.Save(&peer)
- if res.Error != nil {
- logrus.Errorf("failed to update peer: %v", res.Error)
- return errors.Wrap(res.Error, "failed to update peer")
- }
-
- return nil
-}
-
-func (u *PeerManager) DeletePeer(peer Peer) error {
- res := u.db.Delete(&peer)
- if res.Error != nil {
- logrus.Errorf("failed to delete peer: %v", res.Error)
- return errors.Wrap(res.Error, "failed to delete peer")
- }
-
- return nil
-}
-
-func (u *PeerManager) UpdateDevice(device Device) error {
- device.UpdatedAt = time.Now()
- device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ")
- device.IPsStr = strings.Join(device.IPs, ", ")
- device.DNSStr = strings.Join(device.DNS, ", ")
-
- res := u.db.Save(&device)
- if res.Error != nil {
- logrus.Errorf("failed to update device: %v", res.Error)
- return errors.Wrap(res.Error, "failed to update device")
- }
-
- return nil
-}
-
-func (u *PeerManager) GetAllReservedIps() ([]string, error) {
- reservedIps := make([]string, 0)
- peers := u.GetAllPeers()
- for _, user := range peers {
- for _, cidr := range user.IPs {
- if cidr == "" {
- continue
- }
- ip, _, err := net.ParseCIDR(cidr)
- if err != nil {
- return nil, errors.Wrap(err, "failed to parse cidr")
- }
- reservedIps = append(reservedIps, ip.String())
- }
- }
-
- device := u.GetDevice()
- for _, cidr := range device.IPs {
- if cidr == "" {
- continue
- }
- ip, _, err := net.ParseCIDR(cidr)
- if err != nil {
- return nil, errors.Wrap(err, "failed to parse cidr")
- }
-
- reservedIps = append(reservedIps, ip.String())
- }
-
- return reservedIps, nil
-}
-
-func (u *PeerManager) IsIPReserved(cidr string) bool {
- reserved, err := u.GetAllReservedIps()
- if err != nil {
- return true // in case something failed, assume the ip is reserved
- }
- ip, ipnet, err := net.ParseCIDR(cidr)
- if err != nil {
- return true
- }
-
- // this two addresses are not usable
- broadcastAddr := common.BroadcastAddr(ipnet).String()
- networkAddr := ipnet.IP.String()
- address := ip.String()
-
- if address == broadcastAddr || address == networkAddr {
- return true
- }
-
- for _, r := range reserved {
- if address == r {
- return true
- }
- }
-
- return false
-}
-
-// GetAvailableIp search for an available ip in cidr against a list of reserved ips
-func (u *PeerManager) GetAvailableIp(cidr string) (string, error) {
- reserved, err := u.GetAllReservedIps()
- if err != nil {
- return "", errors.WithMessage(err, "failed to get all reserved IP addresses")
- }
- ip, ipnet, err := net.ParseCIDR(cidr)
- if err != nil {
- return "", errors.Wrap(err, "failed to parse cidr")
- }
-
- // this two addresses are not usable
- broadcastAddr := common.BroadcastAddr(ipnet).String()
- networkAddr := ipnet.IP.String()
-
- for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); common.IncreaseIP(ip) {
- ok := true
- address := ip.String()
- for _, r := range reserved {
- if address == r {
- ok = false
- break
- }
- }
- if ok && address != networkAddr && address != broadcastAddr {
- netMask := "/32"
- if common.IsIPv6(address) {
- netMask = "/128"
- }
- return address + netMask, nil
- }
- }
-
- return "", errors.New("no more available address from cidr")
-}
diff --git a/internal/server/routes.go b/internal/server/routes.go
index e1d3a47..03cf998 100644
--- a/internal/server/routes.go
+++ b/internal/server/routes.go
@@ -4,14 +4,14 @@
"net/http"
"github.com/gin-gonic/gin"
- wg_portal "github.com/h44z/wg-portal"
+ wgportal "github.com/h44z/wg-portal"
)
func SetupRoutes(s *Server) {
// Startpage
s.server.GET("/", s.GetIndex)
s.server.GET("/favicon.ico", func(c *gin.Context) {
- file, _ := wg_portal.Statics.ReadFile("assets/img/favicon.ico")
+ file, _ := wgportal.Statics.ReadFile("assets/img/favicon.ico")
c.Data(
http.StatusOK,
"image/x-icon",
diff --git a/internal/server/server.go b/internal/server/server.go
index 8bb6f38..6b35a71 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -11,12 +11,15 @@
"net/url"
"os"
"path/filepath"
+ "strings"
"time"
+ "gorm.io/gorm"
+
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/memstore"
"github.com/gin-gonic/gin"
- wg_portal "github.com/h44z/wg-portal"
+ wgportal "github.com/h44z/wg-portal"
ldapprovider "github.com/h44z/wg-portal/internal/authentication/providers/ldap"
passwordprovider "github.com/h44z/wg-portal/internal/authentication/providers/password"
"github.com/h44z/wg-portal/internal/common"
@@ -32,18 +35,19 @@
func init() {
gob.Register(SessionData{})
gob.Register(FlashData{})
- gob.Register(Peer{})
- gob.Register(Device{})
+ gob.Register(wireguard.Peer{})
+ gob.Register(wireguard.Device{})
gob.Register(LdapCreateForm{})
gob.Register(users.User{})
}
type SessionData struct {
- LoggedIn bool
- IsAdmin bool
- Firstname string
- Lastname string
- Email string
+ LoggedIn bool
+ IsAdmin bool
+ Firstname string
+ Lastname string
+ Email string
+ DeviceName string
SortedBy map[string]string
SortDirection map[string]string
@@ -69,14 +73,15 @@
type Server struct {
ctx context.Context
- config *common.Config
+ config *Config
server *gin.Engine
mailTpl *template.Template
auth *AuthManager
+ db *gorm.DB
users *users.Manager
wg *wireguard.Manager
- peers *PeerManager
+ peers *wireguard.PeerManager
}
func (s *Server) Setup(ctx context.Context) error {
@@ -90,9 +95,15 @@
// Init rand
rand.Seed(time.Now().UnixNano())
- s.config = common.NewConfig()
+ s.config = NewConfig()
s.ctx = ctx
+ // Setup database connection
+ s.db, err = common.GetDatabaseForConfig(&s.config.Database)
+ if err != nil {
+ return errors.WithMessage(err, "database setup failed")
+ }
+
// Setup http server
gin.SetMode(gin.DebugMode)
gin.DefaultWriter = ioutil.Discard
@@ -104,24 +115,33 @@
s.server.SetFuncMap(template.FuncMap{
"formatBytes": common.ByteCountSI,
"urlEncode": url.QueryEscape,
+ "startsWith": strings.HasPrefix,
+ "userForEmail": func(users []users.User, email string) *users.User {
+ for i := range users {
+ if users[i].Email == email {
+ return &users[i]
+ }
+ }
+ return nil
+ },
})
// Setup templates
- templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wg_portal.Templates, "assets/tpl/*.html"))
+ templates := template.Must(template.New("").Funcs(s.server.FuncMap).ParseFS(wgportal.Templates, "assets/tpl/*.html"))
s.server.SetHTMLTemplate(templates)
s.server.Use(sessions.Sessions("authsession", memstore.NewStore([]byte("secret")))) // TODO: change key?
// Serve static files
- s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/css"))))
- s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/js"))))
- s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/img"))))
- s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wg_portal.Statics, "assets/fonts"))))
+ s.server.StaticFS("/css", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/css"))))
+ s.server.StaticFS("/js", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/js"))))
+ s.server.StaticFS("/img", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/img"))))
+ s.server.StaticFS("/fonts", http.FS(fsMust(fs.Sub(wgportal.Statics, "assets/fonts"))))
// Setup all routes
SetupRoutes(s)
// Setup user database (also needed for database authentication)
- s.users, err = users.NewManager(&s.config.Database)
+ s.users, err = users.NewManager(s.db)
if err != nil {
return errors.WithMessage(err, "user-manager initialization failed")
}
@@ -153,18 +173,21 @@
}
// Setup peer manager
- if s.peers, err = NewPeerManager(s.config, s.wg, s.users); err != nil {
+ if s.peers, err = wireguard.NewPeerManager(s.db, s.wg); err != nil {
return errors.WithMessage(err, "unable to setup peer manager")
}
- if err = s.peers.InitFromCurrentInterface(); err != nil {
- return errors.WithMessage(err, "unable to initialize peer manager")
+ if err = s.peers.InitFromPhysicalInterface(); err != nil {
+ return errors.WithMessagef(err, "unable to initialize peer manager")
}
- if err = s.RestoreWireGuardInterface(); err != nil {
- return errors.WithMessage(err, "unable to restore WireGuard state")
+
+ for _, deviceName := range s.wg.Cfg.DeviceNames {
+ if err = s.RestoreWireGuardInterface(deviceName); err != nil {
+ return errors.WithMessagef(err, "unable to restore WireGuard state for %s", deviceName)
+ }
}
// Setup mail template
- s.mailTpl, err = template.New("email.html").ParseFS(wg_portal.Templates, "assets/tpl/email.html")
+ s.mailTpl, err = template.New("email.html").ParseFS(wgportal.Templates, "assets/tpl/email.html")
if err != nil {
return errors.Wrap(err, "unable to pare mail template")
}
@@ -174,6 +197,8 @@
}
func (s *Server) Run() {
+ logrus.Infof("starting web service on %s", s.config.Core.ListeningAddress)
+
// Start ldap sync
if s.config.Core.LdapEnabled {
go s.SyncLdapWithUserDatabase()
@@ -238,6 +263,7 @@
Email: "",
Firstname: "",
Lastname: "",
+ DeviceName: "",
IsAdmin: false,
LoggedIn: false,
}
diff --git a/internal/server/server_helper.go b/internal/server/server_helper.go
index 41fd7dc..177a4fe 100644
--- a/internal/server/server_helper.go
+++ b/internal/server/server_helper.go
@@ -4,9 +4,12 @@
"crypto/md5"
"fmt"
"io/ioutil"
+ "path"
"syscall"
"time"
+ "github.com/h44z/wg-portal/internal/wireguard"
+
"github.com/h44z/wg-portal/internal/common"
"github.com/h44z/wg-portal/internal/users"
"github.com/pkg/errors"
@@ -15,28 +18,29 @@
"gorm.io/gorm"
)
-func (s *Server) PrepareNewPeer() (Peer, error) {
- device := s.peers.GetDevice()
+// PrepareNewPeer initiates a new peer for the given WireGuard device.
+func (s *Server) PrepareNewPeer(device string) (wireguard.Peer, error) {
+ dev := s.peers.GetDevice(device)
- peer := Peer{}
+ peer := wireguard.Peer{}
peer.IsNew = true
- peer.AllowedIPsStr = device.AllowedIPsStr
- peer.IPs = make([]string, len(device.IPs))
- for i := range device.IPs {
- freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
+ peer.AllowedIPsStr = dev.AllowedIPsStr
+ peer.IPs = make([]string, len(dev.IPs))
+ for i := range dev.IPs {
+ freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil {
- return Peer{}, errors.WithMessage(err, "failed to get available IP addresses")
+ return wireguard.Peer{}, errors.WithMessage(err, "failed to get available IP addresses")
}
peer.IPs[i] = freeIP
}
peer.IPsStr = common.ListToString(peer.IPs)
psk, err := wgtypes.GenerateKey()
if err != nil {
- return Peer{}, errors.Wrap(err, "failed to generate key")
+ return wireguard.Peer{}, errors.Wrap(err, "failed to generate key")
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
- return Peer{}, errors.Wrap(err, "failed to generate private key")
+ return wireguard.Peer{}, errors.Wrap(err, "failed to generate private key")
}
peer.PresharedKey = psk.String()
peer.PrivateKey = key.String()
@@ -46,54 +50,39 @@
return peer, nil
}
-func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error {
+// CreatePeerByEmail creates a new peer for the given email. If no user with the specified email was found, a new one
+// will be created.
+func (s *Server) CreatePeerByEmail(device, email, identifierSuffix string, disabled bool) error {
user, err := s.users.GetOrCreateUser(email)
if err != nil {
return errors.WithMessagef(err, "failed to load/create related user %s", email)
}
- device := s.peers.GetDevice()
- peer := Peer{}
- peer.User = user
- peer.AllowedIPsStr = device.AllowedIPsStr
- peer.IPs = make([]string, len(device.IPs))
- for i := range device.IPs {
- freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
- if err != nil {
- return errors.WithMessage(err, "failed to get available IP addresses")
- }
- peer.IPs[i] = freeIP
- }
- peer.IPsStr = common.ListToString(peer.IPs)
- psk, err := wgtypes.GenerateKey()
+ peer, err := s.PrepareNewPeer(device)
if err != nil {
- return errors.Wrap(err, "failed to generate key")
+ return errors.WithMessage(err, "failed to prepare new peer")
}
- key, err := wgtypes.GeneratePrivateKey()
- if err != nil {
- return errors.Wrap(err, "failed to generate private key")
- }
- peer.PresharedKey = psk.String()
- peer.PrivateKey = key.String()
- peer.PublicKey = key.PublicKey().String()
- peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
peer.Email = email
peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix)
+
now := time.Now()
if disabled {
peer.DeactivatedAt = &now
}
- return s.CreatePeer(peer)
+ return s.CreatePeer(device, peer)
}
-func (s *Server) CreatePeer(peer Peer) error {
- device := s.peers.GetDevice()
- peer.AllowedIPsStr = device.AllowedIPsStr
+// CreatePeer creates the new peer in the database. If the peer has no assigned ip addresses, a new one will be assigned
+// automatically. Also, if the private key is empty, a new key-pair will be generated.
+// This function also configures the new peer on the physical WireGuard interface if the peer is not deactivated.
+func (s *Server) CreatePeer(device string, peer wireguard.Peer) error {
+ dev := s.peers.GetDevice(device)
+ peer.AllowedIPsStr = dev.AllowedIPsStr
if peer.IPs == nil || len(peer.IPs) == 0 {
- peer.IPs = make([]string, len(device.IPs))
- for i := range device.IPs {
- freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
+ peer.IPs = make([]string, len(dev.IPs))
+ for i := range dev.IPs {
+ freeIP, err := s.peers.GetAvailableIp(device, dev.IPs[i])
if err != nil {
return errors.WithMessage(err, "failed to get available IP addresses")
}
@@ -114,11 +103,12 @@
peer.PrivateKey = key.String()
peer.PublicKey = key.PublicKey().String()
}
+ peer.DeviceName = dev.DeviceName
peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
// Create WireGuard interface
if peer.DeactivatedAt == nil {
- if err := s.wg.AddPeer(peer.GetConfig()); err != nil {
+ if err := s.wg.AddPeer(device, peer.GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer")
}
}
@@ -128,21 +118,22 @@
return errors.WithMessage(err, "failed to create peer")
}
- return s.WriteWireGuardConfigFile()
+ return s.WriteWireGuardConfigFile(device)
}
-func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
+// UpdatePeer updates the physical WireGuard interface and the database.
+func (s *Server) UpdatePeer(peer wireguard.Peer, updateTime time.Time) error {
currentPeer := s.peers.GetPeerByKey(peer.PublicKey)
// Update WireGuard device
var err error
switch {
case peer.DeactivatedAt == &updateTime:
- err = s.wg.RemovePeer(peer.PublicKey)
+ err = s.wg.RemovePeer(peer.DeviceName, peer.PublicKey)
case peer.DeactivatedAt == nil && currentPeer.Peer != nil:
- err = s.wg.UpdatePeer(peer.GetConfig())
+ err = s.wg.UpdatePeer(peer.DeviceName, peer.GetConfig())
case peer.DeactivatedAt == nil && currentPeer.Peer == nil:
- err = s.wg.AddPeer(peer.GetConfig())
+ err = s.wg.AddPeer(peer.DeviceName, peer.GetConfig())
}
if err != nil {
return errors.WithMessage(err, "failed to update WireGuard peer")
@@ -153,12 +144,13 @@
return errors.WithMessage(err, "failed to update peer")
}
- return s.WriteWireGuardConfigFile()
+ return s.WriteWireGuardConfigFile(peer.DeviceName)
}
-func (s *Server) DeletePeer(peer Peer) error {
+// DeletePeer removes the peer from the physical WireGuard interface and the database.
+func (s *Server) DeletePeer(peer wireguard.Peer) error {
// Delete WireGuard peer
- if err := s.wg.RemovePeer(peer.PublicKey); err != nil {
+ if err := s.wg.RemovePeer(peer.DeviceName, peer.PublicKey); err != nil {
return errors.WithMessage(err, "failed to remove WireGuard peer")
}
@@ -167,15 +159,16 @@
return errors.WithMessage(err, "failed to remove peer")
}
- return s.WriteWireGuardConfigFile()
+ return s.WriteWireGuardConfigFile(peer.DeviceName)
}
-func (s *Server) RestoreWireGuardInterface() error {
- activePeers := s.peers.GetActivePeers()
+// RestoreWireGuardInterface restores the state of the physical WireGuard interface from the database.
+func (s *Server) RestoreWireGuardInterface(device string) error {
+ activePeers := s.peers.GetActivePeers(device)
for i := range activePeers {
if activePeers[i].Peer == nil {
- if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil {
+ if err := s.wg.AddPeer(device, activePeers[i].GetConfig()); err != nil {
return errors.WithMessage(err, "failed to add WireGuard peer")
}
}
@@ -184,26 +177,29 @@
return nil
}
-func (s *Server) WriteWireGuardConfigFile() error {
- if s.config.WG.WireGuardConfig == "" {
+// WriteWireGuardConfigFile writes the configuration file for the physical WireGuard interface.
+func (s *Server) WriteWireGuardConfigFile(device string) error {
+ if s.config.WG.ConfigDirectoryPath == "" {
return nil // writing disabled
}
- if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil {
+ if err := syscall.Access(s.config.WG.ConfigDirectoryPath, syscall.O_RDWR); err != nil {
return errors.Wrap(err, "failed to check WireGuard config access rights")
}
- device := s.peers.GetDevice()
- cfg, err := device.GetConfigFile(s.peers.GetActivePeers())
+ dev := s.peers.GetDevice(device)
+ cfg, err := dev.GetConfigFile(s.peers.GetActivePeers(device))
if err != nil {
return errors.WithMessage(err, "failed to get config file")
}
- if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil {
+ filePath := path.Join(s.config.WG.ConfigDirectoryPath, dev.DeviceName+".conf")
+ if err := ioutil.WriteFile(filePath, cfg, 0644); err != nil {
return errors.Wrap(err, "failed to write WireGuard config file")
}
return nil
}
-func (s *Server) CreateUser(user users.User) error {
+// CreateUser creates the user in the database and optionally adds a default WireGuard peer for the user.
+func (s *Server) CreateUser(user users.User, device string) error {
if user.Email == "" {
return errors.New("cannot create user with empty email address")
}
@@ -220,9 +216,11 @@
}
// Check if user already has a peer setup, if not, create one
- return s.CreateUserDefaultPeer(user.Email)
+ return s.CreateUserDefaultPeer(user.Email, device)
}
+// UpdateUser updates the user in the database. If the user is marked as deleted, it will get remove from the database.
+// Also, if the user is re-enabled, all it's linked WireGuard peers will be activated again.
func (s *Server) UpdateUser(user users.User) error {
if user.DeletedAt.Valid {
return s.DeleteUser(user)
@@ -249,6 +247,8 @@
return nil
}
+// DeleteUser removes the user from the database.
+// Also, if the user has linked WireGuard peers, they will be deactivated.
func (s *Server) DeleteUser(user users.User) error {
currentUser := s.users.GetUserUnscoped(user.Email)
@@ -271,7 +271,7 @@
return nil
}
-func (s *Server) CreateUserDefaultPeer(email string) error {
+func (s *Server) CreateUserDefaultPeer(email, device string) error {
// Check if user is active, if not, quit
var existingUser *users.User
if existingUser = s.users.GetUser(email); existingUser == nil {
@@ -282,7 +282,7 @@
if s.config.Core.CreateDefaultPeer {
peers := s.peers.GetPeersByMail(email)
if len(peers) == 0 { // Create default vpn peer
- if err := s.CreatePeer(Peer{
+ if err := s.CreatePeer(device, wireguard.Peer{
Identifier: existingUser.Firstname + " " + existingUser.Lastname + " (Default)",
Email: existingUser.Email,
CreatedBy: existingUser.Email,
diff --git a/internal/users/config.go b/internal/users/config.go
deleted file mode 100644
index 98e8954..0000000
--- a/internal/users/config.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package users
-
-type SupportedDatabase string
-
-const (
- SupportedDatabaseMySQL SupportedDatabase = "mysql"
- SupportedDatabaseSQLite SupportedDatabase = "sqlite"
-)
-
-type Config struct {
- Typ SupportedDatabase `yaml:"typ" envconfig:"DATABASE_TYPE"` //mysql or sqlite
- Host string `yaml:"host" envconfig:"DATABASE_HOST"`
- Port int `yaml:"port" envconfig:"DATABASE_PORT"`
- Database string `yaml:"database" envconfig:"DATABASE_NAME"` // On SQLite: the database file-path, otherwise the database name
- User string `yaml:"user" envconfig:"DATABASE_USERNAME"`
- Password string `yaml:"password" envconfig:"DATABASE_PASSWORD"`
-}
diff --git a/internal/users/manager.go b/internal/users/manager.go
index 1c06154..8cd9fe7 100644
--- a/internal/users/manager.go
+++ b/internal/users/manager.go
@@ -1,9 +1,6 @@
package users
import (
- "fmt"
- "os"
- "path/filepath"
"sort"
"strconv"
"strings"
@@ -11,69 +8,15 @@
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
- "gorm.io/driver/mysql"
- "gorm.io/driver/sqlite"
"gorm.io/gorm"
- "gorm.io/gorm/logger"
)
-func GetDatabaseForConfig(cfg *Config) (db *gorm.DB, err error) {
- switch cfg.Typ {
- case SupportedDatabaseSQLite:
- if _, err = os.Stat(filepath.Dir(cfg.Database)); os.IsNotExist(err) {
- if err = os.MkdirAll(filepath.Dir(cfg.Database), 0700); err != nil {
- return
- }
- }
- db, err = gorm.Open(sqlite.Open(cfg.Database), &gorm.Config{})
- if err != nil {
- return
- }
- case SupportedDatabaseMySQL:
- connectionString := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database)
- db, err = gorm.Open(mysql.Open(connectionString), &gorm.Config{})
- if err != nil {
- return
- }
-
- sqlDB, _ := db.DB()
- sqlDB.SetConnMaxLifetime(time.Minute * 5)
- sqlDB.SetMaxIdleConns(2)
- sqlDB.SetMaxOpenConns(10)
- err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible
- if err != nil {
- return nil, errors.Wrap(err, "failed to ping mysql authentication database")
- }
- }
-
- // Enable Logger (logrus)
- logCfg := logger.Config{
- SlowThreshold: time.Second, // all slower than one second
- Colorful: false,
- LogLevel: logger.Silent, // default: log nothing
- }
-
- if logrus.StandardLogger().GetLevel() == logrus.TraceLevel {
- logCfg.LogLevel = logger.Info
- logCfg.SlowThreshold = 500 * time.Millisecond // all slower than half a second
- }
-
- db.Config.Logger = logger.New(logrus.StandardLogger(), logCfg)
- return
-}
-
type Manager struct {
db *gorm.DB
}
-func NewManager(cfg *Config) (*Manager, error) {
- m := &Manager{}
-
- var err error
- m.db, err = GetDatabaseForConfig(cfg)
- if err != nil {
- return nil, errors.Wrapf(err, "failed to setup user database %s", cfg.Database)
- }
+func NewManager(db *gorm.DB) (*Manager, error) {
+ m := &Manager{db: db}
// check if old user table exists (from version <= 1.0.2), if so rename it to peers.
if m.db.Migrator().HasTable("users") && !m.db.Migrator().HasTable("peers") {
@@ -84,14 +27,11 @@
}
}
- return m, m.MigrateUserDB()
-}
-
-func (m Manager) MigrateUserDB() error {
if err := m.db.AutoMigrate(&User{}); err != nil {
- return errors.Wrap(err, "failed to migrate user database")
+ return nil, errors.Wrap(err, "failed to migrate user database")
}
- return nil
+
+ return m, nil
}
func (m Manager) GetUsers() []User {
diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go
index 5b9d4c6..027d928 100644
--- a/internal/wireguard/config.go
+++ b/internal/wireguard/config.go
@@ -1,7 +1,8 @@
package wireguard
type Config struct {
- DeviceName string `yaml:"device" envconfig:"WG_DEVICE"`
- WireGuardConfig string `yaml:"configFile" envconfig:"WG_CONFIG_FILE"` // optional, if set, updates will be written to this file
- ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface
+ DeviceNames []string `yaml:"devices" envconfig:"WG_DEVICES"` // managed devices
+ DefaultDeviceName string `yaml:"devices" envconfig:"WG_DEFAULT_DEVICE"` // this device is used for auto-created peers
+ ConfigDirectoryPath string `yaml:"configDirectory" envconfig:"WG_CONFIG_PATH"` // optional, if set, updates will be written to this path, filename: .conf
+ ManageIPAddresses bool `yaml:"manageIPAddresses" envconfig:"MANAGE_IPS"` // handle ip-address setup of interface
}
diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go
index f7f2a78..6cd0de5 100644
--- a/internal/wireguard/manager.go
+++ b/internal/wireguard/manager.go
@@ -9,6 +9,7 @@
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
+// Manager offers a synchronized management interface to the real WireGuard interface.
type Manager struct {
Cfg *Config
wg *wgctrl.Client
@@ -25,8 +26,8 @@
return nil
}
-func (m *Manager) GetDeviceInfo() (*wgtypes.Device, error) {
- dev, err := m.wg.Device(m.Cfg.DeviceName)
+func (m *Manager) GetDeviceInfo(device string) (*wgtypes.Device, error) {
+ dev, err := m.wg.Device(device)
if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device")
}
@@ -34,11 +35,11 @@
return dev, nil
}
-func (m *Manager) GetPeerList() ([]wgtypes.Peer, error) {
+func (m *Manager) GetPeerList(device string) ([]wgtypes.Peer, error) {
m.mux.RLock()
defer m.mux.RUnlock()
- dev, err := m.wg.Device(m.Cfg.DeviceName)
+ dev, err := m.wg.Device(device)
if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard device")
}
@@ -46,7 +47,7 @@
return dev.Peers, nil
}
-func (m *Manager) GetPeer(pubKey string) (*wgtypes.Peer, error) {
+func (m *Manager) GetPeer(device string, pubKey string) (*wgtypes.Peer, error) {
m.mux.RLock()
defer m.mux.RUnlock()
@@ -55,7 +56,7 @@
return nil, errors.Wrap(err, "invalid public key")
}
- peers, err := m.GetPeerList()
+ peers, err := m.GetPeerList(device)
if err != nil {
return nil, errors.Wrap(err, "could not get WireGuard peers")
}
@@ -69,11 +70,11 @@
return nil, errors.Errorf("could not find WireGuard peer: %s", pubKey)
}
-func (m *Manager) AddPeer(cfg wgtypes.PeerConfig) error {
+func (m *Manager) AddPeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock()
defer m.mux.Unlock()
- err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
+ err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return errors.Wrap(err, "could not configure WireGuard device")
}
@@ -81,12 +82,12 @@
return nil
}
-func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error {
+func (m *Manager) UpdatePeer(device string, cfg wgtypes.PeerConfig) error {
m.mux.Lock()
defer m.mux.Unlock()
cfg.UpdateOnly = true
- err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
+ err := m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
if err != nil {
return errors.Wrap(err, "could not configure WireGuard device")
}
@@ -94,7 +95,7 @@
return nil
}
-func (m *Manager) RemovePeer(pubKey string) error {
+func (m *Manager) RemovePeer(device string, pubKey string) error {
m.mux.Lock()
defer m.mux.Unlock()
@@ -108,7 +109,7 @@
Remove: true,
}
- err = m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
+ err = m.wg.ConfigureDevice(device, wgtypes.Config{Peers: []wgtypes.PeerConfig{peer}})
if err != nil {
return errors.Wrap(err, "could not configure WireGuard device")
}
@@ -116,6 +117,6 @@
return nil
}
-func (m *Manager) UpdateDevice(name string, cfg wgtypes.Config) error {
- return m.wg.ConfigureDevice(name, cfg)
+func (m *Manager) UpdateDevice(device string, cfg wgtypes.Config) error {
+ return m.wg.ConfigureDevice(device, cfg)
}
diff --git a/internal/wireguard/manager_net.go b/internal/wireguard/manager_net.go
new file mode 100644
index 0000000..76ec646
--- /dev/null
+++ b/internal/wireguard/manager_net.go
@@ -0,0 +1,122 @@
+package wireguard
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/pkg/errors"
+
+ "github.com/milosgajdos/tenus"
+)
+
+const DefaultMTU = 1420
+
+func (m *Manager) GetIPAddress(device string) ([]string, error) {
+ wgInterface, err := tenus.NewLinkFrom(device)
+ if err != nil {
+ return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
+ }
+
+ // Get golang net.interface
+ iface := wgInterface.NetInterface()
+ if iface == nil { // Not sure if this check is really necessary
+ return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface")
+ }
+
+ addrs, err := iface.Addrs()
+ if err != nil {
+ return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses")
+ }
+
+ ipAddresses := make([]string, 0, len(addrs))
+ for _, addr := range addrs {
+ var ip net.IP
+ var mask net.IPMask
+ switch v := addr.(type) {
+ case *net.IPNet:
+ ip = v.IP
+ mask = v.Mask
+ case *net.IPAddr:
+ ip = v.IP
+ mask = ip.DefaultMask()
+ }
+ if ip == nil || mask == nil {
+ continue // something is wrong?
+ }
+
+ maskSize, _ := mask.Size()
+ cidr := fmt.Sprintf("%s/%d", ip.String(), maskSize)
+ ipAddresses = append(ipAddresses, cidr)
+ }
+
+ return ipAddresses, nil
+}
+
+func (m *Manager) SetIPAddress(device string, cidrs []string) error {
+ wgInterface, err := tenus.NewLinkFrom(device)
+ if err != nil {
+ return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
+ }
+
+ // First remove existing IP addresses
+ existingIPs, err := m.GetIPAddress(device)
+ if err != nil {
+ return errors.Wrap(err, "could not retrieve IP addresses")
+ }
+ for _, cidr := range existingIPs {
+ wgIp, wgIpNet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return errors.Wrapf(err, "unable to parse cidr %s", cidr)
+ }
+
+ if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil {
+ return errors.Wrapf(err, "failed to unset ip %s", cidr)
+ }
+ }
+
+ // Next set new IP addresses
+ for _, cidr := range cidrs {
+ wgIp, wgIpNet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return errors.Wrapf(err, "unable to parse cidr %s", cidr)
+ }
+
+ if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil {
+ return errors.Wrapf(err, "failed to set ip %s", cidr)
+ }
+ }
+
+ return nil
+}
+
+func (m *Manager) GetMTU(device string) (int, error) {
+ wgInterface, err := tenus.NewLinkFrom(device)
+ if err != nil {
+ return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
+ }
+
+ // Get golang net.interface
+ iface := wgInterface.NetInterface()
+ if iface == nil { // Not sure if this check is really necessary
+ return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface")
+ }
+
+ return iface.MTU, nil
+}
+
+func (m *Manager) SetMTU(device string, mtu int) error {
+ wgInterface, err := tenus.NewLinkFrom(device)
+ if err != nil {
+ return errors.Wrapf(err, "could not retrieve WireGuard interface %s", device)
+ }
+
+ if mtu == 0 {
+ mtu = DefaultMTU
+ }
+
+ if err := wgInterface.SetLinkMTU(mtu); err != nil {
+ return errors.Wrapf(err, "could not set MTU on interface %s", device)
+ }
+
+ return nil
+}
diff --git a/internal/wireguard/net.go b/internal/wireguard/net.go
deleted file mode 100644
index 0b9e68b..0000000
--- a/internal/wireguard/net.go
+++ /dev/null
@@ -1,122 +0,0 @@
-package wireguard
-
-import (
- "fmt"
- "net"
-
- "github.com/pkg/errors"
-
- "github.com/milosgajdos/tenus"
-)
-
-const DefaultMTU = 1420
-
-func (m *Manager) GetIPAddress() ([]string, error) {
- wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
- if err != nil {
- return nil, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
- }
-
- // Get golang net.interface
- iface := wgInterface.NetInterface()
- if iface == nil { // Not sure if this check is really necessary
- return nil, errors.Wrap(err, "could not retrieve WireGuard net.interface")
- }
-
- addrs, err := iface.Addrs()
- if err != nil {
- return nil, errors.Wrap(err, "could not retrieve WireGuard ip addresses")
- }
-
- ipAddresses := make([]string, 0, len(addrs))
- for _, addr := range addrs {
- var ip net.IP
- var mask net.IPMask
- switch v := addr.(type) {
- case *net.IPNet:
- ip = v.IP
- mask = v.Mask
- case *net.IPAddr:
- ip = v.IP
- mask = ip.DefaultMask()
- }
- if ip == nil || mask == nil {
- continue // something is wrong?
- }
-
- maskSize, _ := mask.Size()
- cidr := fmt.Sprintf("%s/%d", ip.String(), maskSize)
- ipAddresses = append(ipAddresses, cidr)
- }
-
- return ipAddresses, nil
-}
-
-func (m *Manager) SetIPAddress(cidrs []string) error {
- wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
- if err != nil {
- return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
- }
-
- // First remove existing IP addresses
- existingIPs, err := m.GetIPAddress()
- if err != nil {
- return errors.Wrap(err, "could not retrieve IP addresses")
- }
- for _, cidr := range existingIPs {
- wgIp, wgIpNet, err := net.ParseCIDR(cidr)
- if err != nil {
- return errors.Wrapf(err, "unable to parse cidr %s", cidr)
- }
-
- if err := wgInterface.UnsetLinkIp(wgIp, wgIpNet); err != nil {
- return errors.Wrapf(err, "failed to unset ip %s", cidr)
- }
- }
-
- // Next set new IP addresses
- for _, cidr := range cidrs {
- wgIp, wgIpNet, err := net.ParseCIDR(cidr)
- if err != nil {
- return errors.Wrapf(err, "unable to parse cidr %s", cidr)
- }
-
- if err := wgInterface.SetLinkIp(wgIp, wgIpNet); err != nil {
- return errors.Wrapf(err, "failed to set ip %s", cidr)
- }
- }
-
- return nil
-}
-
-func (m *Manager) GetMTU() (int, error) {
- wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
- if err != nil {
- return 0, errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
- }
-
- // Get golang net.interface
- iface := wgInterface.NetInterface()
- if iface == nil { // Not sure if this check is really necessary
- return 0, errors.Wrap(err, "could not retrieve WireGuard net.interface")
- }
-
- return iface.MTU, nil
-}
-
-func (m *Manager) SetMTU(mtu int) error {
- wgInterface, err := tenus.NewLinkFrom(m.Cfg.DeviceName)
- if err != nil {
- return errors.Wrapf(err, "could not retrieve WireGuard interface %s", m.Cfg.DeviceName)
- }
-
- if mtu == 0 {
- mtu = DefaultMTU
- }
-
- if err := wgInterface.SetLinkMTU(mtu); err != nil {
- return errors.Wrapf(err, "could not set MTU on interface %s", m.Cfg.DeviceName)
- }
-
- return nil
-}
diff --git a/internal/wireguard/peermanager.go b/internal/wireguard/peermanager.go
new file mode 100644
index 0000000..099b491
--- /dev/null
+++ b/internal/wireguard/peermanager.go
@@ -0,0 +1,723 @@
+package wireguard
+
+import (
+ "bytes"
+ "crypto/md5"
+ "fmt"
+ "net"
+ "reflect"
+ "regexp"
+ "sort"
+ "strings"
+ "text/template"
+ "time"
+
+ "github.com/gin-gonic/gin/binding"
+ "github.com/go-playground/validator/v10"
+ "github.com/h44z/wg-portal/internal/common"
+ "github.com/pkg/errors"
+ "github.com/sirupsen/logrus"
+ "github.com/skip2/go-qrcode"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ "gorm.io/gorm"
+)
+
+//
+// CUSTOM VALIDATORS ----------------------------------------------------------------------------
+//
+var cidrList validator.Func = func(fl validator.FieldLevel) bool {
+ cidrListStr := fl.Field().String()
+
+ cidrList := common.ParseStringList(cidrListStr)
+ for i := range cidrList {
+ _, _, err := net.ParseCIDR(cidrList[i])
+ if err != nil {
+ return false
+ }
+ }
+ return true
+}
+
+var ipList validator.Func = func(fl validator.FieldLevel) bool {
+ ipListStr := fl.Field().String()
+
+ ipList := common.ParseStringList(ipListStr)
+ for i := range ipList {
+ ip := net.ParseIP(ipList[i])
+ if ip == nil {
+ return false
+ }
+ }
+ return true
+}
+
+func init() {
+ if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
+ _ = v.RegisterValidation("cidrlist", cidrList)
+ _ = v.RegisterValidation("iplist", ipList)
+ }
+}
+
+//
+// PEER ----------------------------------------------------------------------------------------
+//
+
+type Peer struct {
+ Peer *wgtypes.Peer `gorm:"-"` // WireGuard peer
+ Config string `gorm:"-"`
+
+ UID string `form:"uid" binding:"alphanum"` // uid for html identification
+ IsOnline bool `gorm:"-"`
+ IsNew bool `gorm:"-"`
+ Identifier string `form:"identifier" binding:"required,lt=64"` // Identifier AND Email make a WireGuard peer unique
+ Email string `gorm:"index" form:"mail" binding:"required,email"`
+ LastHandshake string `gorm:"-"`
+ LastHandshakeTime string `gorm:"-"`
+
+ IgnorePersistentKeepalive bool `form:"ignorekeepalive"`
+ PresharedKey string `form:"presharedkey" binding:"omitempty,base64"`
+ AllowedIPsStr string `form:"allowedip" binding:"cidrlist"`
+ IPsStr string `form:"ip" binding:"cidrlist"`
+ AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
+ IPs []string `gorm:"-"` // The IPs of the client
+ PrivateKey string `form:"privkey" binding:"omitempty,base64"`
+ PublicKey string `gorm:"primaryKey" form:"pubkey" binding:"required,base64"`
+ DeviceName string `gorm:"index"`
+
+ DeactivatedAt *time.Time
+ CreatedBy string
+ UpdatedBy string
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+func (p Peer) GetConfig() wgtypes.PeerConfig {
+ publicKey, _ := wgtypes.ParseKey(p.PublicKey)
+ var presharedKey *wgtypes.Key
+ if p.PresharedKey != "" {
+ presharedKeyTmp, _ := wgtypes.ParseKey(p.PresharedKey)
+ presharedKey = &presharedKeyTmp
+ }
+
+ cfg := wgtypes.PeerConfig{
+ PublicKey: publicKey,
+ Remove: false,
+ UpdateOnly: false,
+ PresharedKey: presharedKey,
+ Endpoint: nil,
+ PersistentKeepaliveInterval: nil,
+ ReplaceAllowedIPs: true,
+ AllowedIPs: make([]net.IPNet, len(p.IPs)),
+ }
+ for i, ip := range p.IPs {
+ _, ipNet, err := net.ParseCIDR(ip)
+ if err == nil {
+ cfg.AllowedIPs[i] = *ipNet
+ }
+ }
+
+ return cfg
+}
+
+func (p Peer) GetConfigFile(device Device) ([]byte, error) {
+ tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(ClientCfgTpl)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to parse client template")
+ }
+
+ var tplBuff bytes.Buffer
+
+ err = tpl.Execute(&tplBuff, struct {
+ Client Peer
+ Server Device
+ }{
+ Client: p,
+ Server: device,
+ })
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to execute client template")
+ }
+
+ return tplBuff.Bytes(), nil
+}
+
+func (p Peer) GetQRCode() ([]byte, error) {
+ png, err := qrcode.Encode(p.Config, qrcode.Medium, 250)
+ if err != nil {
+ logrus.WithFields(logrus.Fields{
+ "err": err,
+ }).Error("failed to create qrcode")
+ return nil, errors.Wrap(err, "failed to encode qrcode")
+ }
+ return png, nil
+}
+
+func (p Peer) IsValid() bool {
+ if p.PublicKey == "" {
+ return false
+ }
+
+ return true
+}
+
+func (p Peer) ToMap() map[string]string {
+ out := make(map[string]string)
+
+ v := reflect.ValueOf(p)
+ if v.Kind() == reflect.Ptr {
+ v = v.Elem()
+ }
+
+ typ := v.Type()
+ for i := 0; i < v.NumField(); i++ {
+ // gets us a StructField
+ fi := typ.Field(i)
+ if tagv := fi.Tag.Get("form"); tagv != "" {
+ // set key of map to value in struct field
+ out[tagv] = v.Field(i).String()
+ }
+ }
+ return out
+}
+
+func (p Peer) GetConfigFileName() string {
+ reg := regexp.MustCompile("[^a-zA-Z0-9_-]+")
+ return reg.ReplaceAllString(strings.ReplaceAll(p.Identifier, " ", "-"), "") + ".conf"
+}
+
+//
+// DEVICE --------------------------------------------------------------------------------------
+//
+
+type Device struct {
+ Interface *wgtypes.Device `gorm:"-"`
+
+ DeviceName string `form:"device" gorm:"primaryKey" binding:"required,alphanum"`
+ PrivateKey string `form:"privkey" binding:"required,base64"`
+ PublicKey string `form:"pubkey" binding:"required,base64"`
+ PersistentKeepalive int `form:"keepalive" binding:"gte=0"`
+ ListenPort int `form:"port" binding:"required,gt=0"`
+ Mtu int `form:"mtu" binding:"gte=0,lte=1500"`
+ Endpoint string `form:"endpoint" binding:"required,hostname_port"`
+ AllowedIPsStr string `form:"allowedip" binding:"cidrlist"`
+ IPsStr string `form:"ip" binding:"required,cidrlist"`
+ AllowedIPs []string `gorm:"-"` // IPs that are used in the client config file
+ IPs []string `gorm:"-"` // The IPs of the client
+ DNSStr string `form:"dns" binding:"iplist"`
+ DNS []string `gorm:"-"` // The DNS servers of the client
+ PreUp string `form:"preup"`
+ PostUp string `form:"postup"`
+ PreDown string `form:"predown"`
+ PostDown string `form:"postdown"`
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+func (d Device) IsValid() bool {
+ if d.PublicKey == "" {
+ return false
+ }
+ if len(d.IPs) == 0 {
+ return false
+ }
+ if d.Endpoint == "" {
+ return false
+ }
+
+ return true
+}
+
+func (d Device) GetConfig() wgtypes.Config {
+ var privateKey *wgtypes.Key
+ if d.PrivateKey != "" {
+ pKey, _ := wgtypes.ParseKey(d.PrivateKey)
+ privateKey = &pKey
+ }
+
+ cfg := wgtypes.Config{
+ PrivateKey: privateKey,
+ ListenPort: &d.ListenPort,
+ }
+
+ return cfg
+}
+
+func (d Device) GetConfigFile(peers []Peer) ([]byte, error) {
+ tpl, err := template.New("server").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(DeviceCfgTpl)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to parse server template")
+ }
+
+ var tplBuff bytes.Buffer
+
+ err = tpl.Execute(&tplBuff, struct {
+ Clients []Peer
+ Server Device
+ }{
+ Clients: peers,
+ Server: d,
+ })
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to execute server template")
+ }
+
+ return tplBuff.Bytes(), nil
+}
+
+//
+// PEER-MANAGER --------------------------------------------------------------------------------
+//
+
+type PeerManager struct {
+ db *gorm.DB
+ wg *Manager
+}
+
+func NewPeerManager(db *gorm.DB, wg *Manager) (*PeerManager, error) {
+ um := &PeerManager{db: db, wg: wg}
+
+ if err := um.db.AutoMigrate(&Peer{}, &Device{}); err != nil {
+ return nil, errors.WithMessage(err, "failed to migrate peer database")
+ }
+
+ return um, nil
+}
+
+// InitFromPhysicalInterface read all WireGuard peers from the WireGuard interface configuration. If a peer does not
+// exist in the local database, it gets created.
+func (m *PeerManager) InitFromPhysicalInterface() error {
+ for _, deviceName := range m.wg.Cfg.DeviceNames {
+ peers, err := m.wg.GetPeerList(deviceName)
+ if err != nil {
+ return errors.Wrapf(err, "failed to get peer list for device %s", deviceName)
+ }
+ device, err := m.wg.GetDeviceInfo(deviceName)
+ if err != nil {
+ return errors.Wrapf(err, "failed to get device info for device %s", deviceName)
+ }
+ var ipAddresses []string
+ var mtu int
+ if m.wg.Cfg.ManageIPAddresses {
+ if ipAddresses, err = m.wg.GetIPAddress(deviceName); err != nil {
+ return errors.Wrapf(err, "failed to get ip address for device %s", deviceName)
+ }
+ if mtu, err = m.wg.GetMTU(deviceName); err != nil {
+ return errors.Wrapf(err, "failed to get MTU for device %s", deviceName)
+ }
+ }
+
+ // Check if entries already exist in database, if not create them
+ for _, peer := range peers {
+ if err := m.validateOrCreatePeer(deviceName, peer); err != nil {
+ return errors.WithMessagef(err, "failed to validate peer %s for device %s", peer.PublicKey, deviceName)
+ }
+ }
+ if err := m.validateOrCreateDevice(*device, ipAddresses, mtu); err != nil {
+ return errors.WithMessagef(err, "failed to validate device %s", device.Name)
+ }
+ }
+
+ return nil
+}
+
+// validateOrCreatePeer checks if the given WireGuard peer already exists in the database, if not, the peer entry will be created
+func (m *PeerManager) validateOrCreatePeer(device string, wgPeer wgtypes.Peer) error {
+ peer := Peer{}
+ m.db.Where("public_key = ?", wgPeer.PublicKey.String()).FirstOrInit(&peer)
+
+ if peer.PublicKey == "" { // peer not found, create
+ peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(wgPeer.PublicKey.String())))
+ peer.PublicKey = wgPeer.PublicKey.String()
+ peer.PrivateKey = "" // UNKNOWN
+ if wgPeer.PresharedKey != (wgtypes.Key{}) {
+ peer.PresharedKey = wgPeer.PresharedKey.String()
+ }
+ peer.Email = "autodetected@example.com"
+ peer.Identifier = "Autodetected (" + peer.PublicKey[0:8] + ")"
+ peer.UpdatedAt = time.Now()
+ peer.CreatedAt = time.Now()
+ peer.AllowedIPs = make([]string, 0) // UNKNOWN
+ peer.IPs = make([]string, len(wgPeer.AllowedIPs))
+ for i, ip := range wgPeer.AllowedIPs {
+ peer.IPs[i] = ip.String()
+ }
+ peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
+ peer.IPsStr = strings.Join(peer.IPs, ", ")
+ peer.DeviceName = device
+
+ res := m.db.Create(&peer)
+ if res.Error != nil {
+ return errors.Wrapf(res.Error, "failed to create autodetected peer %s", peer.PublicKey)
+ }
+ }
+
+ return nil
+}
+
+// validateOrCreateDevice checks if the given WireGuard device already exists in the database, if not, the peer entry will be created
+func (m *PeerManager) validateOrCreateDevice(dev wgtypes.Device, ipAddresses []string, mtu int) error {
+ device := Device{}
+ m.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
+
+ if device.PublicKey == "" { // device not found, create
+ device.PublicKey = dev.PublicKey.String()
+ device.PrivateKey = dev.PrivateKey.String()
+ device.DeviceName = dev.Name
+ device.ListenPort = dev.ListenPort
+ device.Mtu = 0
+ device.PersistentKeepalive = 16 // Default
+ device.IPsStr = strings.Join(ipAddresses, ", ")
+ if mtu == DefaultMTU {
+ mtu = 0
+ }
+ device.Mtu = mtu
+
+ res := m.db.Create(&device)
+ if res.Error != nil {
+ return errors.Wrapf(res.Error, "failed to create autodetected device")
+ }
+ }
+
+ return nil
+}
+
+// populatePeerData enriches the peer struct with WireGuard live data like last handshake, ...
+func (m *PeerManager) populatePeerData(peer *Peer) {
+ peer.AllowedIPs = strings.Split(peer.AllowedIPsStr, ", ")
+ peer.IPs = strings.Split(peer.IPsStr, ", ")
+ // Set config file
+ tmpCfg, _ := peer.GetConfigFile(m.GetDevice(peer.DeviceName))
+ peer.Config = string(tmpCfg)
+
+ // set data from WireGuard interface
+ peer.Peer, _ = m.wg.GetPeer(peer.DeviceName, peer.PublicKey)
+ peer.LastHandshake = "never"
+ peer.LastHandshakeTime = "Never connected, or user is disabled."
+ if peer.Peer != nil {
+ since := time.Since(peer.Peer.LastHandshakeTime)
+ sinceSeconds := int(since.Round(time.Second).Seconds())
+ sinceMinutes := sinceSeconds / 60
+ sinceSeconds -= sinceMinutes * 60
+
+ if sinceMinutes > 2*10080 { // 2 weeks
+ peer.LastHandshake = "a while ago"
+ } else if sinceMinutes > 10080 { // 1 week
+ peer.LastHandshake = "a week ago"
+ } else {
+ peer.LastHandshake = fmt.Sprintf("%02dm %02ds", sinceMinutes, sinceSeconds)
+ }
+ peer.LastHandshakeTime = peer.Peer.LastHandshakeTime.Format(time.UnixDate)
+ }
+ peer.IsOnline = false
+}
+
+// populateDeviceData enriches the device struct with WireGuard live data like interface information
+func (m *PeerManager) populateDeviceData(device *Device) {
+ device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
+ device.IPs = strings.Split(device.IPsStr, ", ")
+ device.DNS = strings.Split(device.DNSStr, ", ")
+
+ // set data from WireGuard interface
+ device.Interface, _ = m.wg.GetDeviceInfo(device.DeviceName)
+}
+
+func (m *PeerManager) GetAllPeers(device string) []Peer {
+ peers := make([]Peer, 0)
+ m.db.Where("device_name = ?", device).Find(&peers)
+
+ for i := range peers {
+ m.populatePeerData(&peers[i])
+ }
+
+ return peers
+}
+
+func (m *PeerManager) GetActivePeers(device string) []Peer {
+ peers := make([]Peer, 0)
+ m.db.Where("device_name = ? AND deactivated_at IS NULL", device).Find(&peers)
+
+ for i := range peers {
+ m.populatePeerData(&peers[i])
+ }
+
+ return peers
+}
+
+func (m *PeerManager) GetFilteredAndSortedPeers(device, sortKey, sortDirection, search string) []Peer {
+ peers := make([]Peer, 0)
+ m.db.Where("device_name = ?", device).Find(&peers)
+
+ filteredPeers := make([]Peer, 0, len(peers))
+ for i := range peers {
+ m.populatePeerData(&peers[i])
+
+ if search == "" ||
+ strings.Contains(peers[i].Email, search) ||
+ strings.Contains(peers[i].Identifier, search) ||
+ strings.Contains(peers[i].PublicKey, search) {
+ filteredPeers = append(filteredPeers, peers[i])
+ }
+ }
+
+ sort.Slice(filteredPeers, func(i, j int) bool {
+ var sortValueLeft string
+ var sortValueRight string
+
+ switch sortKey {
+ case "id":
+ sortValueLeft = filteredPeers[i].Identifier
+ sortValueRight = filteredPeers[j].Identifier
+ case "pubKey":
+ sortValueLeft = filteredPeers[i].PublicKey
+ sortValueRight = filteredPeers[j].PublicKey
+ case "mail":
+ sortValueLeft = filteredPeers[i].Email
+ sortValueRight = filteredPeers[j].Email
+ case "ip":
+ sortValueLeft = filteredPeers[i].IPsStr
+ sortValueRight = filteredPeers[j].IPsStr
+ case "handshake":
+ if filteredPeers[i].Peer == nil {
+ return false
+ } else if filteredPeers[j].Peer == nil {
+ return true
+ }
+ sortValueLeft = filteredPeers[i].Peer.LastHandshakeTime.Format(time.RFC3339)
+ sortValueRight = filteredPeers[j].Peer.LastHandshakeTime.Format(time.RFC3339)
+ }
+
+ if sortDirection == "asc" {
+ return sortValueLeft < sortValueRight
+ } else {
+ return sortValueLeft > sortValueRight
+ }
+ })
+
+ return filteredPeers
+}
+
+func (m *PeerManager) GetSortedPeersForEmail(sortKey, sortDirection, email string) []Peer {
+ peers := make([]Peer, 0)
+ m.db.Where("email = ?", email).Find(&peers)
+
+ for i := range peers {
+ m.populatePeerData(&peers[i])
+ }
+
+ sort.Slice(peers, func(i, j int) bool {
+ var sortValueLeft string
+ var sortValueRight string
+
+ switch sortKey {
+ case "id":
+ sortValueLeft = peers[i].Identifier
+ sortValueRight = peers[j].Identifier
+ case "pubKey":
+ sortValueLeft = peers[i].PublicKey
+ sortValueRight = peers[j].PublicKey
+ case "mail":
+ sortValueLeft = peers[i].Email
+ sortValueRight = peers[j].Email
+ case "ip":
+ sortValueLeft = peers[i].IPsStr
+ sortValueRight = peers[j].IPsStr
+ case "handshake":
+ if peers[i].Peer == nil {
+ return true
+ } else if peers[j].Peer == nil {
+ return false
+ }
+ sortValueLeft = peers[i].Peer.LastHandshakeTime.Format(time.RFC3339)
+ sortValueRight = peers[j].Peer.LastHandshakeTime.Format(time.RFC3339)
+ }
+
+ if sortDirection == "asc" {
+ return sortValueLeft < sortValueRight
+ } else {
+ return sortValueLeft > sortValueRight
+ }
+ })
+
+ return peers
+}
+
+func (m *PeerManager) GetDevice(device string) Device {
+ dev := Device{}
+
+ m.db.Where("device_name = ?", device).First(&dev)
+ m.populateDeviceData(&dev)
+
+ return dev
+}
+
+func (m *PeerManager) GetPeerByKey(publicKey string) Peer {
+ peer := Peer{}
+ m.db.Where("public_key = ?", publicKey).FirstOrInit(&peer)
+ m.populatePeerData(&peer)
+ return peer
+}
+
+func (m *PeerManager) GetPeersByMail(mail string) []Peer {
+ var peers []Peer
+ m.db.Where("email = ?", mail).Find(&peers)
+ for i := range peers {
+ m.populatePeerData(&peers[i])
+ }
+
+ return peers
+}
+
+// ---- Database helpers -----
+
+func (m *PeerManager) CreatePeer(peer Peer) error {
+ peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
+ peer.UpdatedAt = time.Now()
+ peer.CreatedAt = time.Now()
+ peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
+ peer.IPsStr = strings.Join(peer.IPs, ", ")
+
+ res := m.db.Create(&peer)
+ if res.Error != nil {
+ logrus.Errorf("failed to create peer: %v", res.Error)
+ return errors.Wrap(res.Error, "failed to create peer")
+ }
+
+ return nil
+}
+
+func (m *PeerManager) UpdatePeer(peer Peer) error {
+ peer.UpdatedAt = time.Now()
+ peer.AllowedIPsStr = strings.Join(peer.AllowedIPs, ", ")
+ peer.IPsStr = strings.Join(peer.IPs, ", ")
+
+ res := m.db.Save(&peer)
+ if res.Error != nil {
+ logrus.Errorf("failed to update peer: %v", res.Error)
+ return errors.Wrap(res.Error, "failed to update peer")
+ }
+
+ return nil
+}
+
+func (m *PeerManager) DeletePeer(peer Peer) error {
+ res := m.db.Delete(&peer)
+ if res.Error != nil {
+ logrus.Errorf("failed to delete peer: %v", res.Error)
+ return errors.Wrap(res.Error, "failed to delete peer")
+ }
+
+ return nil
+}
+
+func (m *PeerManager) UpdateDevice(device Device) error {
+ device.UpdatedAt = time.Now()
+ device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ")
+ device.IPsStr = strings.Join(device.IPs, ", ")
+ device.DNSStr = strings.Join(device.DNS, ", ")
+
+ res := m.db.Save(&device)
+ if res.Error != nil {
+ logrus.Errorf("failed to update device: %v", res.Error)
+ return errors.Wrap(res.Error, "failed to update device")
+ }
+
+ return nil
+}
+
+// ---- IP helpers ----
+
+func (m *PeerManager) GetAllReservedIps(device string) ([]string, error) {
+ reservedIps := make([]string, 0)
+ peers := m.GetAllPeers(device)
+ for _, user := range peers {
+ for _, cidr := range user.IPs {
+ if cidr == "" {
+ continue
+ }
+ ip, _, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to parse cidr")
+ }
+ reservedIps = append(reservedIps, ip.String())
+ }
+ }
+
+ dev := m.GetDevice(device)
+ for _, cidr := range dev.IPs {
+ if cidr == "" {
+ continue
+ }
+ ip, _, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return nil, errors.Wrap(err, "failed to parse cidr")
+ }
+
+ reservedIps = append(reservedIps, ip.String())
+ }
+
+ return reservedIps, nil
+}
+
+func (m *PeerManager) IsIPReserved(device string, cidr string) bool {
+ reserved, err := m.GetAllReservedIps(device)
+ if err != nil {
+ return true // in case something failed, assume the ip is reserved
+ }
+ ip, ipnet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return true
+ }
+
+ // this two addresses are not usable
+ broadcastAddr := common.BroadcastAddr(ipnet).String()
+ networkAddr := ipnet.IP.String()
+ address := ip.String()
+
+ if address == broadcastAddr || address == networkAddr {
+ return true
+ }
+
+ for _, r := range reserved {
+ if address == r {
+ return true
+ }
+ }
+
+ return false
+}
+
+// GetAvailableIp search for an available ip in cidr against a list of reserved ips
+func (m *PeerManager) GetAvailableIp(device string, cidr string) (string, error) {
+ reserved, err := m.GetAllReservedIps(device)
+ if err != nil {
+ return "", errors.WithMessagef(err, "failed to get all reserved IP addresses for %s", device)
+ }
+ ip, ipnet, err := net.ParseCIDR(cidr)
+ if err != nil {
+ return "", errors.Wrap(err, "failed to parse cidr")
+ }
+
+ // this two addresses are not usable
+ broadcastAddr := common.BroadcastAddr(ipnet).String()
+ networkAddr := ipnet.IP.String()
+
+ for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); common.IncreaseIP(ip) {
+ ok := true
+ address := ip.String()
+ for _, r := range reserved {
+ if address == r {
+ ok = false
+ break
+ }
+ }
+ if ok && address != networkAddr && address != broadcastAddr {
+ netMask := "/32"
+ if common.IsIPv6(address) {
+ netMask = "/128"
+ }
+ return address + netMask, nil
+ }
+ }
+
+ return "", errors.New("no more available address from cidr")
+}