diff --git a/assets/css/custom.css b/assets/css/custom.css index 8e0baa9..0782b69 100644 --- a/assets/css/custom.css +++ b/assets/css/custom.css @@ -47,4 +47,8 @@ .navbar { padding: 0.5rem 1rem; +} + +.disabled-peer { + color: #d03131; } \ No newline at end of file diff --git a/assets/tpl/admin_edit_client.html b/assets/tpl/admin_edit_client.html new file mode 100644 index 0000000..9dc09d5 --- /dev/null +++ b/assets/tpl/admin_edit_client.html @@ -0,0 +1,97 @@ + + +
+ + +No LDAP user-information available...
+ {{else}} +{{$p.Config}}
diff --git a/internal/server/core.go b/internal/server/core.go
index 5417e48..9b7063a 100644
--- a/internal/server/core.go
+++ b/internal/server/core.go
@@ -90,12 +90,12 @@
}
// Setup user manager
- s.users = NewUserManager()
- if s.users == nil {
+ if s.users = NewUserManager(s.wg, s.ldapUsers); s.users == nil {
return errors.New("unable to setup user manager")
}
- s.users.InitWithDevice(s.wg.GetDeviceInfo())
- s.users.InitWithPeers(s.wg.GetPeerList())
+ if err := s.users.InitFromCurrentInterface(); err != nil {
+ return errors.New("unable to initialize user manager")
+ }
dir := s.getExecutableDirectory()
rDir, _ := filepath.Abs(filepath.Dir(os.Args[0]))
diff --git a/internal/server/handlers.go b/internal/server/handlers.go
index 36efa7e..7ae946c 100644
--- a/internal/server/handlers.go
+++ b/internal/server/handlers.go
@@ -2,7 +2,10 @@
import (
"net/http"
+ "net/url"
"strconv"
+ "strings"
+ "time"
"github.com/gin-gonic/gin"
)
@@ -32,24 +35,9 @@
}
func (s *Server) GetAdminIndex(c *gin.Context) {
- dev, err := s.wg.GetDeviceInfo()
- if err != nil {
- s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
- return
- }
- peers, err := s.wg.GetPeerList()
- if err != nil {
- s.HandleError(c, http.StatusInternalServerError, "WireGuard error", err.Error())
- return
- }
-
device := s.users.GetDevice()
- device.Interface = dev
+ users := s.users.GetAllUsers()
- users := make([]User, len(peers))
- for i, peer := range peers {
- users[i] = s.users.GetOrCreateUserForPeer(peer)
- }
c.HTML(http.StatusOK, "admin_index.html", struct {
Route string
Session SessionData
@@ -65,8 +53,332 @@
})
}
+func (s *Server) GetAdminEditInterface(c *gin.Context) {
+ device := s.users.GetDevice()
+ users := s.users.GetAllUsers()
+
+ c.HTML(http.StatusOK, "admin_edit_interface.html", struct {
+ Route string
+ Alerts AlertData
+ Session SessionData
+ Static StaticData
+ Peers []User
+ Device Device
+ }{
+ Route: c.Request.URL.Path,
+ Alerts: s.getAlertData(c),
+ Session: s.getSessionData(c),
+ Static: s.getStaticData(),
+ Peers: users,
+ Device: device,
+ })
+}
+
+func (s *Server) PostAdminEditInterface(c *gin.Context) {
+ device := s.users.GetDevice()
+ var err error
+
+ device.ListenPort, err = strconv.Atoi(c.PostForm("port"))
+ if err != nil {
+ s.setAlert(c, "invalid port: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ ipField := c.PostForm("ip")
+ ips := strings.Split(ipField, ",")
+ validatedIPs := make([]string, 0, len(ips))
+ for i := range ips {
+ ips[i] = strings.TrimSpace(ips[i])
+ if ips[i] != "" {
+ validatedIPs = append(validatedIPs, ips[i])
+ }
+ }
+ if len(validatedIPs) == 0 {
+ s.setAlert(c, "invalid ip address", "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+ device.IPs = validatedIPs
+
+ device.Endpoint = c.PostForm("endpoint")
+
+ dnsField := c.PostForm("dns")
+ dns := strings.Split(dnsField, ",")
+ validatedDNS := make([]string, 0, len(dns))
+ for i := range dns {
+ dns[i] = strings.TrimSpace(dns[i])
+ if dns[i] != "" {
+ validatedDNS = append(validatedDNS, dns[i])
+ }
+ }
+ device.DNS = validatedDNS
+
+ allowedIPField := c.PostForm("allowedip")
+ allowedIP := strings.Split(allowedIPField, ",")
+ validatedAllowedIP := make([]string, 0, len(allowedIP))
+ for i := range allowedIP {
+ allowedIP[i] = strings.TrimSpace(allowedIP[i])
+ if allowedIP[i] != "" {
+ validatedAllowedIP = append(validatedAllowedIP, allowedIP[i])
+ }
+ }
+ device.AllowedIPs = validatedAllowedIP
+
+ device.Mtu, err = strconv.Atoi(c.PostForm("mtu"))
+ if err != nil {
+ s.setAlert(c, "invalid MTU: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ device.PersistentKeepalive, err = strconv.Atoi(c.PostForm("keepalive"))
+ if err != nil {
+ s.setAlert(c, "invalid PersistentKeepalive: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ // Update WireGuard device
+ err = s.wg.UpdateDevice(device.DeviceName, device.GetDeviceConfig())
+ if err != nil {
+ s.setAlert(c, "failed to update device in WireGuard: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ // Update in database
+ err = s.users.UpdateDevice(device)
+ if err != nil {
+ s.setAlert(c, "failed to update device in database: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ s.setAlert(c, "changes applied successfully", "success")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+}
+
+func (s *Server) GetAdminEditPeer(c *gin.Context) {
+ device := s.users.GetDevice()
+ user := s.users.GetUserByKey(c.Query("pkey"))
+
+ c.HTML(http.StatusOK, "admin_edit_client.html", struct {
+ Route string
+ Alerts AlertData
+ Session SessionData
+ Static StaticData
+ Peer User
+ Device Device
+ }{
+ Route: c.Request.URL.Path,
+ Alerts: s.getAlertData(c),
+ Session: s.getSessionData(c),
+ Static: s.getStaticData(),
+ Peer: user,
+ Device: device,
+ })
+}
+
+func (s *Server) PostAdminEditPeer(c *gin.Context) {
+ user := s.users.GetUserByKey(c.Query("pkey"))
+ urlEncodedKey := url.QueryEscape(c.Query("pkey"))
+ var err error
+
+ user.Identifier = c.PostForm("identifier")
+ if user.Identifier == "" {
+ s.setAlert(c, "invalid identifier, must not be empty", "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+
+ user.Email = c.PostForm("mail")
+ if user.Email == "" {
+ s.setAlert(c, "invalid email, must not be empty", "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+
+ ipField := c.PostForm("ip")
+ ips := strings.Split(ipField, ",")
+ validatedIPs := make([]string, 0, len(ips))
+ for i := range ips {
+ ips[i] = strings.TrimSpace(ips[i])
+ if ips[i] != "" {
+ validatedIPs = append(validatedIPs, ips[i])
+ }
+ }
+ if len(validatedIPs) == 0 {
+ s.setAlert(c, "invalid ip address", "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+ user.IPs = validatedIPs
+
+ allowedIPField := c.PostForm("allowedip")
+ allowedIP := strings.Split(allowedIPField, ",")
+ validatedAllowedIP := make([]string, 0, len(allowedIP))
+ for i := range allowedIP {
+ allowedIP[i] = strings.TrimSpace(allowedIP[i])
+ if allowedIP[i] != "" {
+ validatedAllowedIP = append(validatedAllowedIP, allowedIP[i])
+ }
+ }
+ user.AllowedIPs = validatedAllowedIP
+
+ user.IgnorePersistentKeepalive = c.PostForm("ignorekeepalive") != ""
+ disabled := c.PostForm("isdisabled") != ""
+ now := time.Now()
+ if disabled && user.DeactivatedAt == nil {
+ user.DeactivatedAt = &now
+ } else if !disabled {
+ user.DeactivatedAt = nil
+ }
+
+ // Update WireGuard device
+ if user.DeactivatedAt == &now {
+ err = s.wg.RemovePeer(user.PublicKey)
+ if err != nil {
+ s.setAlert(c, "failed to remove peer in WireGuard: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+ } else if user.DeactivatedAt == nil && user.Peer != nil {
+ err = s.wg.UpdatePeer(user.GetPeerConfig())
+ if err != nil {
+ s.setAlert(c, "failed to update peer in WireGuard: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+ } else if user.DeactivatedAt == nil && user.Peer == nil {
+ err = s.wg.AddPeer(user.GetPeerConfig())
+ if err != nil {
+ s.setAlert(c, "failed to add peer in WireGuard: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+ }
+
+ // Update in database
+ err = s.users.UpdateUser(user)
+ if err != nil {
+ s.setAlert(c, "failed to update user in database: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+ return
+ }
+
+ s.setAlert(c, "changes applied successfully", "success")
+ c.Redirect(http.StatusSeeOther, "/admin/peer/edit?pkey="+urlEncodedKey)
+}
+
+func (s *Server) GetAdminCreatePeer(c *gin.Context) {
+ device := s.users.GetDevice()
+ user := s.users.GetUserByKey(c.Query("pkey"))
+
+ c.HTML(http.StatusOK, "admin_edit_client.html", struct {
+ Route string
+ Alerts AlertData
+ Session SessionData
+ Static StaticData
+ Peer User
+ Device Device
+ }{
+ Route: c.Request.URL.Path,
+ Alerts: s.getAlertData(c),
+ Session: s.getSessionData(c),
+ Static: s.getStaticData(),
+ Peer: user,
+ Device: device,
+ })
+}
+
+func (s *Server) PostAdminCreatePeer(c *gin.Context) {
+ device := s.users.GetDevice()
+ var err error
+
+ device.ListenPort, err = strconv.Atoi(c.PostForm("port"))
+ if err != nil {
+ s.setAlert(c, "invalid port: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ ipField := c.PostForm("ip")
+ ips := strings.Split(ipField, ",")
+ validatedIPs := make([]string, 0, len(ips))
+ for i := range ips {
+ ips[i] = strings.TrimSpace(ips[i])
+ if ips[i] != "" {
+ validatedIPs = append(validatedIPs, ips[i])
+ }
+ }
+ if len(validatedIPs) == 0 {
+ s.setAlert(c, "invalid ip address", "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+ device.IPs = validatedIPs
+
+ device.Endpoint = c.PostForm("endpoint")
+
+ dnsField := c.PostForm("dns")
+ dns := strings.Split(dnsField, ",")
+ validatedDNS := make([]string, 0, len(dns))
+ for i := range dns {
+ dns[i] = strings.TrimSpace(dns[i])
+ if dns[i] != "" {
+ validatedDNS = append(validatedDNS, dns[i])
+ }
+ }
+ device.DNS = validatedDNS
+
+ allowedIPField := c.PostForm("allowedip")
+ allowedIP := strings.Split(allowedIPField, ",")
+ validatedAllowedIP := make([]string, 0, len(allowedIP))
+ for i := range allowedIP {
+ allowedIP[i] = strings.TrimSpace(allowedIP[i])
+ if allowedIP[i] != "" {
+ validatedAllowedIP = append(validatedAllowedIP, allowedIP[i])
+ }
+ }
+ device.AllowedIPs = validatedAllowedIP
+
+ device.Mtu, err = strconv.Atoi(c.PostForm("mtu"))
+ if err != nil {
+ s.setAlert(c, "invalid MTU: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ device.PersistentKeepalive, err = strconv.Atoi(c.PostForm("keepalive"))
+ if err != nil {
+ s.setAlert(c, "invalid PersistentKeepalive: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ // Update WireGuard device
+ err = s.wg.UpdateDevice(device.DeviceName, device.GetDeviceConfig())
+ if err != nil {
+ s.setAlert(c, "failed to update device in WireGuard: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ // Update in database
+ err = s.users.UpdateDevice(device)
+ if err != nil {
+ s.setAlert(c, "failed to update device in database: "+err.Error(), "danger")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+ return
+ }
+
+ s.setAlert(c, "changes applied successfully", "success")
+ c.Redirect(http.StatusSeeOther, "/admin/device/edit")
+}
+
func (s *Server) GetUserQRCode(c *gin.Context) {
- user := s.users.GetUser(c.Param("pkey"))
+ user := s.users.GetUserByKey(c.Query("pkey"))
png, err := user.GetQRCode()
if err != nil {
s.HandleError(c, http.StatusInternalServerError, "QRCode error", err.Error())
diff --git a/internal/server/routes.go b/internal/server/routes.go
index 1fe28f3..9dc523e 100644
--- a/internal/server/routes.go
+++ b/internal/server/routes.go
@@ -20,6 +20,12 @@
admin := s.server.Group("/admin")
admin.Use(s.RequireAuthentication(s.config.AdminLdapGroup))
admin.GET("/", s.GetAdminIndex)
+ admin.GET("/device/edit", s.GetAdminEditInterface)
+ admin.POST("/device/edit", s.PostAdminEditInterface)
+ admin.GET("/peer/edit", s.GetAdminEditPeer)
+ admin.POST("/peer/edit", s.PostAdminEditPeer)
+ admin.GET("/peer/create", s.GetAdminCreatePeer)
+ admin.POST("/peer/create", s.PostAdminCreatePeer)
// User routes
user := s.server.Group("/user")
diff --git a/internal/server/usermanager.go b/internal/server/usermanager.go
index 7112c85..5308e6f 100644
--- a/internal/server/usermanager.go
+++ b/internal/server/usermanager.go
@@ -22,10 +22,14 @@
"gorm.io/gorm"
)
+//
+// USER ----------------------------------------------------------------------------------------
+//
+
type User struct {
- Peer wgtypes.Peer `gorm:"-"`
- User *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap
- Config string `gorm:"-"`
+ Peer *wgtypes.Peer `gorm:"-"`
+ LdapUser *ldap.UserCacheHolderEntry `gorm:"-"` // optional, it is still possible to have users without ldap
+ Config string `gorm:"-"`
UID string // uid for html identification
IsOnline bool `gorm:"-"`
@@ -48,6 +52,28 @@
UpdatedAt time.Time
}
+func (u User) GetClientConfigFile(device Device) ([]byte, error) {
+ tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl)
+ if err != nil {
+ return nil, err
+ }
+
+ var tplBuff bytes.Buffer
+
+ err = tpl.Execute(&tplBuff, struct {
+ Client User
+ Server Device
+ }{
+ Client: u,
+ Server: device,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return tplBuff.Bytes(), nil
+}
+
func (u User) GetPeerConfig() wgtypes.PeerConfig {
publicKey, _ := wgtypes.ParseKey(u.PublicKey)
var presharedKey *wgtypes.Key
@@ -87,6 +113,18 @@
return png, nil
}
+func (u User) IsValid() bool {
+ if u.PublicKey == "" {
+ return false
+ }
+
+ return true
+}
+
+//
+// DEVICE --------------------------------------------------------------------------------------
+//
+
type Device struct {
Interface *wgtypes.Device `gorm:"-"`
@@ -112,6 +150,9 @@
}
func (d Device) IsValid() bool {
+ if d.PublicKey == "" {
+ return false
+ }
if len(d.IPs) == 0 {
return false
}
@@ -122,12 +163,33 @@
return true
}
-type UserManager struct {
- db *gorm.DB
+func (d Device) GetDeviceConfig() wgtypes.Config {
+ var privateKey *wgtypes.Key
+ if d.PrivateKey != "" {
+ pKey, _ := wgtypes.ParseKey(d.PrivateKey)
+ privateKey = &pKey
+ }
+
+ cfg := wgtypes.Config{
+ PrivateKey: privateKey,
+ ListenPort: &d.ListenPort,
+ }
+
+ return cfg
}
-func NewUserManager() *UserManager {
- um := &UserManager{}
+//
+// USER-MANAGER --------------------------------------------------------------------------------
+//
+
+type UserManager struct {
+ db *gorm.DB
+ wg *wireguard.Manager
+ ldapUsers *ldap.SynchronizedUserCacheHolder
+}
+
+func NewUserManager(wg *wireguard.Manager, ldapUsers *ldap.SynchronizedUserCacheHolder) *UserManager {
+ um := &UserManager{wg: wg, ldapUsers: ldapUsers}
var err error
um.db, err = gorm.Open(sqlite.Open("wg_portal.db"), &gorm.Config{})
if err != nil {
@@ -144,52 +206,32 @@
return um
}
-func (u *UserManager) InitWithPeers(peers []wgtypes.Peer, err error) {
+func (u *UserManager) InitFromCurrentInterface() error {
+ peers, err := u.wg.GetPeerList()
if err != nil {
log.Errorf("failed to init user-manager from peers: %v", err)
- return
+ return err
}
- for _, peer := range peers {
- u.GetOrCreateUserForPeer(peer)
- }
-}
-
-func (u *UserManager) InitWithDevice(dev *wgtypes.Device, err error) {
+ device, err := u.wg.GetDeviceInfo()
if err != nil {
log.Errorf("failed to init user-manager from device: %v", err)
- return
- }
- u.GetOrCreateDevice(*dev)
-}
-
-func (u *UserManager) GetAllUsers() []User {
- users := make([]User, 0)
- u.db.Find(&users)
-
- for i := range users {
- users[i].AllowedIPs = strings.Split(users[i].AllowedIPsStr, ", ")
- users[i].IPs = strings.Split(users[i].IPsStr, ", ")
- tmpCfg, _ := u.GetPeerConfigFile(users[i])
- users[i].Config = string(tmpCfg)
+ return err
}
- return users
-}
-
-func (u *UserManager) GetDevice() Device {
- devices := make([]Device, 0, 1)
- u.db.Find(&devices)
-
- for i := range devices {
- devices[i].AllowedIPs = strings.Split(devices[i].AllowedIPsStr, ", ")
- devices[i].IPs = strings.Split(devices[i].IPsStr, ", ")
- devices[i].DNS = strings.Split(devices[i].DNSStr, ", ")
+ // Check if entries already exist in database, if not create them
+ for _, peer := range peers {
+ if err := u.validateOrCreateUserForPeer(peer); err != nil {
+ return err
+ }
+ }
+ if err := u.validateOrCreateDevice(*device); err != nil {
+ return err
}
- return devices[0]
+ return nil
}
-func (u *UserManager) GetOrCreateUserForPeer(peer wgtypes.Peer) User {
+func (u *UserManager) validateOrCreateUserForPeer(peer wgtypes.Peer) error {
user := User{}
u.db.Where("public_key = ?", peer.PublicKey.String()).FirstOrInit(&user)
@@ -215,25 +257,92 @@
res := u.db.Create(&user)
if res.Error != nil {
log.Errorf("failed to create autodetected peer: %v", res.Error)
+ return res.Error
}
}
- user.IPs = strings.Split(user.IPsStr, ", ")
+ return nil
+}
+
+func (u *UserManager) validateOrCreateDevice(dev wgtypes.Device) error {
+ device := Device{}
+ u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
+
+ if device.PublicKey == "" { // device not found, create
+ device.PublicKey = dev.PublicKey.String()
+ device.PrivateKey = dev.PrivateKey.String()
+ device.DeviceName = dev.Name
+ device.ListenPort = dev.ListenPort
+ device.Mtu = 0
+ device.PersistentKeepalive = 16 // Default
+
+ res := u.db.Create(&device)
+ if res.Error != nil {
+ log.Errorf("failed to create autodetected device: %v", res.Error)
+ return res.Error
+ }
+ }
+
+ return nil
+}
+
+func (u *UserManager) populateUserData(user *User) {
user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ")
- tmpCfg, _ := u.GetPeerConfigFile(user)
+ user.IPs = strings.Split(user.IPsStr, ", ")
+ // Set config file
+ tmpCfg, _ := user.GetClientConfigFile(u.GetDevice())
user.Config = string(tmpCfg)
+ // set data from WireGuard interface
+ user.Peer, _ = u.wg.GetPeer(user.PublicKey)
+ user.IsOnline = false // todo: calculate online status
+
+ // set ldap data
+ user.LdapUser = u.ldapUsers.GetUserData(u.ldapUsers.GetUserDNByMail(user.Email))
+}
+
+func (u *UserManager) populateDeviceData(device *Device) {
+ device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
+ device.IPs = strings.Split(device.IPsStr, ", ")
+ device.DNS = strings.Split(device.DNSStr, ", ")
+
+ // set data from WireGuard interface
+ device.Interface, _ = u.wg.GetDeviceInfo()
+}
+
+func (u *UserManager) GetAllUsers() []User {
+ users := make([]User, 0)
+ u.db.Find(&users)
+
+ for i := range users {
+ u.populateUserData(&users[i])
+ }
+
+ return users
+}
+
+func (u *UserManager) GetDevice() Device {
+ devices := make([]Device, 0, 1)
+ u.db.Find(&devices)
+
+ for i := range devices {
+ u.populateDeviceData(&devices[i])
+ }
+
+ return devices[0] // use first device for now... more to come?
+}
+
+func (u *UserManager) GetUserByKey(publicKey string) User {
+ user := User{}
+ u.db.Where("public_key = ?", publicKey).FirstOrInit(&user)
+ u.populateUserData(&user)
return user
}
-func (u *UserManager) GetUser(publicKey string) User {
+func (u *UserManager) GetUserByMail(mail string) User {
user := User{}
- u.db.Where("public_key = ?", publicKey).FirstOrInit(&user)
-
- user.IPs = strings.Split(user.IPsStr, ", ")
- user.AllowedIPs = strings.Split(user.AllowedIPsStr, ", ")
- tmpCfg, _ := u.GetPeerConfigFile(user)
- user.Config = string(tmpCfg)
+ u.db.Where("email = ?", mail).FirstOrInit(&user)
+ u.populateUserData(&user)
return user
}
@@ -268,6 +377,21 @@
return nil
}
+func (u *UserManager) UpdateDevice(device Device) error {
+ device.UpdatedAt = time.Now()
+ device.AllowedIPsStr = strings.Join(device.AllowedIPs, ", ")
+ device.IPsStr = strings.Join(device.IPs, ", ")
+ device.DNSStr = strings.Join(device.DNS, ", ")
+
+ res := u.db.Save(&device)
+ if res.Error != nil {
+ log.Errorf("failed to update device: %v", res.Error)
+ return res.Error
+ }
+
+ return nil
+}
+
func (u *UserManager) GetAllReservedIps() ([]string, error) {
reservedIps := make([]string, 0)
users := u.GetAllUsers()
@@ -328,50 +452,3 @@
return "", errors.New("no more available address from cidr")
}
-
-func (u *UserManager) GetOrCreateDevice(dev wgtypes.Device) Device {
- device := Device{}
- u.db.Where("device_name = ?", dev.Name).FirstOrInit(&device)
-
- if device.PublicKey == "" { // device not found, create
- device.PublicKey = dev.PublicKey.String()
- device.PrivateKey = dev.PrivateKey.String()
- device.DeviceName = dev.Name
- device.ListenPort = dev.ListenPort
- device.Mtu = 0
- device.PersistentKeepalive = 16 // Default
-
- res := u.db.Create(&device)
- if res.Error != nil {
- log.Errorf("failed to create autodetected device: %v", res.Error)
- }
- }
-
- device.IPs = strings.Split(device.IPsStr, ", ")
- device.AllowedIPs = strings.Split(device.AllowedIPsStr, ", ")
- device.DNS = strings.Split(device.DNSStr, ", ")
-
- return device
-}
-
-func (u *UserManager) GetPeerConfigFile(user User) ([]byte, error) {
- tpl, err := template.New("client").Funcs(template.FuncMap{"StringsJoin": strings.Join}).Parse(wireguard.ClientCfgTpl)
- if err != nil {
- return nil, err
- }
-
- var tplBuff bytes.Buffer
-
- err = tpl.Execute(&tplBuff, struct {
- Client User
- Server Device
- }{
- Client: user,
- Server: u.GetDevice(),
- })
- if err != nil {
- return nil, err
- }
-
- return tplBuff.Bytes(), nil
-}
diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go
index 0f22b3c..976b841 100644
--- a/internal/wireguard/manager.go
+++ b/internal/wireguard/manager.go
@@ -80,6 +80,19 @@
return nil
}
+func (m *Manager) UpdatePeer(cfg wgtypes.PeerConfig) error {
+ m.mux.Lock()
+ defer m.mux.Unlock()
+
+ cfg.UpdateOnly = true
+ err := m.wg.ConfigureDevice(m.Cfg.DeviceName, wgtypes.Config{Peers: []wgtypes.PeerConfig{cfg}})
+ if err != nil {
+ return fmt.Errorf("could not configure WireGuard device: %w", err)
+ }
+
+ return nil
+}
+
func (m *Manager) RemovePeer(pubKey string) error {
m.mux.Lock()
defer m.mux.Unlock()
@@ -101,3 +114,7 @@
return nil
}
+
+func (m *Manager) UpdateDevice(name string, cfg wgtypes.Config) error {
+ return m.wg.ConfigureDevice(name, cfg)
+}