diff --git a/internal/lowlevel/doc.go b/internal/lowlevel/doc.go new file mode 100644 index 0000000..057cf03 --- /dev/null +++ b/internal/lowlevel/doc.go @@ -0,0 +1,6 @@ +package lowlevel + +/** +This package contains wrappers for low level api's like netlink or the WireGuard control library. +Wrapping those external libraries makes mocking and testing code easier. +*/ diff --git a/internal/lowlevel/netlink.go b/internal/lowlevel/netlink.go new file mode 100644 index 0000000..b86d13e --- /dev/null +++ b/internal/lowlevel/netlink.go @@ -0,0 +1,63 @@ +package lowlevel + +import ( + "github.com/vishvananda/netlink" +) + +// A NetlinkClient is a type which can control a netlink device. +type NetlinkClient interface { + LinkAdd(link netlink.Link) error + LinkDel(link netlink.Link) error + LinkByName(name string) (netlink.Link, error) + LinkSetUp(link netlink.Link) error + LinkSetDown(link netlink.Link) error + LinkSetMTU(link netlink.Link, mtu int) error + AddrReplace(link netlink.Link, addr *netlink.Addr) error + AddrAdd(link netlink.Link, addr *netlink.Addr) error + AddrList(link netlink.Link) ([]netlink.Addr, error) +} + +type NetlinkManager struct { +} + +func (n NetlinkManager) LinkAdd(link netlink.Link) error { return netlink.LinkAdd(link) } + +func (n NetlinkManager) LinkDel(link netlink.Link) error { return netlink.LinkDel(link) } + +func (n NetlinkManager) LinkByName(name string) (netlink.Link, error) { + return netlink.LinkByName(name) +} + +func (n NetlinkManager) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } + +func (n NetlinkManager) LinkSetDown(link netlink.Link) error { return netlink.LinkSetDown(link) } + +func (n NetlinkManager) LinkSetMTU(link netlink.Link, mtu int) error { + return netlink.LinkSetMTU(link, mtu) +} + +func (n NetlinkManager) AddrReplace(link netlink.Link, addr *netlink.Addr) error { + return netlink.AddrReplace(link, addr) +} + +func (n NetlinkManager) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + return netlink.AddrAdd(link, addr) +} + +func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) { + listIPv4, err := netlink.AddrList(link, netlink.FAMILY_V4) + if err != nil { + return nil, err + } + + listIPv6, err := netlink.AddrList(link, netlink.FAMILY_V6) + if err != nil { + return nil, err + } + + ipAddresses := make([]netlink.Addr, 0, len(listIPv4)+len(listIPv6)) + ipAddresses = append(ipAddresses, listIPv4...) + ipAddresses = append(ipAddresses, listIPv6...) + + return ipAddresses, nil +} diff --git a/internal/lowlevel/wgctrl.go b/internal/lowlevel/wgctrl.go new file mode 100644 index 0000000..ab6832c --- /dev/null +++ b/internal/lowlevel/wgctrl.go @@ -0,0 +1,15 @@ +package lowlevel + +import ( + "io" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// A WireGuardClient is a type which can control a WireGuard device. +type WireGuardClient interface { + io.Closer + Devices() ([]*wgtypes.Device, error) + Device(name string) (*wgtypes.Device, error) + ConfigureDevice(name string, cfg wgtypes.Config) error +} diff --git a/internal/lowlevel/wrappers.go b/internal/lowlevel/wrappers.go deleted file mode 100644 index 9b409ed..0000000 --- a/internal/lowlevel/wrappers.go +++ /dev/null @@ -1,74 +0,0 @@ -package lowlevel - -import ( - "io" - - "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// A WireGuardClient is a type which can control a WireGuard device. -type WireGuardClient interface { - io.Closer - Devices() ([]*wgtypes.Device, error) - Device(name string) (*wgtypes.Device, error) - ConfigureDevice(name string, cfg wgtypes.Config) error -} - -// A NetlinkClient is a type which can control a netlink device. -type NetlinkClient interface { - LinkAdd(link netlink.Link) error - LinkDel(link netlink.Link) error - LinkByName(name string) (netlink.Link, error) - LinkSetUp(link netlink.Link) error - LinkSetDown(link netlink.Link) error - LinkSetMTU(link netlink.Link, mtu int) error - AddrReplace(link netlink.Link, addr *netlink.Addr) error - AddrAdd(link netlink.Link, addr *netlink.Addr) error - AddrList(link netlink.Link) ([]netlink.Addr, error) -} - -type NetlinkManager struct { -} - -func (n NetlinkManager) LinkAdd(link netlink.Link) error { return netlink.LinkAdd(link) } - -func (n NetlinkManager) LinkDel(link netlink.Link) error { return netlink.LinkDel(link) } - -func (n NetlinkManager) LinkByName(name string) (netlink.Link, error) { - return netlink.LinkByName(name) -} - -func (n NetlinkManager) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } - -func (n NetlinkManager) LinkSetDown(link netlink.Link) error { return netlink.LinkSetDown(link) } - -func (n NetlinkManager) LinkSetMTU(link netlink.Link, mtu int) error { - return netlink.LinkSetMTU(link, mtu) -} - -func (n NetlinkManager) AddrReplace(link netlink.Link, addr *netlink.Addr) error { - return netlink.AddrReplace(link, addr) -} - -func (n NetlinkManager) AddrAdd(link netlink.Link, addr *netlink.Addr) error { - return netlink.AddrAdd(link, addr) -} - -func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) { - listIPv4, err := netlink.AddrList(link, netlink.FAMILY_V4) - if err != nil { - return nil, err - } - - listIPv6, err := netlink.AddrList(link, netlink.FAMILY_V6) - if err != nil { - return nil, err - } - - ipAddresses := make([]netlink.Addr, 0, len(listIPv4)+len(listIPv6)) - ipAddresses = append(ipAddresses, listIPv4...) - ipAddresses = append(ipAddresses, listIPv6...) - - return ipAddresses, nil -} diff --git a/internal/persistence/database.go b/internal/persistence/database.go new file mode 100644 index 0000000..dc7cf83 --- /dev/null +++ b/internal/persistence/database.go @@ -0,0 +1 @@ +package persistence diff --git a/internal/persistence/ldap.go b/internal/persistence/ldap.go new file mode 100644 index 0000000..90c28b3 --- /dev/null +++ b/internal/persistence/ldap.go @@ -0,0 +1,4 @@ +package persistence + +type LdapLoader interface { +} diff --git a/internal/persistence/models.go b/internal/persistence/models.go new file mode 100644 index 0000000..acb58ce --- /dev/null +++ b/internal/persistence/models.go @@ -0,0 +1,158 @@ +package persistence + +import ( + "database/sql" + "time" + + "gorm.io/gorm" +) + +type BaseModel struct { + CreatedBy string + UpdatedBy string + CreatedAt time.Time + UpdatedAt time.Time + DisabledAt sql.NullTime +} + +type InterfaceIdentifier string +type PeerIdentifier string +type UserIdentifier string + +type KeyPair struct { + PrivateKey string + PublicKey string +} + +type PreSharedKey string + +type InterfaceType string + +const ( + InterfaceTypeServer InterfaceType = "server" + InterfaceTypeClient InterfaceType = "client" +) + +type InterfaceConfig struct { + BaseModel + + // WireGuard specific (for the [interface] section of the config file) + + Identifier InterfaceIdentifier // device name, for example: wg0 + KeyPair KeyPair // private/public Key of the server interface + ListenPort int // the listening port, for example: 51820 + + AddressStr string // the interface ip addresses, comma separated + DnsStr string // the dns server that should be set if the interface is up, comma separated + + Mtu int // the device MTU + FirewallMark int32 // a firewall mark + RoutingTable string // the routing table + + PreUp string // action that is executed before the device is up + PostUp string // action that is executed after the device is up + PreDown string // action that is executed before the device is down + PostDown string // action that is executed after the device is down + + SaveConfig bool // automatically persist config changes to the wgX.conf file + + // WG Portal specific + Enabled bool // flag that specifies if the interface is enabled (up) or nor (down) + DisplayName string // a nice display name/ description for the interface + Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient + DriverType string // the interface driver type (linux, software, ...) + + // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of + // the peer config + + PeerDefNetworkStr string // the default subnets from which peers will get their IP addresses, comma seperated + PeerDefDnsStr string // the default dns server for the peer + PeerDefEndpoint string // the default endpoint for the peer + PeerDefAllowedIPsStr string // the default allowed IP string for the peer + PeerDefMtu int // the default device MTU + PeerDefPersistentKeepalive int // the default persistent keep-alive Value + PeerDefFirewallMark int32 // default firewall mark + PeerDefRoutingTable string // the default routing table + + PeerDefPreUp string // default action that is executed before the device is up + PeerDefPostUp string // default action that is executed after the device is up + PeerDefPreDown string // default action that is executed before the device is down + PeerDefPostDown string // default action that is executed after the device is down +} + +type PeerInterfaceConfig struct { + Identifier InterfaceIdentifier // the interface identifier + AddressStr StringConfigOption // the interface ip addresses, comma separated + DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated + Mtu IntConfigOption // the device MTU + FirewallMark Int32ConfigOption // a firewall mark + RoutingTable StringConfigOption // the routing table + + PreUp StringConfigOption // action that is executed before the device is up + PostUp StringConfigOption // action that is executed after the device is up + PreDown StringConfigOption // action that is executed before the device is down + PostDown StringConfigOption // action that is executed after the device is down +} + +type PeerConfig struct { + BaseModel + + // WireGuard specific (for the [peer] section of the config file) + + Endpoint StringConfigOption // the endpoint address + AllowedIPsStr StringConfigOption // all allowed ip subnets, comma seperated + ExtraAllowedIPsStr string // all allowed ip subnets on the server side, comma seperated + KeyPair KeyPair // private/public Key of the peer + PresharedKey string // the pre-shared Key of the peer + PersistentKeepalive IntConfigOption // the persistent keep-alive interval + + // WG Portal specific + + DisplayName string // a nice display name/ description for the peer + Identifier PeerIdentifier // peer unique identifier + UserIdentifier UserIdentifier // the owner + + // Interface settings for the peer, used to generate the [interface] section in the peer config file + PeerInterfaceConfig +} + +type UserSource string + +const ( + UserSourceLdap UserSource = "ldap" // LDAP / ActiveDirectory + UserSourceDatabase UserSource = "db" // sqlite / mysql database + UserSourceOIDC UserSource = "oidc" // open id connect, TODO: implement +) + +type PrivateString string + +func (PrivateString) MarshalJSON() ([]byte, error) { + return []byte(`""`), nil +} + +func (PrivateString) String() string { + return "" +} + +// User is the user model that gets linked to peer entries, by default an empty user model with only the email address is created +type User struct { + // required fields + Uid UserIdentifier `gorm:"primaryKey"` + Email string `form:"email" binding:"required,email"` + Source UserSource + IsAdmin bool + + // optional fields + Firstname string `form:"firstname" binding:"omitempty"` + Lastname string `form:"lastname" binding:"omitempty"` + Phone string `form:"phone" binding:"omitempty"` + Department string `form:"department" binding:"omitempty"` + + // optional, integrated password authentication + Password PrivateString `form:"password" binding:"omitempty"` + + // database internal fields + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty" swaggertype:"string"` +} diff --git a/internal/persistence/options.go b/internal/persistence/options.go new file mode 100644 index 0000000..683eae5 --- /dev/null +++ b/internal/persistence/options.go @@ -0,0 +1,81 @@ +package persistence + +// ConfigOption is an Overridable configuration option +type ConfigOption struct { + Value interface{} + Overridable bool +} + +type StringConfigOption struct { + ConfigOption +} + +func (o StringConfigOption) GetValue() string { + if o.Value == nil { + return "" + } + return o.Value.(string) +} + +func NewStringConfigOption(value string, overridable bool) StringConfigOption { + return StringConfigOption{ConfigOption{ + Value: value, + Overridable: overridable, + }} +} + +type IntConfigOption struct { + ConfigOption +} + +func (o IntConfigOption) GetValue() int { + if o.Value == nil { + return 0 + } + return o.Value.(int) +} + +func NewIntConfigOption(value int, overridable bool) IntConfigOption { + return IntConfigOption{ConfigOption{ + Value: value, + Overridable: overridable, + }} +} + +type Int32ConfigOption struct { + ConfigOption +} + +func (o Int32ConfigOption) GetValue() int32 { + if o.Value == nil { + return 0 + } + + return o.Value.(int32) +} + +func NewInt32ConfigOption(value int32, overridable bool) Int32ConfigOption { + return Int32ConfigOption{ConfigOption{ + Value: value, + Overridable: overridable, + }} +} + +type BoolConfigOption struct { + ConfigOption +} + +func (o BoolConfigOption) GetValue() bool { + if o.Value == nil { + return false + } + + return o.Value.(bool) +} + +func NewBoolConfigOption(value bool, overridable bool) BoolConfigOption { + return BoolConfigOption{ConfigOption{ + Value: value, + Overridable: overridable, + }} +} diff --git a/internal/persistence/users.go b/internal/persistence/users.go new file mode 100644 index 0000000..8f3d56a --- /dev/null +++ b/internal/persistence/users.go @@ -0,0 +1,19 @@ +package persistence + +import "gorm.io/gorm" + +type UserFilterCondition func(tx *gorm.DB) + +type UsersLoader interface { + GetUser(id UserIdentifier) (User, error) + GetUsers() ([]User, error) + GetUsersUnscoped() ([]User, error) + GetUsersFiltered(filter ...UserFilterCondition) ([]User, error) +} + +type Users interface { + UsersLoader + + SaveUser(user User) error + DeleteUser(identifier UserIdentifier) error +} diff --git a/internal/persistence/wireguard.go b/internal/persistence/wireguard.go new file mode 100644 index 0000000..9f320f1 --- /dev/null +++ b/internal/persistence/wireguard.go @@ -0,0 +1,14 @@ +package persistence + +type WireGuard interface { + GetAvailableInterfaces() ([]InterfaceIdentifier, error) + + GetAllInterfaces(interfaceIdentifiers ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) + GetInterface(identifier InterfaceIdentifier) (InterfaceConfig, []PeerConfig, error) + + SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error + SavePeer(peer PeerConfig, interfaceIdentifier InterfaceIdentifier) error + + DeleteInterface(identifier InterfaceIdentifier) error + DeletePeer(peer PeerIdentifier, interfaceIdentifier InterfaceIdentifier) error +} diff --git a/internal/portal/api.go b/internal/portal/api.go new file mode 100644 index 0000000..8d7996b --- /dev/null +++ b/internal/portal/api.go @@ -0,0 +1 @@ +package portal diff --git a/internal/portal/web.go b/internal/portal/web.go new file mode 100644 index 0000000..8d7996b --- /dev/null +++ b/internal/portal/web.go @@ -0,0 +1 @@ +package portal diff --git a/internal/user/authentication.go b/internal/user/authentication.go new file mode 100644 index 0000000..b200aa5 --- /dev/null +++ b/internal/user/authentication.go @@ -0,0 +1,4 @@ +package user + +type Authenticator interface { +} diff --git a/internal/user/manager.go b/internal/user/manager.go new file mode 100644 index 0000000..7ed2597 --- /dev/null +++ b/internal/user/manager.go @@ -0,0 +1,9 @@ +package user + +import ( + "github.com/h44z/wg-portal/internal/persistence" +) + +type Manager interface { + persistence.UsersLoader +} diff --git a/internal/wireguard/backend_db.go b/internal/wireguard/backend_db.go deleted file mode 100644 index c0eadeb..0000000 --- a/internal/wireguard/backend_db.go +++ /dev/null @@ -1,414 +0,0 @@ -package wireguard - -import ( - "database/sql" - "time" - - "gorm.io/gorm/clause" - - "github.com/pkg/errors" - "gorm.io/gorm" -) - -var DatabaseBackendName = "db" - -type DatabaseBackend struct { - db *gorm.DB -} - -func NewDatabaseBackend(db *gorm.DB) (*DatabaseBackend, error) { - backend := &DatabaseBackend{db: db} - - // Auto-Migrate Gorm models - err := db.AutoMigrate(&dbInterfaceConfig{}, &dbDefaultPeerConfig{}, &dbPeerConfig{}) - if err != nil { - return nil, errors.Wrap(err, "failed to migrate WireGuard database") - } - - return backend, nil -} - -func (d DatabaseBackend) Name() string { - return DatabaseBackendName -} - -func (d DatabaseBackend) SaveInterface(cfg InterfaceConfig, _ []PeerConfig) error { - iface, peerDefaults := convertInterface(cfg) - - if err := d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&iface).Error; err != nil { - return errors.Wrapf(err, "failed to save interface %s to db", cfg.DeviceName) - } - if err := d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&peerDefaults).Error; err != nil { - return errors.Wrapf(err, "failed to save peer defaults of %s to db", cfg.DeviceName) - } - - return nil -} - -func (d DatabaseBackend) SavePeer(cfg PeerConfig, iface InterfaceConfig) error { - peer := convertPeer(cfg, iface.DeviceName) - - if err := d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&peer).Error; err != nil { - return errors.Wrapf(err, "failed to save peer %s to db", cfg.Uid) - } - - return nil -} - -func (d DatabaseBackend) DeleteInterface(cfg InterfaceConfig, _ []PeerConfig) error { - // Delete peers - if err := d.db.Where("device_name = ?", cfg.DeviceName).Delete(&dbPeerConfig{}).Error; err != nil { - return errors.Wrapf(err, "failed to delete peer for %s from db", cfg.DeviceName) - } - // Delete peer default settings - if err := d.db.Where("device_name = ?", cfg.DeviceName).Delete(&dbDefaultPeerConfig{}).Error; err != nil { - return errors.Wrapf(err, "failed to delete peer defaults for %s from db", cfg.DeviceName) - } - // Delete interface config - if err := d.db.Where("device_name = ?", cfg.DeviceName).Delete(&dbInterfaceConfig{}).Error; err != nil { - return errors.Wrapf(err, "failed to delete interface %s from db", cfg.DeviceName) - } - return nil -} - -func (d DatabaseBackend) DeletePeer(cfg PeerConfig, iface InterfaceConfig) error { - err := d.db.Where("device_name = ? AND uid = ?", iface.DeviceName, cfg.Uid).Delete(&dbPeerConfig{}).Error - if err != nil { - return errors.Wrapf(err, "failed to delete peer %s from db", cfg.Uid) - } - return nil -} - -func (d DatabaseBackend) Load(identifier DeviceIdentifier) (InterfaceConfig, []PeerConfig, error) { - var iface dbInterfaceConfig - var peerDefaults dbDefaultPeerConfig - var peers []dbPeerConfig - - if err := d.db.Where("device_name = ?", identifier).First(&iface).Error; err != nil { - return InterfaceConfig{}, nil, errors.Wrapf(err, "failed to load interface %s from db", identifier) - } - if err := d.db.Where("device_name = ?", identifier).First(&peerDefaults).Error; err != nil { - return InterfaceConfig{}, nil, errors.Wrapf(err, "failed to load peer defaults for %s from db", identifier) - } - if err := d.db.Where("device_name = ?", identifier).Find(&peers).Error; err != nil { - return InterfaceConfig{}, nil, errors.Wrapf(err, "failed to load peers for %s from db", identifier) - } - - interfaceConfig := InterfaceConfig{ - DeviceName: DeviceIdentifier(iface.DeviceName), - KeyPair: KeyPair{PrivateKey: iface.PrivateKey, PublicKey: iface.PublicKey}, - ListenPort: iface.ListenPort, - AddressStr: iface.AddressStr, - DnsStr: iface.DnsStr, - Mtu: iface.Mtu, - FirewallMark: int32(iface.FirewallMark), - RoutingTable: iface.RoutingTable, - PreUp: iface.PreUp, - PostUp: iface.PostUp, - PreDown: iface.PreDown, - PostDown: iface.PostDown, - SaveConfig: iface.SaveConfig, - Enabled: iface.Enabled, - DisplayName: iface.DisplayName, - Type: InterfaceType(iface.Type), - DriverType: iface.DriverType, - - PeerDefNetworkStr: peerDefaults.NetworkStr, - PeerDefDnsStr: peerDefaults.DnsStr, - PeerDefEndpoint: peerDefaults.Endpoint, - PeerDefAllowedIPsStr: peerDefaults.AllowedIPsStr, - PeerDefMtu: peerDefaults.Mtu, - PeerDefPersistentKeepalive: peerDefaults.PersistentKeepalive, - PeerDefFirewallMark: int32(peerDefaults.FirewallMark), - PeerDefRoutingTable: peerDefaults.RoutingTable, - PeerDefPreUp: peerDefaults.PreUp, - PeerDefPostUp: peerDefaults.PostUp, - PeerDefPreDown: peerDefaults.PreDown, - PeerDefPostDown: peerDefaults.PostDown, - - DisabledAt: nil, - BaseConfig: BaseConfig{ - CreatedAt: iface.CreatedAt, - UpdatedAt: iface.UpdatedAt, - CreatedBy: iface.CreatedBy, - UpdatedBy: iface.UpdatedBy, - }, - } - if iface.DisabledAt.Valid { - interfaceConfig.DisabledAt = &iface.DisabledAt.Time - } - - peerConfigs := make([]PeerConfig, len(peers)) - for i, peer := range peers { - peerConfigs[i] = PeerConfig{ - Endpoint: NewStringConfigOption(peer.Endpoint, peer.OvrEndpoint), - AllowedIPsStr: NewStringConfigOption(peer.AllowedIPsStr, peer.OvrAllowedIPsStr), - ExtraAllowedIPsStr: peer.ExtraAllowedIPsStr, - KeyPair: KeyPair{PrivateKey: peer.PrivateKey, PublicKey: peer.PublicKey}, - PresharedKey: peer.PresharedKey, - PersistentKeepalive: NewIntConfigOption(peer.PersistentKeepalive, peer.OvrPersistentKeepalive), - Identifier: peer.Identifier, - Uid: PeerIdentifier(peer.Uid), - AddressStr: NewStringConfigOption(peer.AddressStr, peer.OvrAddressStr), - DnsStr: NewStringConfigOption(peer.DnsStr, peer.OvrDnsStr), - Mtu: NewIntConfigOption(peer.Mtu, peer.OvrMtu), - FirewallMark: NewInt32ConfigOption(int32(peer.FirewallMark), peer.OvrFirewallMark), - RoutingTable: NewStringConfigOption(peer.RoutingTable, peer.OvrRoutingTable), - PreUp: NewStringConfigOption(peer.PreUp, peer.OvrPreUp), - PostUp: NewStringConfigOption(peer.PostUp, peer.OvrPostUp), - PreDown: NewStringConfigOption(peer.PreDown, peer.OvrPreDown), - PostDown: NewStringConfigOption(peer.PostDown, peer.OvrPostDown), - - DisabledAt: nil, - BaseConfig: BaseConfig{ - CreatedAt: iface.CreatedAt, - UpdatedAt: iface.UpdatedAt, - CreatedBy: iface.CreatedBy, - UpdatedBy: iface.UpdatedBy, - }, - } - - if peer.DisabledAt.Valid { - peerConfigs[i].DisabledAt = &peer.DisabledAt.Time - } - } - - return interfaceConfig, peerConfigs, nil -} - -func (d DatabaseBackend) LoadAll(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { - result := make(map[InterfaceConfig][]PeerConfig) - for _, identifier := range interfaceIdentifiers { - iface, peers, err := d.Load(identifier) - if err != nil { - return nil, errors.Wrapf(err, "failed to load data for %s", identifier) - } - result[iface] = peers - } - - return result, nil -} - -func (d DatabaseBackend) GetAvailableInterfaces() ([]DeviceIdentifier, error) { - var iface []dbInterfaceConfig - if err := d.db.Find(&iface).Error; err != nil { - return nil, errors.Wrap(err, "failed to load interfaces from db") - } - - interfaces := make([]DeviceIdentifier, len(iface)) - for i := range iface { - interfaces[i] = DeviceIdentifier(iface[i].DeviceName) - } - - return interfaces, nil -} - -// -// --- Models -// - -type dbBaseModel struct { - CreatedBy string - UpdatedBy string - CreatedAt time.Time - UpdatedAt time.Time -} - -type dbInterfaceConfig struct { - dbBaseModel - DisabledAt sql.NullTime - - // WireGuard specific (for the [interface] section of the config file) - - DeviceName string `gorm:"primaryKey"` - PrivateKey string - PublicKey string - ListenPort int - - AddressStr string - DnsStr string - - Mtu int - FirewallMark int - RoutingTable string - - PreUp string - PostUp string - PreDown string - PostDown string - - SaveConfig bool - - // WG Portal specific - Enabled bool - DisplayName string - Type string - DriverType string - - // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of - // the peer config - - dbDefaultPeerConfig dbDefaultPeerConfig -} - -func (d dbInterfaceConfig) TableName() string { - return "interface" -} - -type dbDefaultPeerConfig struct { - dbBaseModel - - DeviceName string `gorm:"primaryKey"` // Foreign key - - NetworkStr string // the default subnets from which peers will get their IP addresses, comma seperated - DnsStr string // the default dns server for the peer - Endpoint string // the default endpoint for the peer - AllowedIPsStr string // the default allowed IP string for the peer - Mtu int // the default device MTU - PersistentKeepalive int // the default persistent keep-alive Value - FirewallMark int // default firewall mark - RoutingTable string // the default routing table - - PreUp string // default action that is executed before the device is up - PostUp string // default action that is executed after the device is up - PreDown string // default action that is executed before the device is down - PostDown string // default action that is executed after the device is down -} - -func (d dbDefaultPeerConfig) TableName() string { - return "peer_defaults" -} - -type dbPeerConfig struct { - dbBaseModel - DisabledAt sql.NullTime - - DeviceName string `gorm:"primaryKey"` - Endpoint string - OvrEndpoint bool - AllowedIPsStr string - OvrAllowedIPsStr bool - ExtraAllowedIPsStr string - PrivateKey string - PublicKey string - PresharedKey string - PersistentKeepalive int - OvrPersistentKeepalive bool - - // WG Portal specific - - Identifier string - Uid string `gorm:"primaryKey"` - - // Interface settings for the peer, used to generate the [interface] section in the peer config file - - AddressStr string - OvrAddressStr bool - DnsStr string - OvrDnsStr bool - Mtu int - OvrMtu bool - FirewallMark int - OvrFirewallMark bool - RoutingTable string - OvrRoutingTable bool - - PreUp string - OvrPreUp bool - PostUp string - OvrPostUp bool - PreDown string - OvrPreDown bool - PostDown string - OvrPostDown bool -} - -func (d dbPeerConfig) TableName() string { - return "peer" -} - -func convertPeer(peer PeerConfig, devName DeviceIdentifier) dbPeerConfig { - cfg := dbPeerConfig{ - DeviceName: string(devName), - Endpoint: peer.Endpoint.GetValue(), - OvrEndpoint: peer.Endpoint.Overridable, - AllowedIPsStr: peer.AllowedIPsStr.GetValue(), - OvrAllowedIPsStr: peer.AllowedIPsStr.Overridable, - ExtraAllowedIPsStr: peer.ExtraAllowedIPsStr, - PrivateKey: peer.KeyPair.PrivateKey, - PublicKey: peer.KeyPair.PublicKey, - PresharedKey: peer.PresharedKey, - PersistentKeepalive: peer.PersistentKeepalive.GetValue(), - OvrPersistentKeepalive: peer.PersistentKeepalive.Overridable, - Identifier: peer.Identifier, - Uid: string(peer.Uid), - AddressStr: peer.AddressStr.GetValue(), - OvrAddressStr: peer.AddressStr.Overridable, - DnsStr: peer.DnsStr.GetValue(), - OvrDnsStr: peer.DnsStr.Overridable, - Mtu: peer.Mtu.GetValue(), - OvrMtu: peer.Mtu.Overridable, - FirewallMark: int(peer.FirewallMark.GetValue()), - OvrFirewallMark: peer.FirewallMark.Overridable, - RoutingTable: peer.RoutingTable.GetValue(), - OvrRoutingTable: peer.RoutingTable.Overridable, - PreUp: peer.PreUp.GetValue(), - OvrPreUp: peer.PreUp.Overridable, - PostUp: peer.PostUp.GetValue(), - OvrPostUp: peer.PostUp.Overridable, - PreDown: peer.PreDown.GetValue(), - OvrPreDown: peer.PreDown.Overridable, - PostDown: peer.PostDown.GetValue(), - OvrPostDown: peer.PostDown.Overridable, - DisabledAt: sql.NullTime{Time: time.Time{}, Valid: peer.DisabledAt != nil}, - } - if peer.DisabledAt != nil { - cfg.DisabledAt.Time = *peer.DisabledAt - } - - return cfg -} - -func convertInterface(iface InterfaceConfig) (dbInterfaceConfig, dbDefaultPeerConfig) { - cfg := dbInterfaceConfig{ - DeviceName: string(iface.DeviceName), - PrivateKey: iface.KeyPair.PrivateKey, - PublicKey: iface.KeyPair.PublicKey, - ListenPort: iface.ListenPort, - AddressStr: iface.AddressStr, - DnsStr: iface.DnsStr, - Mtu: iface.Mtu, - FirewallMark: int(iface.FirewallMark), - RoutingTable: iface.RoutingTable, - PreUp: iface.PreUp, - PostUp: iface.PostUp, - PreDown: iface.PreDown, - PostDown: iface.PostDown, - SaveConfig: iface.SaveConfig, - Enabled: iface.Enabled, - DisplayName: iface.DisplayName, - Type: string(iface.Type), - DriverType: iface.DriverType, - DisabledAt: sql.NullTime{Time: time.Time{}, Valid: iface.DisabledAt != nil}, - } - if iface.DisabledAt != nil { - cfg.DisabledAt.Time = *iface.DisabledAt - } - peerDefaults := dbDefaultPeerConfig{ - DeviceName: string(iface.DeviceName), - NetworkStr: iface.PeerDefNetworkStr, - DnsStr: iface.PeerDefDnsStr, - Endpoint: iface.PeerDefEndpoint, - AllowedIPsStr: iface.PeerDefAllowedIPsStr, - Mtu: iface.PeerDefMtu, - PersistentKeepalive: iface.PeerDefPersistentKeepalive, - FirewallMark: int(iface.PeerDefFirewallMark), - RoutingTable: iface.PeerDefRoutingTable, - PreUp: iface.PeerDefPreUp, - PostUp: iface.PeerDefPostUp, - PreDown: iface.PeerDefPreDown, - PostDown: iface.PeerDefPostDown, - } - - return cfg, peerDefaults -} diff --git a/internal/wireguard/backend_db_test.go b/internal/wireguard/backend_db_test.go deleted file mode 100644 index b8b6bbc..0000000 --- a/internal/wireguard/backend_db_test.go +++ /dev/null @@ -1,422 +0,0 @@ -package wireguard - -import ( - "database/sql" - "database/sql/driver" - "reflect" - "testing" - "time" - - "github.com/pkg/errors" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/DATA-DOG/go-sqlmock" - "gorm.io/driver/mysql" - "gorm.io/gorm" -) - -type AnyTime struct{} - -// Match satisfies sqlmock.Argument interface -func (a AnyTime) Match(v driver.Value) bool { - _, ok := v.(time.Time) - return ok -} - -func getMockedGorm() (*gorm.DB, sqlmock.Sqlmock, error) { - // Default mock with regex matching (https://tienbm90.medium.com/unit-test-for-gorm-application-with-go-sqlmock-ecb5c369e570) - db, mock, err := sqlmock.New() - if err != nil { - return nil, nil, err - } - - gdb, err := gorm.Open(mysql.New(mysql.Config{ - Conn: db, - SkipInitializeWithVersion: true, - }), &gorm.Config{ - SkipDefaultTransaction: true, - }) // open gorm db - if err != nil { - return nil, nil, err - } - return gdb, mock, nil -} - -func TestDatabaseBackend_DeleteInterface(t *testing.T) { - db, mock, err := getMockedGorm() - require.NoError(t, err) - backend := &DatabaseBackend{db: db} - - type args struct { - iface InterfaceConfig - peers []PeerConfig - } - tests := []struct { - name string - mock func() - args args - wantErr bool - }{ - { - name: "Success", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("DELETE FROM `peer_defaults` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("DELETE FROM `interface` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - }, - args: args{ - iface: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: false, - }, - { - name: "Peer Delete Failure", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnError(errors.New("peererr")) - }, - args: args{ - iface: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: true, - }, - { - name: "Peer Defaults Delete Failure", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("DELETE FROM `peer_defaults` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnError(errors.New("defaultserr")) - }, - args: args{ - iface: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: true, - }, - { - name: "Interface Delete Failure", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("DELETE FROM `peer_defaults` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("DELETE FROM `interface` WHERE device_name = \\?"). - WithArgs("wg0").WillReturnError(errors.New("ifaceerr")) - }, - args: args{ - iface: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - if err := backend.DeleteInterface(tt.args.iface, tt.args.peers); (err != nil) != tt.wantErr { - t.Errorf("DeleteInterface() error = %v, wantErr %v", err, tt.wantErr) - } - assert.NoError(t, mock.ExpectationsWereMet()) - }) - } -} - -func TestDatabaseBackend_DeletePeer(t *testing.T) { - db, mock, err := getMockedGorm() - require.NoError(t, err) - backend := &DatabaseBackend{db: db} - - type args struct { - peer PeerConfig - iface InterfaceConfig - } - tests := []struct { - name string - mock func() - args args - wantErr bool - }{ - { - name: "Success", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\? AND uid = \\?"). - WithArgs("wg0", "peer0").WillReturnResult(sqlmock.NewResult(1, 1)) - }, - args: args{ - peer: PeerConfig{Uid: "peer0"}, - iface: InterfaceConfig{DeviceName: "wg0"}, - }, - wantErr: false, - }, - { - name: "Peer Delete Failure", - mock: func() { - mock.ExpectExec("DELETE FROM `peer` WHERE device_name = \\? AND uid = \\?"). - WithArgs("wg0", "peer0").WillReturnError(errors.New("peererr")) - }, - args: args{ - peer: PeerConfig{Uid: "peer0"}, - iface: InterfaceConfig{DeviceName: "wg0"}, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - if err := backend.DeletePeer(tt.args.peer, tt.args.iface); (err != nil) != tt.wantErr { - t.Errorf("DeletePeer() error = %v, wantErr %v", err, tt.wantErr) - } - assert.NoError(t, mock.ExpectationsWereMet()) - }) - } -} - -func TestDatabaseBackend_Load(t *testing.T) { - type fields struct { - db *gorm.DB - } - type args struct { - identifier DeviceIdentifier - } - tests := []struct { - name string - fields fields - args args - want InterfaceConfig - want1 []PeerConfig - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := DatabaseBackend{ - db: tt.fields.db, - } - got, got1, err := d.Load(tt.args.identifier) - if (err != nil) != tt.wantErr { - t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Load() got = %v, want %v", got, tt.want) - } - if !reflect.DeepEqual(got1, tt.want1) { - t.Errorf("Load() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} - -func TestDatabaseBackend_LoadAll(t *testing.T) { - type fields struct { - db *gorm.DB - } - type args struct { - ignored []DeviceIdentifier - } - tests := []struct { - name string - fields fields - args args - want map[InterfaceConfig][]PeerConfig - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := DatabaseBackend{ - db: tt.fields.db, - } - got, err := d.LoadAll(tt.args.ignored...) - if (err != nil) != tt.wantErr { - t.Errorf("LoadAll() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("LoadAll() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestDatabaseBackend_SaveInterface(t *testing.T) { - db, mock, err := getMockedGorm() - require.NoError(t, err) - backend := &DatabaseBackend{db: db} - - type args struct { - cfg InterfaceConfig - peers []PeerConfig - } - tests := []struct { - name string - mock func() - args args - wantErr bool - }{ - { - name: "Success Create", - mock: func() { - mock.ExpectExec("INSERT INTO `interface` .*"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("INSERT INTO `peer_defaults` .*"). - WillReturnResult(sqlmock.NewResult(1, 1)) - }, - args: args{ - cfg: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: false, - }, - { - name: "Error Interface", - mock: func() { - mock.ExpectExec("INSERT INTO `interface` .*"). - WillReturnError(errors.New("ifaceerr")) - }, - args: args{ - cfg: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: true, - }, - { - name: "Error Peer Defaults", - mock: func() { - mock.ExpectExec("INSERT INTO `interface` .*"). - WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("INSERT INTO `peer_defaults` .*"). - WillReturnError(errors.New("ifaceerr")) - }, - args: args{ - cfg: InterfaceConfig{DeviceName: "wg0"}, - peers: nil, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - if err := backend.SaveInterface(tt.args.cfg, tt.args.peers); (err != nil) != tt.wantErr { - t.Errorf("SaveInterface() error = %v, wantErr %v", err, tt.wantErr) - } - assert.NoError(t, mock.ExpectationsWereMet()) - }) - } -} - -func TestDatabaseBackend_SavePeer(t *testing.T) { - db, mock, err := getMockedGorm() - require.NoError(t, err) - backend := &DatabaseBackend{db: db} - - type args struct { - peer PeerConfig - iface InterfaceConfig - } - tests := []struct { - name string - mock func() - args args - wantErr bool - }{ - { - name: "Success Create", - mock: func() { - mock.ExpectExec("INSERT INTO `peer` .*"). - WillReturnResult(sqlmock.NewResult(1, 1)) - }, - args: args{ - peer: PeerConfig{Uid: "peer0"}, - iface: InterfaceConfig{DeviceName: "wg0"}, - }, - wantErr: false, - }, - { - name: "Error", - mock: func() { - mock.ExpectExec("INSERT INTO `peer` .*"). - WillReturnError(errors.New("peererr")) - }, - args: args{ - peer: PeerConfig{Uid: "peer0"}, - iface: InterfaceConfig{DeviceName: "wg0"}, - }, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.mock() - if err := backend.SavePeer(tt.args.peer, tt.args.iface); (err != nil) != tt.wantErr { - t.Errorf("SavePeer() error = %v, wantErr %v", err, tt.wantErr) - } - assert.NoError(t, mock.ExpectationsWereMet()) - }) - } -} - -func TestNewDatabaseBackend(t *testing.T) { - db, mock, err := getMockedGorm() - require.NoError(t, err) - - // Success - mock.ExpectExec("CREATE TABLE `interface` .*").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("CREATE TABLE `peer_defaults` .*").WillReturnResult(sqlmock.NewResult(1, 1)) - mock.ExpectExec("CREATE TABLE `peer` .*").WillReturnResult(sqlmock.NewResult(1, 1)) - backend, err := NewDatabaseBackend(db) - assert.NoError(t, err) - assert.NotNil(t, backend) - assert.NoError(t, mock.ExpectationsWereMet()) - - // Migration failure - mock.ExpectExec("CREATE TABLE `interface` .*").WillReturnError(errors.New("migerr")) - backend, err = NewDatabaseBackend(db) - assert.Error(t, err) - assert.Nil(t, backend) - assert.NoError(t, mock.ExpectationsWereMet()) -} - -func Test_convertInterface(t *testing.T) { - config, peerDefaultConfig := convertInterface(InterfaceConfig{}) - assert.Equal(t, dbInterfaceConfig{}, config) - assert.Equal(t, dbDefaultPeerConfig{}, peerDefaultConfig) - - now := time.Now() - config, peerDefaultConfig = convertInterface(InterfaceConfig{DisabledAt: &now}) - assert.Equal(t, dbInterfaceConfig{DisabledAt: sql.NullTime{Time: now, Valid: true}}, config) - assert.Equal(t, dbDefaultPeerConfig{}, peerDefaultConfig) -} - -func Test_convertPeer(t *testing.T) { - peer := convertPeer(PeerConfig{}, "wg0") - assert.Equal(t, dbPeerConfig{DeviceName: "wg0"}, peer) - - now := time.Now() - peer = convertPeer(PeerConfig{DisabledAt: &now}, "wg0") - assert.Equal(t, dbPeerConfig{DeviceName: "wg0", DisabledAt: sql.NullTime{Time: now, Valid: true}}, peer) -} - -func Test_dbDefaultPeerConfig_TableName(t *testing.T) { - assert.Equal(t, "peer_defaults", dbDefaultPeerConfig{}.TableName()) -} - -func Test_dbInterfaceConfig_TableName(t *testing.T) { - assert.Equal(t, "interface", dbInterfaceConfig{}.TableName()) -} - -func Test_dbPeerConfig_TableName(t *testing.T) { - assert.Equal(t, "peer", dbPeerConfig{}.TableName()) -} diff --git a/internal/wireguard/backend_file.go b/internal/wireguard/backend_file.go deleted file mode 100644 index 5726415..0000000 --- a/internal/wireguard/backend_file.go +++ /dev/null @@ -1,62 +0,0 @@ -package wireguard - -import ( - "io" - "os" - "path/filepath" - - "github.com/pkg/errors" -) - -type FileBackend struct { - configurationPath string - fileGenerator ConfigFileGenerator -} - -func NewFileBackend(configStoragePath string, fileGenerator ConfigFileGenerator) (*FileBackend, error) { - backend := &FileBackend{configurationPath: configStoragePath, fileGenerator: fileGenerator} - return backend, nil -} - -func (f FileBackend) Name() string { - return "file" -} - -func (f FileBackend) SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error { - configContents, err := f.fileGenerator.GetInterfaceConfig(cfg, peers) - if err != nil { - return errors.Wrapf(err, "failed to generate config file contents for %s", cfg.DeviceName) - } - - configFilePath := filepath.Join(f.configurationPath, string(cfg.DeviceName)+".conf") - configFile, err := os.OpenFile(configFilePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0640) - if err != nil { - return errors.Wrapf(err, "failed to create config file for %s", cfg.DeviceName) - } - defer configFile.Close() - - _, err = io.Copy(configFile, configContents) - if err != nil { - return errors.Wrapf(err, "failed to write config file for %s", cfg.DeviceName) - } - - return nil -} - -func (f FileBackend) SavePeer(_ PeerConfig, _ InterfaceConfig) error { - return nil // the file backend will only store changed interfaces -} - -func (f FileBackend) DeleteInterface(cfg InterfaceConfig, _ []PeerConfig) error { - configFilePath := filepath.Join(f.configurationPath, string(cfg.DeviceName)+".conf") - - err := os.Remove(configFilePath) - if err != nil { - return errors.Wrapf(err, "failed to delete config file for %s", cfg.DeviceName) - } - return nil -} - -func (f FileBackend) DeletePeer(_ PeerConfig, _ InterfaceConfig) error { - return nil // the file backend will only store changed interfaces -} diff --git a/internal/wireguard/backend_file_test.go b/internal/wireguard/backend_file_test.go deleted file mode 100644 index 4fab92c..0000000 --- a/internal/wireguard/backend_file_test.go +++ /dev/null @@ -1,208 +0,0 @@ -package wireguard - -import ( - "bytes" - "io" - "io/ioutil" - "os" - "path/filepath" - "reflect" - "strings" - "testing" - - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockFileGenerator struct { - mock.Mock -} - -func (m *MockFileGenerator) GetInterfaceConfig(cfg InterfaceConfig, peers []PeerConfig) (io.Reader, error) { - args := m.Called(cfg, peers) - return args.Get(0).(io.Reader), args.Error(1) -} - -func (m *MockFileGenerator) GetPeerConfig(peer PeerConfig, iface InterfaceConfig) (io.Reader, error) { - args := m.Called(peer, iface) - return args.Get(0).(io.Reader), args.Error(1) -} - -func TestFileBackend_DeleteInterface(t *testing.T) { - // setup - tmpDir := os.TempDir() - tmpFile, err := ioutil.TempFile(tmpDir, "wg*.conf") - require.NoError(t, err) - defer os.Remove(tmpFile.Name()) - - f := FileBackend{ - configurationPath: tmpDir, - } - - // Successful delete - err = f.DeleteInterface(InterfaceConfig{ - DeviceName: DeviceIdentifier(strings.ReplaceAll(filepath.Base(tmpFile.Name()), ".conf", "")), - }, nil) - assert.NoError(t, err) - - // Unsuccessful delete - err = f.DeleteInterface(InterfaceConfig{ - DeviceName: DeviceIdentifier(strings.ReplaceAll(filepath.Base(tmpFile.Name()), ".conf", "")), - }, nil) - assert.Error(t, err) -} - -func TestFileBackend_DeletePeer(t *testing.T) { - assert.NoError(t, FileBackend{}.DeletePeer(PeerConfig{}, InterfaceConfig{})) -} - -func TestFileBackend_Load(t *testing.T) { - type fields struct { - configurationPath string - fileGenerator ConfigFileGenerator - } - type args struct { - identifier DeviceIdentifier - } - tests := []struct { - name string - fields fields - args args - want InterfaceConfig - want1 []PeerConfig - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := FileBackend{ - configurationPath: tt.fields.configurationPath, - fileGenerator: tt.fields.fileGenerator, - } - got, got1, err := f.Load(tt.args.identifier) - if (err != nil) != tt.wantErr { - t.Errorf("Load() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Load() got = %v, want %v", got, tt.want) - } - if !reflect.DeepEqual(got1, tt.want1) { - t.Errorf("Load() got1 = %v, want %v", got1, tt.want1) - } - }) - } -} - -func TestFileBackend_LoadAll(t *testing.T) { - type fields struct { - configurationPath string - fileGenerator ConfigFileGenerator - } - type args struct { - ignored []DeviceIdentifier - } - tests := []struct { - name string - fields fields - args args - want map[InterfaceConfig][]PeerConfig - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - f := FileBackend{ - configurationPath: tt.fields.configurationPath, - fileGenerator: tt.fields.fileGenerator, - } - got, err := f.LoadAll(tt.args.ignored...) - if (err != nil) != tt.wantErr { - t.Errorf("LoadAll() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("LoadAll() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestFileBackend_SaveInterface(t *testing.T) { - // setup - tmpDir := os.TempDir() - tmpFile, err := ioutil.TempFile(tmpDir, "wg*.conf") - require.NoError(t, err) - defer os.Remove(tmpFile.Name()) - deviceName := strings.ReplaceAll(filepath.Base(tmpFile.Name()), ".conf", "") - - type fields struct { - prepare func(m *mock.Mock) - } - type args struct { - cfg InterfaceConfig - peers []PeerConfig - } - tests := []struct { - name string - fields fields - args args - wantErr bool - }{ - { - name: "FileGeneratorError", - fields: fields{ - prepare: func(m *mock.Mock) { - m.On("GetInterfaceConfig", mock.Anything, mock.Anything). - Return(&bytes.Buffer{}, errors.New("generr")) - }, - }, - args: args{}, - wantErr: true, - }, - { - name: "Success", - fields: fields{ - prepare: func(m *mock.Mock) { - m.On("GetInterfaceConfig", mock.Anything, mock.Anything). - Return(bytes.NewBuffer([]byte("hello world")), nil) - }, - }, - args: args{ - cfg: InterfaceConfig{DeviceName: DeviceIdentifier(deviceName)}, - peers: nil, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - fg := new(MockFileGenerator) - f := FileBackend{ - configurationPath: tmpDir, - fileGenerator: fg, - } - tt.fields.prepare(&fg.Mock) - if err := f.SaveInterface(tt.args.cfg, tt.args.peers); (err != nil) != tt.wantErr { - t.Errorf("SaveInterface() error = %v, wantErr %v", err, tt.wantErr) - } - - fg.AssertExpectations(t) - }) - } -} - -func TestFileBackend_SavePeer(t *testing.T) { - assert.NoError(t, FileBackend{}.SavePeer(PeerConfig{}, InterfaceConfig{})) -} - -func TestNewFileBackend(t *testing.T) { - got, err := NewFileBackend("testing", nil) - assert.NoError(t, err) - assert.NotNil(t, got) -} diff --git a/internal/wireguard/configuration.go b/internal/wireguard/configuration.go deleted file mode 100644 index ed8e6fa..0000000 --- a/internal/wireguard/configuration.go +++ /dev/null @@ -1,209 +0,0 @@ -package wireguard - -import ( - "time" -) - -// ConfigOption is an Overridable configuration option -type ConfigOption struct { - Value interface{} - Overridable bool -} - -type StringConfigOption struct { - ConfigOption -} - -func (o StringConfigOption) GetValue() string { - if o.Value == nil { - return "" - } - return o.Value.(string) -} - -func NewStringConfigOption(value string, overridable bool) StringConfigOption { - return StringConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type IntConfigOption struct { - ConfigOption -} - -func (o IntConfigOption) GetValue() int { - if o.Value == nil { - return 0 - } - return o.Value.(int) -} - -func NewIntConfigOption(value int, overridable bool) IntConfigOption { - return IntConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type Int32ConfigOption struct { - ConfigOption -} - -func (o Int32ConfigOption) GetValue() int32 { - if o.Value == nil { - return 0 - } - - return o.Value.(int32) -} - -func NewInt32ConfigOption(value int32, overridable bool) Int32ConfigOption { - return Int32ConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type BoolConfigOption struct { - ConfigOption -} - -func (o BoolConfigOption) GetValue() bool { - if o.Value == nil { - return false - } - - return o.Value.(bool) -} - -func NewBoolConfigOption(value bool, overridable bool) BoolConfigOption { - return BoolConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type InterfaceType string - -const ( - InterfaceTypeServer InterfaceType = "server" - InterfaceTypeClient InterfaceType = "client" -) - -type DeviceIdentifier string -type PeerIdentifier string - -type BaseConfig struct { - CreatedBy string - UpdatedBy string - CreatedAt time.Time - UpdatedAt time.Time -} - -type InterfaceConfig struct { - BaseConfig - - // WireGuard specific (for the [interface] section of the config file) - - DeviceName DeviceIdentifier // device name, for example: wg0 - KeyPair KeyPair // private/public Key of the server interface - ListenPort int // the listening port, for example: 51820 - - AddressStr string // the interface ip addresses, comma separated - DnsStr string // the dns server that should be set if the interface is up, comma separated - - Mtu int // the device MTU - FirewallMark int32 // a firewall mark - RoutingTable string // the routing table - - PreUp string // action that is executed before the device is up - PostUp string // action that is executed after the device is up - PreDown string // action that is executed before the device is down - PostDown string // action that is executed after the device is down - - SaveConfig bool // automatically persist config changes to the wgX.conf file - - // WG Portal specific - Enabled bool // flag that specifies if the interface is enabled (up) or nor (down) - DisplayName string // a nice display name/ description for the interface - Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient - DriverType string // the interface driver type (linux, software, ...) - - // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of - // the peer config - - PeerDefNetworkStr string // the default subnets from which peers will get their IP addresses, comma seperated - PeerDefDnsStr string // the default dns server for the peer - PeerDefEndpoint string // the default endpoint for the peer - PeerDefAllowedIPsStr string // the default allowed IP string for the peer - PeerDefMtu int // the default device MTU - PeerDefPersistentKeepalive int // the default persistent keep-alive Value - PeerDefFirewallMark int32 // default firewall mark - PeerDefRoutingTable string // the default routing table - - PeerDefPreUp string // default action that is executed before the device is up - PeerDefPostUp string // default action that is executed after the device is up - PeerDefPreDown string // default action that is executed before the device is down - PeerDefPostDown string // default action that is executed after the device is down - - // Internal stats - - DisabledAt *time.Time -} - -type PeerConfig struct { - BaseConfig - - // WireGuard specific (for the [peer] section of the config file) - - Endpoint StringConfigOption // the endpoint address - AllowedIPsStr StringConfigOption // all allowed ip subnets, comma seperated - ExtraAllowedIPsStr string // all allowed ip subnets on the server side, comma seperated - KeyPair KeyPair // private/public Key of the peer - PresharedKey string // the pre-shared Key of the peer - PersistentKeepalive IntConfigOption // the persistent keep-alive interval - - // WG Portal specific - - Identifier string // a nice display name/ description for the peer - Uid PeerIdentifier // peer unique identifier - - // Interface settings for the peer, used to generate the [interface] section in the peer config file - - AddressStr StringConfigOption // the interface ip addresses, comma separated - DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated - Mtu IntConfigOption // the device MTU - FirewallMark Int32ConfigOption // a firewall mark - RoutingTable StringConfigOption // the routing table - - PreUp StringConfigOption // action that is executed before the device is up - PostUp StringConfigOption // action that is executed after the device is up - PreDown StringConfigOption // action that is executed before the device is down - PostDown StringConfigOption // action that is executed after the device is down - - // Internal stats - - DisabledAt *time.Time -} - -type Name interface { - Name() string -} - -// ConfigWriter provides methods for updating persistent backends (like a database or a WireGuard configuration file) -type ConfigWriter interface { - Name - SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error - SavePeer(peer PeerConfig, cfg InterfaceConfig) error - DeleteInterface(cfg InterfaceConfig, peers []PeerConfig) error - DeletePeer(peer PeerConfig, cfg InterfaceConfig) error -} - -// ConfigLoader provides methods to load interface and peer configurations from a persistent backend. -type ConfigLoader interface { - Name - Load(identifier DeviceIdentifier) (InterfaceConfig, []PeerConfig, error) - LoadAll(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) - GetAvailableInterfaces() ([]DeviceIdentifier, error) -} diff --git a/internal/wireguard/configuration_test.go b/internal/wireguard/configuration_test.go deleted file mode 100644 index 97ce95d..0000000 --- a/internal/wireguard/configuration_test.go +++ /dev/null @@ -1,259 +0,0 @@ -package wireguard - -import ( - "reflect" - "testing" -) - -func TestBoolConfigOption_GetValue(t *testing.T) { - type fields struct { - ConfigOption ConfigOption - } - tests := []struct { - name string - fields fields - want bool - }{ - { - name: "Empty", - fields: fields{}, - want: false, - }, - { - name: "True", - fields: fields{ConfigOption: ConfigOption{Value: true}}, - want: true, - }, - { - name: "False", - fields: fields{ConfigOption: ConfigOption{Value: false}}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := BoolConfigOption{ - ConfigOption: tt.fields.ConfigOption, - } - if got := o.GetValue(); got != tt.want { - t.Errorf("GetValue() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestInt32ConfigOption_GetValue(t *testing.T) { - type fields struct { - ConfigOption ConfigOption - } - tests := []struct { - name string - fields fields - want int32 - }{ - { - name: "Empty", - fields: fields{}, - want: 0, - }, - { - name: "Leet", - fields: fields{ConfigOption: ConfigOption{Value: int32(1337)}}, - want: 1337, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := Int32ConfigOption{ - ConfigOption: tt.fields.ConfigOption, - } - if got := o.GetValue(); got != tt.want { - t.Errorf("GetValue() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestIntConfigOption_GetValue(t *testing.T) { - type fields struct { - ConfigOption ConfigOption - } - tests := []struct { - name string - fields fields - want int - }{ - { - name: "Empty", - fields: fields{}, - want: 0, - }, - { - name: "Leet", - fields: fields{ConfigOption: ConfigOption{Value: 1337}}, - want: 1337, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := IntConfigOption{ - ConfigOption: tt.fields.ConfigOption, - } - if got := o.GetValue(); got != tt.want { - t.Errorf("GetValue() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestStringConfigOption_GetValue(t *testing.T) { - type fields struct { - ConfigOption ConfigOption - } - tests := []struct { - name string - fields fields - want string - }{ - { - name: "Empty", - fields: fields{}, - want: "", - }, - { - name: "Leet", - fields: fields{ConfigOption: ConfigOption{Value: "leet"}}, - want: "leet", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := StringConfigOption{ - ConfigOption: tt.fields.ConfigOption, - } - if got := o.GetValue(); got != tt.want { - t.Errorf("GetValue() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewBoolConfigOption(t *testing.T) { - type args struct { - value bool - overridable bool - } - tests := []struct { - name string - args args - want BoolConfigOption - }{ - { - name: "Overridable", - args: args{value: false, overridable: true}, - want: BoolConfigOption{ConfigOption: ConfigOption{Value: false, Overridable: true}}, - }, - { - name: "Not Overridable", - args: args{value: true, overridable: false}, - want: BoolConfigOption{ConfigOption: ConfigOption{Value: true, Overridable: false}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewBoolConfigOption(tt.args.value, tt.args.overridable); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewBoolConfigOption() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewInt32ConfigOption(t *testing.T) { - type args struct { - value int32 - overridable bool - } - tests := []struct { - name string - args args - want Int32ConfigOption - }{ - { - name: "Overridable", - args: args{value: 1337, overridable: true}, - want: Int32ConfigOption{ConfigOption: ConfigOption{Value: int32(1337), Overridable: true}}, - }, - { - name: "Not Overridable", - args: args{value: 1337, overridable: false}, - want: Int32ConfigOption{ConfigOption: ConfigOption{Value: int32(1337), Overridable: false}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewInt32ConfigOption(tt.args.value, tt.args.overridable); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewInt32ConfigOption() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewIntConfigOption(t *testing.T) { - type args struct { - value int - overridable bool - } - tests := []struct { - name string - args args - want IntConfigOption - }{ - { - name: "Overridable", - args: args{value: 1337, overridable: true}, - want: IntConfigOption{ConfigOption: ConfigOption{Value: 1337, Overridable: true}}, - }, - { - name: "Not Overridable", - args: args{value: 1337, overridable: false}, - want: IntConfigOption{ConfigOption: ConfigOption{Value: 1337, Overridable: false}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewIntConfigOption(tt.args.value, tt.args.overridable); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewIntConfigOption() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestNewStringConfigOption(t *testing.T) { - type args struct { - value string - overridable bool - } - tests := []struct { - name string - args args - want StringConfigOption - }{ - { - name: "Overridable", - args: args{value: "leet", overridable: true}, - want: StringConfigOption{ConfigOption: ConfigOption{Value: "leet", Overridable: true}}, - }, - { - name: "Not Overridable", - args: args{value: "leet", overridable: false}, - want: StringConfigOption{ConfigOption: ConfigOption{Value: "leet", Overridable: false}}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := NewStringConfigOption(tt.args.value, tt.args.overridable); !reflect.DeepEqual(got, tt.want) { - t.Errorf("NewStringConfigOption() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/wireguard/keys.go b/internal/wireguard/keys.go index caa8bb8..657a7b1 100644 --- a/internal/wireguard/keys.go +++ b/internal/wireguard/keys.go @@ -1,20 +1,20 @@ package wireguard -import "encoding/base64" +import ( + "encoding/base64" -type KeyPair struct { - PrivateKey string - PublicKey string -} + "github.com/h44z/wg-portal/internal/persistence" -type PreSharedKey string + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) -func (p KeyPair) GetPrivateKeyBytes() []byte { +func GetPrivateKeyBytes(p persistence.KeyPair) []byte { data, _ := base64.StdEncoding.DecodeString(p.PrivateKey) return data } -func (p KeyPair) GetPublicKeyBytes() []byte { +func GetPublicKeyBytes(p persistence.KeyPair) []byte { data, _ := base64.StdEncoding.DecodeString(p.PublicKey) return data } @@ -22,3 +22,26 @@ func KeyBytesToString(key []byte) string { return base64.StdEncoding.EncodeToString(key) } + +type WgCtrlKeyGenerator struct{} + +func (k WgCtrlKeyGenerator) GetFreshKeypair() (persistence.KeyPair, error) { + privateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return persistence.KeyPair{}, errors.Wrap(err, "failed to generate private Key") + } + + return persistence.KeyPair{ + PrivateKey: privateKey.String(), + PublicKey: privateKey.PublicKey().String(), + }, nil +} + +func (k WgCtrlKeyGenerator) GetPreSharedKey() (persistence.PreSharedKey, error) { + preSharedKey, err := wgtypes.GenerateKey() + if err != nil { + return "", errors.Wrap(err, "failed to generate pre-shared Key") + } + + return persistence.PreSharedKey(preSharedKey.String()), nil +} diff --git a/internal/wireguard/keys_test.go b/internal/wireguard/keys_test.go index 3275781..e8936a0 100644 --- a/internal/wireguard/keys_test.go +++ b/internal/wireguard/keys_test.go @@ -3,29 +3,46 @@ import ( "testing" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/stretchr/testify/assert" ) -func TestKeyPair_GetPrivateKeyBytes(t *testing.T) { - kp := KeyPair{ +func TestGetPrivateKeyBytes(t *testing.T) { + kp := persistence.KeyPair{ PrivateKey: "aGVsbG8=", PublicKey: "d29ybGQ=", } - got := kp.GetPrivateKeyBytes() + got := GetPrivateKeyBytes(kp) assert.Equal(t, []byte("hello"), got) } -func TestKeyPair_GetPublicKeyBytes(t *testing.T) { - kp := KeyPair{ +func TestGetPublicKeyBytes(t *testing.T) { + kp := persistence.KeyPair{ PrivateKey: "aGVsbG8=", PublicKey: "d29ybGQ=", } - got := kp.GetPublicKeyBytes() + got := GetPublicKeyBytes(kp) assert.Equal(t, []byte("world"), got) } func TestKeyBytesToString(t *testing.T) { assert.Equal(t, "aGVsbG8=", KeyBytesToString([]byte("hello"))) } + +func TestWgCtrlKeyGenerator_GetFreshKeypair(t *testing.T) { + m := WgCtrlKeyGenerator{} + kp, err := m.GetFreshKeypair() + assert.NoError(t, err) + assert.NotEmpty(t, kp.PrivateKey) + assert.NotEmpty(t, kp.PublicKey) +} + +func TestWgCtrlKeyGenerator_GetPreSharedKey(t *testing.T) { + m := WgCtrlKeyGenerator{} + psk, err := m.GetPreSharedKey() + assert.NoError(t, err) + assert.NotEmpty(t, psk) +} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 62d6df5..7901416 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -1,658 +1,86 @@ package wireguard import ( - "net" - "sort" - "strings" + "io" "sync" - "time" "github.com/h44z/wg-portal/internal/lowlevel" - "github.com/pkg/errors" - "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/h44z/wg-portal/internal/persistence" ) type KeyGenerator interface { - GetFreshKeypair() (KeyPair, error) - GetPreSharedKey() (PreSharedKey, error) + GetFreshKeypair() (persistence.KeyPair, error) + GetPreSharedKey() (persistence.PreSharedKey, error) } -// DeviceManager provides methods to create/update/delete physical WireGuard devices. -type DeviceManager interface { - GetDevices() ([]InterfaceConfig, error) - CreateDevice(device DeviceIdentifier) error - DeleteDevice(device DeviceIdentifier) error - UpdateDevice(device DeviceIdentifier, cfg InterfaceConfig) error +// InterfaceManager provides methods to create/update/delete physical WireGuard devices. +type InterfaceManager interface { + GetInterfaces() ([]persistence.InterfaceConfig, error) + CreateInterface(id persistence.InterfaceIdentifier) error + DeleteInterface(id persistence.InterfaceIdentifier) error + UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error +} + +type ImportableInterface struct { + persistence.InterfaceConfig + ImportLocation string + ImportType string +} + +type ImportManager interface { + GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) + ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) +} + +type ConfigFileGenerator interface { + GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) + GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) } type PeerManager interface { - GetPeers(device DeviceIdentifier) ([]PeerConfig, error) - SavePeers(device DeviceIdentifier, peers ...PeerConfig) error - RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error + GetPeers(device persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) + SavePeers(device persistence.InterfaceIdentifier, peers ...persistence.PeerConfig) error + RemovePeer(device persistence.InterfaceIdentifier, peer persistence.PeerIdentifier) error } -type Opt func(svc *ManagementUtil) - type Manager interface { KeyGenerator - DeviceManager + InterfaceManager PeerManager + ImportManager + ConfigFileGenerator } -// ManagementUtil is a persistent management util for WireGuard configurations. -type ManagementUtil struct { +// +// -- Implementations +// + +type PersistentManager struct { + WgCtrlKeyGenerator + TemplateHandler + mux sync.RWMutex // mutex to synchronize access to maps - wg lowlevel.WireGuardClient // WireGuard interface handler - nl lowlevel.NetlinkClient // Network interface handler - cs ConfigStore // Persistent backend + // external api clients + wg lowlevel.WireGuardClient + nl lowlevel.NetlinkClient + + // persistent backend + store store // internal holder of interface configurations - interfaces map[DeviceIdentifier]InterfaceConfig - + interfaces map[persistence.InterfaceIdentifier]persistence.InterfaceConfig // internal holder of peer configurations - peers map[DeviceIdentifier]map[PeerIdentifier]PeerConfig + peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig } -func NewManagementUtil(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, cs ConfigStore, opts ...Opt) (*ManagementUtil, error) { - m := &ManagementUtil{ +func NewPersistentManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*PersistentManager, error) { + m := &PersistentManager{ mux: sync.RWMutex{}, wg: wg, nl: nl, - cs: cs, - } - - for _, opt := range opts { - opt(m) - } - - // initialize - err := m.initialize() - if err != nil { - return nil, errors.Wrap(err, "failed to initialize WireGuard manager") } return m, nil } - -func IgnoredInterfaces(ignored ...DeviceIdentifier) Opt { - return func(m *ManagementUtil) { - m.unmanagedInterfaces = ignored - } -} - -func (m *ManagementUtil) GetFreshKeypair() (KeyPair, error) { - privateKey, err := wgtypes.GeneratePrivateKey() - if err != nil { - return KeyPair{}, errors.Wrap(err, "failed to generate private Key") - } - - return KeyPair{ - PrivateKey: privateKey.String(), - PublicKey: privateKey.PublicKey().String(), - }, nil -} - -func (m *ManagementUtil) GetPreSharedKey() (PreSharedKey, error) { - preSharedKey, err := wgtypes.GenerateKey() - if err != nil { - return "", errors.Wrap(err, "failed to generate pre-shared Key") - } - - return PreSharedKey(preSharedKey.String()), nil -} - -func (m *ManagementUtil) GetDevices() ([]InterfaceConfig, error) { - interfaces := make([]InterfaceConfig, 0, len(m.interfaces)) - for _, iface := range interfaces { - interfaces = append(interfaces, iface) - } - // Order the interfaces by device name - sort.Slice(interfaces, func(i, j int) bool { - return interfaces[i].DeviceName < interfaces[j].DeviceName - }) - - return interfaces, nil -} - -func (m *ManagementUtil) CreateDevice(identifier DeviceIdentifier) error { - m.mux.Lock() - defer m.mux.Unlock() - if m.deviceExists(identifier) { - return errors.Errorf("device %s already exists", identifier) - } - - err := m.createWgDevice(identifier) - if err != nil { - return errors.Wrapf(err, "failed to create WireGuard interface %s", identifier) - } - - newInterface := InterfaceConfig{DeviceName: identifier} - m.interfaces[identifier] = newInterface - - err = m.persistInterface(identifier, false) - if err != nil { - return errors.Wrapf(err, "failed to persist created interface %s", identifier) - } - - return nil -} - -func (m *ManagementUtil) createWgDevice(identifier DeviceIdentifier) error { - link := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{ - Name: string(identifier), - }, - LinkType: "wireguard", - } - err := m.nl.LinkAdd(link) - if err != nil { - return errors.Wrapf(err, "failed to create WireGuard interface %s", identifier) - } - - if err := m.nl.LinkSetUp(link); err != nil { - return errors.Wrapf(err, "failed to enable WireGuard interface %s", identifier) - } - - return nil -} - -func (m *ManagementUtil) DeleteDevice(identifier DeviceIdentifier) error { - m.mux.Lock() - defer m.mux.Unlock() - if !m.deviceExists(identifier) { - return errors.Errorf("device %s does not exist", identifier) - } - err := m.nl.LinkDel(&netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{ - Name: string(identifier), - }, - LinkType: "wireguard", - }) - if err != nil { - return errors.Wrapf(err, "failed to delete WireGuard interface") - } - - err = m.persistInterface(identifier, true) - if err != nil { - return errors.Wrapf(err, "failed to persist deleted interface %s", identifier) - } - - delete(m.interfaces, identifier) - - return nil -} - -func (m *ManagementUtil) UpdateDevice(identifier DeviceIdentifier, cfg InterfaceConfig) error { - m.mux.Lock() - defer m.mux.Unlock() - if !m.deviceExists(identifier) { - return errors.Errorf("device %s does not exist", identifier) - } - cfg.DeviceName = identifier // ensure that the same device name is set - - // Update net-link attributes - link, err := m.nl.LinkByName(string(identifier)) - if err != nil { - return errors.Wrapf(err, "failed to open WireGuard interface") - } - if err := m.nl.LinkSetMTU(link, cfg.Mtu); err != nil { - return errors.Wrapf(err, "failed to set MTU") - } - addresses, err := parseIpAddressString(cfg.AddressStr) - for i := 0; i < len(addresses); i++ { - var err error - if i == 0 { - err = m.nl.AddrReplace(link, addresses[i]) - } else { - err = m.nl.AddrAdd(link, addresses[i]) - } - if err != nil { - return errors.Wrapf(err, "failed to set ip address %v", addresses[i]) - } - } - - // Update WireGuard attributes - pKey, _ := wgtypes.NewKey(cfg.KeyPair.GetPrivateKeyBytes()) - var fwMark *int - if cfg.FirewallMark != 0 { - *fwMark = int(cfg.FirewallMark) - } - err = m.wg.ConfigureDevice(string(identifier), wgtypes.Config{ - PrivateKey: &pKey, - ListenPort: &cfg.ListenPort, - FirewallMark: fwMark, - }) - if err != nil { - return errors.Wrapf(err, "failed to update WireGuard settings") - } - - // Update link state - if cfg.Enabled { - if err := m.nl.LinkSetUp(link); err != nil { - return errors.Wrapf(err, "failed to enable WireGuard interface") - } - } else { - if err := m.nl.LinkSetDown(link); err != nil { - return errors.Wrapf(err, "failed to disable WireGuard interface") - } - } - - m.interfaces[identifier] = cfg - - err = m.persistInterface(identifier, false) - if err != nil { - return errors.Wrapf(err, "failed to persist updated interface %s", identifier) - } - - return nil -} - -func (m *ManagementUtil) GetPeers(device DeviceIdentifier) ([]PeerConfig, error) { - m.mux.RLock() - defer m.mux.RUnlock() - if !m.deviceExists(device) { - return nil, errors.Errorf("device %s does not exist", device) - } - - peers := make([]PeerConfig, 0, len(m.peers[device])) - for _, config := range m.peers[device] { - peers = append(peers, config) - } - - return peers, nil -} - -func (m *ManagementUtil) SavePeers(device DeviceIdentifier, peers ...PeerConfig) error { - m.mux.Lock() - defer m.mux.Unlock() - if !m.deviceExists(device) { - return errors.Errorf("device %s does not exist", device) - } - - deviceConfig := m.interfaces[device] - - for _, peer := range peers { - wgPeer, err := getWireGuardPeerConfig(deviceConfig.Type, peer) - if err != nil { - return errors.Wrapf(err, "could not generate WireGuard peer configuration for %s", peer.Uid) - } - - err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) - if err != nil { - return errors.Wrapf(err, "could not save peer %s to WireGuard device %s", peer.Uid, device) - } - - m.peers[device][peer.Uid] = peer - - err = m.persistPeer(peer.Uid, false) - if err != nil { - return errors.Wrapf(err, "failed to persist updated peer %s", peer.Uid) - } - } - - return nil -} - -func (m *ManagementUtil) RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error { - m.mux.Lock() - defer m.mux.Unlock() - if !m.deviceExists(device) { - return errors.Errorf("device %s does not exist", device) - } - if !m.peerExists(peer) { - return errors.Errorf("peer %s does not exist", peer) - } - - peerConfig := m.peers[device][peer] - - publicKey, err := wgtypes.ParseKey(peerConfig.KeyPair.PublicKey) - if err != nil { - return errors.Wrapf(err, "invalid public key for peer %s", peer) - } - - wgPeer := wgtypes.PeerConfig{ - PublicKey: publicKey, - Remove: true, - } - - err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) - if err != nil { - return errors.Wrapf(err, "could not remove peer %s from WireGuard device %s", peer, device) - } - - err = m.persistPeer(peer, true) - if err != nil { - return errors.Wrapf(err, "failed to persist deleted peer %s", peer) - } - - delete(m.peers[device], peer) - - return nil -} - -// TODO: implement/think about -func (m *ManagementUtil) loadFromBackend() error { - // Load all interfaces from the database - backendInterfaces, err := m.cs.GetAvailableInterfaces() - if err != nil { - return errors.Wrap(err, "failed to load backend interfaces") - } - - /*// Get a list of available WireGuard interfaces - wgInterfaces, err := m.wg.Devices() - if err != nil { - return errors.Wrap(err, "failed to load WireGuard interfaces") - } - - // Create missing WireGuard interfaces - for _, backendInterface := range backendInterfaces { - exists := false - for _, wgInterface := range wgInterfaces { - if string(backendInterface) == wgInterface.Name { - exists = true - break - } - } - if !exists { - err := m.createWgDevice(backendInterface) - if err != nil { - return errors.Wrapf(err, "failed to create WireGuard interface %s found in backend", backendInterface) - } - } - }*/ - - // Load config options from database backend, populate internal state maps - err = m.loadBackendInterfaces(DatabaseBackendName, backendInterfaces...) - if err != nil { - return errors.Wrap(err, "failed to load interface configurations from backend") - } - - // Load missing config options from current interfaces, populate internal state maps - err = m.loadWireGuardInterfaces() - if err != nil { - return errors.Wrap(err, "failed to load interface configurations from WireGuard") - } - - // Persists currently loaded configurations - // TODO - - // Apply configuration options from internal state maps to current interfaces - // TODO - - return nil -} - -func (m *ManagementUtil) loadBackendInterfaces(backend string, identifiers ...DeviceIdentifier) error { - for _, cl := range m.cl { - if cl.Name() != backend { - continue - } - ifaceAndPeers, err := cl.LoadAll(identifiers...) - if err != nil { - return errors.Wrapf(err, "failed to load interfaces from backend %s", cl.Name()) - } - - for iface, peers := range ifaceAndPeers { - m.interfaces[iface.DeviceName] = iface - for _, peer := range peers { - m.peers[iface.DeviceName][peer.Uid] = peer - } - } - } - return nil -} - -func (m *ManagementUtil) loadWireGuardInterfaces() error { - // Get a list of available WireGuard interfaces - wgInterfaces, err := m.wg.Devices() - if err != nil { - return errors.Wrap(err, "failed to load WireGuard interfaces") - } - - for _, iface := range wgInterfaces { - if m.interfaceIsIgnored(DeviceIdentifier(iface.Name)) { - continue - } - - devId := DeviceIdentifier(iface.Name) - if _, existing := m.interfaces[devId]; !existing { - m.interfaces[devId] = m.convertWireGuardInterface(*iface) - } - - for _, peer := range iface.Peers { - peerPublicKey := peer.PublicKey.String() - - // check if peer exists, compare public keys - existing := false - for _, existingPeer := range m.peers[devId] { - if existingPeer.KeyPair.PublicKey == peerPublicKey { - existing = true - break - } - } - - if !existing { - // Use the peers public key as UID - m.peers[devId][PeerIdentifier(peerPublicKey)] = m.convertWireGuardPeer(peer) - } - - } - } - return nil -} - -func (m *ManagementUtil) restoreBackendInterfaces() error { - return nil -} - -func (m *ManagementUtil) interfaceIsIgnored(name DeviceIdentifier) bool { - for _, iface := range m.unmanagedInterfaces { - if iface == name { - return true - } - } - return false -} - -// -// ---- Helpers -// - -func getWireGuardPeerConfig(deviceType InterfaceType, peer PeerConfig) (wgtypes.PeerConfig, error) { - publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey) - if err != nil { - return wgtypes.PeerConfig{}, errors.Wrapf(err, "invalid public key for peer %s", peer.Uid) - } - - var presharedKey *wgtypes.Key - if tmpPresharedKey, err := wgtypes.ParseKey(peer.PresharedKey); err == nil { - presharedKey = &tmpPresharedKey - } - - var endpoint *net.UDPAddr - if peer.Endpoint.Value != "" && deviceType == InterfaceTypeClient { - addr, err := net.ResolveUDPAddr("udp", peer.Endpoint.Value.(string)) - if err == nil { - endpoint = addr - } - } - - var keepAlive *time.Duration - if peer.PersistentKeepalive.Value != 0 { - keepAliveDuration := time.Duration(peer.PersistentKeepalive.Value.(int)) * time.Second - keepAlive = &keepAliveDuration - } - - allowedIPs := make([]net.IPNet, 0) - var peerAllowedIPs []*netlink.Addr - switch deviceType { - case InterfaceTypeClient: - peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsStr.GetValue()) - if err != nil { - return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid) - } - case InterfaceTypeServer: - peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsStr.GetValue()) - if err != nil { - return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid) - } - peerExtraAllowedIPs, err := parseIpAddressString(peer.ExtraAllowedIPsStr) - if err != nil { - return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse extra allowed IP's for peer %s", peer.Uid) - } - - peerAllowedIPs = append(peerAllowedIPs, peerExtraAllowedIPs...) - } - for _, ip := range peerAllowedIPs { - allowedIPs = append(allowedIPs, *ip.IPNet) - } - - wgPeer := wgtypes.PeerConfig{ - PublicKey: publicKey, - Remove: false, - UpdateOnly: true, - PresharedKey: presharedKey, - Endpoint: endpoint, - PersistentKeepaliveInterval: keepAlive, - ReplaceAllowedIPs: true, - AllowedIPs: allowedIPs, - } - - return wgPeer, nil -} - -func (m *ManagementUtil) deviceExists(identifier DeviceIdentifier) bool { - if _, ok := m.interfaces[identifier]; ok { - return true - } - return false -} - -func (m *ManagementUtil) peerExists(identifier PeerIdentifier) bool { - for _, peers := range m.peers { - if _, ok := peers[identifier]; ok { - return true - } - } - - return false -} - -func (m *ManagementUtil) persistInterface(identifier DeviceIdentifier, delete bool) error { - var err error - - device := m.interfaces[identifier] - peers := make([]PeerConfig, 0, len(m.peers[identifier])) - for _, config := range m.peers[identifier] { - peers = append(peers, config) - } - - if !delete { - err = m.cs.SaveInterface(device, peers) - } else { - err = m.cs.DeleteInterface(device.DeviceName) - } - - if err != nil { - return errors.Wrapf(err, "failed to persist interface %s", identifier) - } - - return nil -} - -func (m *ManagementUtil) persistPeer(identifier PeerIdentifier, delete bool) error { - var err error - - var device InterfaceConfig - var peer PeerConfig - for dev, peers := range m.peers { - if p, ok := peers[identifier]; ok { - device = m.interfaces[dev] - peer = p - break - } - } - - if !delete { - err = m.cs.SavePeer(peer, device.DeviceName) - } else { - err = m.cs.DeletePeer(peer.Uid, device.DeviceName) - } - - if err != nil { - return errors.Wrapf(err, "failed to persist peer %s", identifier) - } - - return nil -} - -func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { - rawAddresses := strings.Split(addrStr, ",") - addresses := make([]*netlink.Addr, 0, len(rawAddresses)) - for i := range rawAddresses { - rawAddress := strings.TrimSpace(rawAddresses[i]) - if rawAddress == "" { - continue // skip empty string - } - address, err := netlink.ParseAddr(rawAddress) - if err != nil { - return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress) - } - addresses = append(addresses, address) - } - - return addresses, nil -} - -func (m *ManagementUtil) convertWireGuardInterface(device wgtypes.Device) InterfaceConfig { - cfg := InterfaceConfig{ - DeviceName: DeviceIdentifier(device.Name), - KeyPair: KeyPair{PublicKey: device.PublicKey.String(), PrivateKey: device.PrivateKey.String()}, - ListenPort: device.ListenPort, - FirewallMark: int32(device.FirewallMark), - DriverType: device.Type.String(), - } - - link, err := m.nl.LinkByName(device.Name) - if err != nil || link.Attrs() == nil { - return cfg - } - cfg.Mtu = link.Attrs().MTU - - addresses, err := m.nl.AddrList(link) - if err != nil { - return cfg - } - addressesStr := make([]string, len(addresses)) - for i := range addresses { - addressesStr[i] = addresses[i].String() - } - cfg.AddressStr = strings.Join(addressesStr, ",") - - return cfg -} - -func (m *ManagementUtil) convertWireGuardPeer(peer wgtypes.Peer) PeerConfig { - cfg := PeerConfig{ - KeyPair: KeyPair{PublicKey: peer.PublicKey.String()}, - } - - if peer.Endpoint != nil { - cfg.Endpoint.Value = peer.Endpoint.String() - } - - if peer.PresharedKey != (wgtypes.Key{}) { - cfg.PresharedKey = peer.PresharedKey.String() - } - - ipAddresses := make([]string, len(peer.AllowedIPs)) // use allowed IP's as the peer IP's - for i, ip := range peer.AllowedIPs { - ipAddresses[i] = ip.String() - } - cfg.AddressStr.Value = strings.Join(ipAddresses, ",") - - return cfg -} diff --git a/internal/wireguard/manager_int_test.go b/internal/wireguard/manager_int_test.go deleted file mode 100644 index 366b850..0000000 --- a/internal/wireguard/manager_int_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build integration -// +build integration - -// In Goland you can use File-Nesting to enhance the project view: _int_test.go; _test.go - -// Run integrations tests as root! - -package wireguard - -import ( - "os/exec" - "testing" - - "github.com/h44z/wg-portal/internal/lowlevel" - - "golang.zx2c4.com/wireguard/wgctrl" - - "github.com/stretchr/testify/assert" - "github.com/vishvananda/netlink" -) - -func prepareTest(dev DeviceIdentifier) { - _ = netlink.LinkDel(&netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{ - Name: string(dev), - }, - LinkType: "wireguard", - }) -} - -func TestManagementUtil_CreateDevice(t *testing.T) { - devName := DeviceIdentifier("wg666") - prepareTest(devName) - m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: lowlevel.NetlinkManager{}} - - defer m.DeleteDevice(devName) - err := m.CreateDevice(devName) - assert.NoError(t, err) - - cmd := exec.Command("ip", "addr") - out, err := cmd.CombinedOutput() - assert.NoError(t, err) - assert.Contains(t, string(out), devName) -} - -func TestManagementUtil_DeleteDevice(t *testing.T) { - devName := DeviceIdentifier("wg667") - prepareTest(devName) - m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: lowlevel.NetlinkManager{}} - - err := m.CreateDevice(devName) - assert.NoError(t, err) - err = m.DeleteDevice(devName) - assert.NoError(t, err) - - cmd := exec.Command("ip", "addr") - out, err := cmd.CombinedOutput() - assert.NoError(t, err) - assert.NotContains(t, string(out), devName) -} - -func TestManagementUtil_deviceExists(t *testing.T) { - m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig)} - assert.False(t, m.deviceExists("test")) - - m = ManagementUtil{interfaces: map[DeviceIdentifier]InterfaceConfig{"test": {}}} - assert.True(t, m.deviceExists("test")) -} - -func TestManagementUtil_UpdateDevice(t *testing.T) { - devName := DeviceIdentifier("wg668") - prepareTest(devName) - wg, err := wgctrl.New() - if !assert.NoError(t, err) { - return - } - m := ManagementUtil{interfaces: make(map[DeviceIdentifier]InterfaceConfig), nl: lowlevel.NetlinkManager{}, wg: wg} - - defer m.DeleteDevice(devName) - err = m.CreateDevice(devName) - if !assert.NoError(t, err) { - return - } - - err = m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24", Mtu: 1234}) - assert.NoError(t, err) - - cmd := exec.Command("ip", "addr") - out, err := cmd.CombinedOutput() - assert.NoError(t, err) - assert.Contains(t, string(out), "123.123.123.123") - - err = m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24,fd9f:6666::10:6:6:1/64", Mtu: 1600}) - assert.NoError(t, err) - - cmd = exec.Command("ip", "addr") - out, err = cmd.CombinedOutput() - assert.NoError(t, err) - assert.Contains(t, string(out), "123.123.123.123") - assert.Contains(t, string(out), "fd9f:6666::10:6:6:1") -} diff --git a/internal/wireguard/manager_test.go b/internal/wireguard/manager_test.go deleted file mode 100644 index 1f4e446..0000000 --- a/internal/wireguard/manager_test.go +++ /dev/null @@ -1,239 +0,0 @@ -//go:build !integration -// +build !integration - -package wireguard - -import ( - "net" - "reflect" - "testing" - - "github.com/stretchr/testify/mock" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/stretchr/testify/assert" - "github.com/vishvananda/netlink" -) - -type MockWireGuardClient struct { - mock.Mock -} - -func (m *MockWireGuardClient) Close() error { - args := m.Called() - return args.Error(0) -} - -func (m *MockWireGuardClient) Devices() ([]*wgtypes.Device, error) { - args := m.Called() - return args.Get(0).([]*wgtypes.Device), args.Error(1) -} - -func (m *MockWireGuardClient) Device(name string) (*wgtypes.Device, error) { - args := m.Called(name) - return args.Get(0).(*wgtypes.Device), args.Error(1) -} - -func (m *MockWireGuardClient) ConfigureDevice(name string, cfg wgtypes.Config) error { - args := m.Called(name, cfg) - return args.Error(0) -} - -type MockNetlinkClient struct { - mock.Mock -} - -func (m *MockNetlinkClient) LinkAdd(link netlink.Link) error { - args := m.Called(link) - return args.Error(0) -} - -func (m *MockNetlinkClient) LinkDel(link netlink.Link) error { - args := m.Called(link) - return args.Error(0) -} - -func (m *MockNetlinkClient) LinkByName(name string) (netlink.Link, error) { - args := m.Called(name) - return args.Get(0).(netlink.Link), args.Error(1) -} - -func (m *MockNetlinkClient) LinkSetUp(link netlink.Link) error { - args := m.Called(link) - return args.Error(0) -} - -func (m *MockNetlinkClient) LinkSetDown(link netlink.Link) error { - args := m.Called(link) - return args.Error(0) -} - -func (m *MockNetlinkClient) LinkSetMTU(link netlink.Link, mtu int) error { - args := m.Called(link, mtu) - return args.Error(0) -} - -func (m *MockNetlinkClient) AddrReplace(link netlink.Link, addr *netlink.Addr) error { - args := m.Called(link, addr) - return args.Error(0) -} - -func (m *MockNetlinkClient) AddrAdd(link netlink.Link, addr *netlink.Addr) error { - args := m.Called(link, addr) - return args.Error(0) -} - -// -// ---------- Tests -// - -func TestManagementUtil_GetFreshKeypair(t *testing.T) { - m := ManagementUtil{} - kp, err := m.GetFreshKeypair() - assert.NoError(t, err) - assert.NotEmpty(t, kp.PrivateKey) - assert.NotEmpty(t, kp.PublicKey) -} - -func TestManagementUtil_GetPreSharedKey(t *testing.T) { - m := ManagementUtil{} - psk, err := m.GetPreSharedKey() - assert.NoError(t, err) - assert.NotEmpty(t, psk) -} - -func Test_parseIpAddressString(t *testing.T) { - type args struct { - addrStr string - } - var tests = []struct { - name string - args args - want []*netlink.Addr - wantErr bool - }{ - { - name: "Empty String", - args: args{}, - want: []*netlink.Addr{}, - wantErr: false, - }, - { - name: "Single IPv4", - args: args{addrStr: "123.123.123.123"}, - want: nil, - wantErr: true, - }, - { - name: "Malformed", - args: args{addrStr: "hello world"}, - want: nil, - wantErr: true, - }, - { - name: "Single IPv4 CIDR", - args: args{addrStr: "123.123.123.123/24"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(123, 123, 123, 123), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - }}, - wantErr: false, - }, - { - name: "Multiple IPv4 CIDR", - args: args{addrStr: "123.123.123.123/24, 200.201.202.203/16"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(123, 123, 123, 123), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - }, { - IPNet: &net.IPNet{ - IP: net.IPv4(200, 201, 202, 203), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - }}, - wantErr: false, - }, - { - name: "Single IPv6 CIDR", - args: args{addrStr: "fe80::1/64"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - }}, - wantErr: false, - }, - { - name: "Multiple IPv6 CIDR", - args: args{addrStr: "fe80::1/64 , 2130:d3ad::b33f/128"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - }, { - IPNet: &net.IPNet{ - IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - }, - }}, - wantErr: false, - }, - { - name: "Mixed IPv4 and IPv6 CIDR", - args: args{addrStr: "200.201.202.203/16,2130:d3ad::b33f/128"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(200, 201, 202, 203), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - }, { - IPNet: &net.IPNet{ - IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - }, - }}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseIpAddressString(tt.args.addrStr) - if (err != nil) != tt.wantErr { - t.Errorf("parseIpAddressString() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseIpAddressString() got = %v, want %v", got, tt.want) - } - }) - } -} - -func TestManagementUtil_UpdateDevice(t *testing.T) { - devName := DeviceIdentifier("wg668") - wg := new(MockWireGuardClient) - nl := new(MockNetlinkClient) - - // expectations - nl.On("LinkByName", string(devName)).Return(&netlink.GenericLink{}, nil) - nl.On("LinkSetMTU", mock.Anything, 1234).Return(nil) - nl.On("AddrReplace", mock.Anything, mock.Anything).Return(nil) - wg.On("ConfigureDevice", string(devName), mock.Anything).Return(nil) - nl.On("LinkSetDown", mock.Anything).Return(nil) - - m := ManagementUtil{interfaces: map[DeviceIdentifier]InterfaceConfig{devName: {}}, nl: nl, wg: wg} - - err := m.UpdateDevice(devName, InterfaceConfig{AddressStr: "123.123.123.123/24", Mtu: 1234}) - assert.NoError(t, err) - - // assert that the expectations were met - wg.AssertExpectations(t) - nl.AssertExpectations(t) -} diff --git a/internal/wireguard/persistence.go b/internal/wireguard/persistence.go index 1c31e6d..0431c74 100644 --- a/internal/wireguard/persistence.go +++ b/internal/wireguard/persistence.go @@ -1,14 +1,18 @@ package wireguard -// ConfigStore provides an interface for interacting with different configuration storage repositories. -type ConfigStore interface { - GetAvailableInterfaces() ([]DeviceIdentifier, error) - GetAllInterfaces(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) - GetInterface(identifier DeviceIdentifier) (InterfaceConfig, []PeerConfig, error) +import ( + "github.com/h44z/wg-portal/internal/persistence" +) - SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error - SavePeer(peer PeerConfig, interfaceIdentifier DeviceIdentifier) error +type store interface { + GetAvailableInterfaces() ([]persistence.InterfaceIdentifier, error) - DeleteInterface(identifier DeviceIdentifier) error - DeletePeer(peer PeerIdentifier, interfaceIdentifier DeviceIdentifier) error + GetAllInterfaces(interfaceIdentifiers ...persistence.InterfaceIdentifier) (map[persistence.InterfaceConfig][]persistence.PeerConfig, error) + GetInterface(identifier persistence.InterfaceIdentifier) (persistence.InterfaceConfig, []persistence.PeerConfig, error) + + SaveInterface(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) error + SavePeer(peer persistence.PeerConfig, interfaceIdentifier persistence.InterfaceIdentifier) error + + DeleteInterface(identifier persistence.InterfaceIdentifier) error + DeletePeer(peer persistence.PeerIdentifier, interfaceIdentifier persistence.InterfaceIdentifier) error } diff --git a/internal/wireguard/template.go b/internal/wireguard/template.go index abea343..5ebe02d 100644 --- a/internal/wireguard/template.go +++ b/internal/wireguard/template.go @@ -6,21 +6,13 @@ "io" "text/template" + "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" ) //go:embed tpl_files/* var TemplateFiles embed.FS -type ConfigFileGenerator interface { - GetInterfaceConfig(cfg InterfaceConfig, peers []PeerConfig) (io.Reader, error) - GetPeerConfig(peer PeerConfig, iface InterfaceConfig) (io.Reader, error) -} - -type ConfigFileParser interface { - ParseConfig(fileContents io.Reader) (InterfaceConfig, []PeerConfig, error) -} - type TemplateHandler struct { templates *template.Template } @@ -38,7 +30,7 @@ return handler, nil } -func (c TemplateHandler) GetInterfaceConfig(cfg InterfaceConfig, peers []PeerConfig) (io.Reader, error) { +func (c TemplateHandler) GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "interface.tpl", map[string]interface{}{ @@ -49,24 +41,24 @@ }, }) if err != nil { - return nil, errors.Wrapf(err, "failed to execute interface template for %s", cfg.DeviceName) + return nil, errors.Wrapf(err, "failed to execute interface template for %s", cfg.Identifier) } return &tplBuff, nil } -func (c TemplateHandler) GetPeerConfig(peer PeerConfig, iface InterfaceConfig) (io.Reader, error) { +func (c TemplateHandler) GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "peer.tpl", map[string]interface{}{ "Peer": peer, - "Interface": iface, + "Interface": peer.PeerInterfaceConfig, "Portal": map[string]interface{}{ "Version": "unknown", }, }) if err != nil { - return nil, errors.Wrapf(err, "failed to execute peer template for %s", peer.Uid) + return nil, errors.Wrapf(err, "failed to execute peer template for %s", peer.Identifier) } return &tplBuff, nil diff --git a/internal/wireguard/template_test.go b/internal/wireguard/template_test.go index 9478ca4..d855bb4 100644 --- a/internal/wireguard/template_test.go +++ b/internal/wireguard/template_test.go @@ -6,6 +6,8 @@ "reflect" "testing" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/stretchr/testify/assert" ) @@ -17,8 +19,8 @@ func TestTemplateHandler_GetInterfaceConfig(t *testing.T) { type args struct { - cfg InterfaceConfig - peers []PeerConfig + cfg persistence.InterfaceConfig + peers []persistence.PeerConfig } tests := []struct { name string @@ -75,8 +77,8 @@ func TestTemplateHandler_GetPeerConfig(t *testing.T) { type args struct { - peer PeerConfig - iface InterfaceConfig + peer persistence.PeerConfig + iface persistence.InterfaceConfig } tests := []struct { name string @@ -116,7 +118,7 @@ c, _ := NewTemplateHandler() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := c.GetPeerConfig(tt.args.peer, tt.args.iface) + got, err := c.GetPeerConfig(tt.args.peer) if (err != nil) != tt.wantErr { t.Errorf("GetPeerConfig() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/wireguard/test_helpers_test.go b/internal/wireguard/test_helpers_test.go new file mode 100644 index 0000000..686d6c5 --- /dev/null +++ b/internal/wireguard/test_helpers_test.go @@ -0,0 +1,132 @@ +package wireguard + +import ( + "github.com/h44z/wg-portal/internal/persistence" + "github.com/stretchr/testify/mock" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// +// -- WireGuard mock +// + +type MockWireGuardClient struct { + mock.Mock +} + +func (m *MockWireGuardClient) Close() error { + args := m.Called() + return args.Error(0) +} + +func (m *MockWireGuardClient) Devices() ([]*wgtypes.Device, error) { + args := m.Called() + return args.Get(0).([]*wgtypes.Device), args.Error(1) +} + +func (m *MockWireGuardClient) Device(name string) (*wgtypes.Device, error) { + args := m.Called(name) + return args.Get(0).(*wgtypes.Device), args.Error(1) +} + +func (m *MockWireGuardClient) ConfigureDevice(name string, cfg wgtypes.Config) error { + args := m.Called(name, cfg) + return args.Error(0) +} + +// +// -- Netlink mock +// + +type MockNetlinkClient struct { + mock.Mock +} + +func (m *MockNetlinkClient) LinkAdd(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkDel(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkByName(name string) (netlink.Link, error) { + args := m.Called(name) + return args.Get(0).(netlink.Link), args.Error(1) +} + +func (m *MockNetlinkClient) LinkSetUp(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkSetDown(link netlink.Link) error { + args := m.Called(link) + return args.Error(0) +} + +func (m *MockNetlinkClient) LinkSetMTU(link netlink.Link, mtu int) error { + args := m.Called(link, mtu) + return args.Error(0) +} + +func (m *MockNetlinkClient) AddrReplace(link netlink.Link, addr *netlink.Addr) error { + args := m.Called(link, addr) + return args.Error(0) +} + +func (m *MockNetlinkClient) AddrAdd(link netlink.Link, addr *netlink.Addr) error { + args := m.Called(link, addr) + return args.Error(0) +} + +func (m *MockNetlinkClient) AddrList(link netlink.Link) ([]netlink.Addr, error) { + args := m.Called(link) + return args.Get(0).([]netlink.Addr), args.Error(1) +} + +// +// -- WireGuard Store mock +// + +type MockWireGuardStore struct { + mock.Mock +} + +func (w *MockWireGuardStore) GetAvailableInterfaces() ([]persistence.InterfaceIdentifier, error) { + args := w.Called() + return args.Get(0).([]persistence.InterfaceIdentifier), args.Error(1) +} + +func (w *MockWireGuardStore) GetAllInterfaces(interfaceIdentifiers ...persistence.InterfaceIdentifier) (map[persistence.InterfaceConfig][]persistence.PeerConfig, error) { + args := w.Called(interfaceIdentifiers) + return args.Get(0).(map[persistence.InterfaceConfig][]persistence.PeerConfig), args.Error(1) +} + +func (w *MockWireGuardStore) GetInterface(identifier persistence.InterfaceIdentifier) (persistence.InterfaceConfig, []persistence.PeerConfig, error) { + args := w.Called(identifier) + return args.Get(0).(persistence.InterfaceConfig), args.Get(1).([]persistence.PeerConfig), args.Error(2) +} + +func (w *MockWireGuardStore) SaveInterface(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) error { + args := w.Called(cfg, peers) + return args.Error(0) +} + +func (w *MockWireGuardStore) SavePeer(peer persistence.PeerConfig, interfaceIdentifier persistence.InterfaceIdentifier) error { + args := w.Called(peer, interfaceIdentifier) + return args.Error(0) +} + +func (w *MockWireGuardStore) DeleteInterface(identifier persistence.InterfaceIdentifier) error { + args := w.Called(identifier) + return args.Error(0) +} + +func (w *MockWireGuardStore) DeletePeer(peer persistence.PeerIdentifier, interfaceIdentifier persistence.InterfaceIdentifier) error { + args := w.Called(peer, interfaceIdentifier) + return args.Error(0) +} diff --git a/internal/wireguard/tpl_files/interface.tpl b/internal/wireguard/tpl_files/interface.tpl index 2abb63d..a47542a 100644 --- a/internal/wireguard/tpl_files/interface.tpl +++ b/internal/wireguard/tpl_files/interface.tpl @@ -5,7 +5,7 @@ # Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. [Interface] -# -WGP- Interface: {{ .Interface.DeviceName }} | Updated: {{ .Interface.UpdatedAt }} | Created: {{ .Interface.CreatedAt }} +# -WGP- Interface: {{ .Interface.Identifier }} | Updated: {{ .Interface.UpdatedAt }} | Created: {{ .Interface.CreatedAt }} # -WGP- Display name: {{ .Interface.DisplayName }} # -WGP- Interface mode: {{ .Interface.Type }} # -WGP- PublicKey = {{ .Interface.KeyPair.PublicKey }} diff --git a/internal/wireguard/tpl_files/peer.tpl b/internal/wireguard/tpl_files/peer.tpl index 2278023..cd44f64 100644 --- a/internal/wireguard/tpl_files/peer.tpl +++ b/internal/wireguard/tpl_files/peer.tpl @@ -5,8 +5,8 @@ # Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. [Interface] -# -WGP- Peer: {{.Peer.Uid}} | Updated: {{.Peer.UpdatedAt}} | Created: {{.Peer.CreatedAt}} -# -WGP- Display name: {{ .Peer.Identifier }} +# -WGP- Peer: {{.Peer.Identifier}} | Updated: {{.Peer.UpdatedAt}} | Created: {{.Peer.CreatedAt}} +# -WGP- Display name: {{ .Peer.DisplayName }} # -WGP- PublicKey: {{ .Peer.KeyPair.PublicKey }} {{- if eq $.Interface.Type "server"}} # -WGP- Peer type: client diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go new file mode 100644 index 0000000..232b3a9 --- /dev/null +++ b/internal/wireguard/wireguard.go @@ -0,0 +1,433 @@ +package wireguard + +import ( + "net" + "sort" + "strings" + "sync" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/vishvananda/netlink" + + "github.com/h44z/wg-portal/internal/lowlevel" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" +) + +type WgCtrlManager struct { + mux sync.RWMutex // mutex to synchronize access to maps and external api clients + + // external api clients + wg lowlevel.WireGuardClient + nl lowlevel.NetlinkClient + + // optional persistent backend + store store + + // internal holder of interface configurations + interfaces map[persistence.InterfaceIdentifier]persistence.InterfaceConfig + // internal holder of peer configurations + peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig +} + +func (m *WgCtrlManager) GetInterfaces() ([]persistence.InterfaceConfig, error) { + m.mux.RLock() + defer m.mux.RUnlock() + interfaces := make([]persistence.InterfaceConfig, 0, len(m.interfaces)) + for _, iface := range interfaces { + interfaces = append(interfaces, iface) + } + // Order the interfaces by device name + sort.Slice(interfaces, func(i, j int) bool { + return interfaces[i].Identifier < interfaces[j].Identifier + }) + + return interfaces, nil +} + +func (m *WgCtrlManager) CreateInterface(id persistence.InterfaceIdentifier) error { + m.mux.Lock() + defer m.mux.Unlock() + if m.deviceExists(id) { + return errors.New("device already exists") + } + + err := m.createLowLevelInterface(id) + if err != nil { + return errors.WithMessage(err, "failed to create low level interface") + } + + newInterface := persistence.InterfaceConfig{Identifier: id} + m.interfaces[id] = newInterface + + err = m.persistInterface(id, false) + if err != nil { + return errors.WithMessage(err, "failed to persist created interface") + } + + return nil +} + +func (m *WgCtrlManager) DeleteInterface(id persistence.InterfaceIdentifier) error { + m.mux.Lock() + defer m.mux.Unlock() + + if !m.deviceExists(id) { + return errors.New("interface does not exist") + } + + err := m.nl.LinkDel(&netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(id), + }, + LinkType: "wireguard", + }) + if err != nil { + return errors.WithMessage(err, "failed to delete low level interface") + } + + err = m.persistInterface(id, true) + if err != nil { + return errors.WithMessage(err, "failed to persist deleted interface") + } + + delete(m.interfaces, id) + + return nil +} + +func (m *WgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error { + m.mux.Lock() + defer m.mux.Unlock() + if !m.deviceExists(id) { + return errors.New("interface does not exist") + } + cfg.Identifier = id // ensure that the same device name is set + + // Update net-link attributes + link, err := m.nl.LinkByName(string(id)) + if err != nil { + return errors.WithMessage(err, "failed to open low level interface") + } + if err := m.nl.LinkSetMTU(link, cfg.Mtu); err != nil { + return errors.WithMessage(err, "failed to set MTU") + } + addresses, err := parseIpAddressString(cfg.AddressStr) + for i := 0; i < len(addresses); i++ { + var err error + if i == 0 { + err = m.nl.AddrReplace(link, addresses[i]) + } else { + err = m.nl.AddrAdd(link, addresses[i]) + } + if err != nil { + return errors.WithMessage(err, "failed to set ip address") + } + } + + // Update WireGuard attributes + pKey, err := wgtypes.NewKey(GetPrivateKeyBytes(cfg.KeyPair)) + if err != nil { + return errors.WithMessage(err, "failed to parse private key bytes") + } + + var fwMark *int + if cfg.FirewallMark != 0 { + *fwMark = int(cfg.FirewallMark) + } + err = m.wg.ConfigureDevice(string(id), wgtypes.Config{ + PrivateKey: &pKey, + ListenPort: &cfg.ListenPort, + FirewallMark: fwMark, + }) + if err != nil { + return errors.WithMessage(err, "failed to update WireGuard settings") + } + + // Update link state + if cfg.Enabled { + if err := m.nl.LinkSetUp(link); err != nil { + return errors.WithMessage(err, "failed to enable low level interface") + } + } else { + if err := m.nl.LinkSetDown(link); err != nil { + return errors.WithMessage(err, "failed to disable low level interface") + } + } + + // update internal map + m.interfaces[id] = cfg + + err = m.persistInterface(id, false) + if err != nil { + return errors.WithMessage(err, "failed to persist updated interface") + } + + return nil +} + +func (m *WgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) { + m.mux.RLock() + defer m.mux.RUnlock() + if !m.deviceExists(interfaceId) { + return nil, errors.New("device does not exist") + } + + peers := make([]persistence.PeerConfig, 0, len(m.peers[interfaceId])) + for _, config := range m.peers[interfaceId] { + peers = append(peers, config) + } + + return peers, nil +} + +func (m *WgCtrlManager) SavePeers(peers ...persistence.PeerConfig) error { + m.mux.Lock() + defer m.mux.Unlock() + + for _, peer := range peers { + deviceId := peer.PeerInterfaceConfig.Identifier + if !m.deviceExists(deviceId) { + return errors.Errorf("device does not exist") + } + deviceConfig := m.interfaces[deviceId] + + wgPeer, err := getWireGuardPeerConfig(deviceConfig.Type, peer) + if err != nil { + return errors.WithMessagef(err, "could not generate WireGuard peer configuration for %s", peer.Identifier) + } + + err = m.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) + if err != nil { + return errors.Wrapf(err, "could not save peer %s to WireGuard device %s", peer.Identifier, deviceId) + } + + m.peers[deviceId][peer.Identifier] = peer + + err = m.persistPeer(peer.Identifier, false) + if err != nil { + return errors.Wrapf(err, "failed to persist updated peer %s", peer.Identifier) + } + } + + return nil +} + +func (m *WgCtrlManager) RemovePeer(id persistence.PeerIdentifier) error { + m.mux.Lock() + defer m.mux.Unlock() + + if !m.peerExists(id) { + return errors.Errorf("peer does not exist") + } + + peer, _ := m.getPeer(id) + deviceId := peer.PeerInterfaceConfig.Identifier + + publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey) + if err != nil { + return errors.WithMessage(err, "invalid public key") + } + + wgPeer := wgtypes.PeerConfig{ + PublicKey: publicKey, + Remove: true, + } + + err = m.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}}) + if err != nil { + return errors.WithMessage(err, "could not remove peer from WireGuard interface") + } + + err = m.persistPeer(id, true) + if err != nil { + return errors.WithMessage(err, "failed to persist deleted peer") + } + + delete(m.peers[deviceId], id) + + return nil +} + +// +// -- Helpers +// + +func (m *WgCtrlManager) createLowLevelInterface(id persistence.InterfaceIdentifier) error { + link := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(id), + }, + LinkType: "wireguard", + } + err := m.nl.LinkAdd(link) + if err != nil { + return errors.Wrapf(err, "failed to create netlink interface") + } + + if err := m.nl.LinkSetUp(link); err != nil { + return errors.Wrapf(err, "failed to enable netlink interface") + } + + return nil +} + +func (m *WgCtrlManager) deviceExists(id persistence.InterfaceIdentifier) bool { + if _, ok := m.interfaces[id]; ok { + return true + } + return false +} + +func (m *WgCtrlManager) persistInterface(id persistence.InterfaceIdentifier, delete bool) error { + if m.store == nil { + return nil // nothing to do + } + + device := m.interfaces[id] + peers := make([]persistence.PeerConfig, 0, len(m.peers[id])) + for _, config := range m.peers[id] { + peers = append(peers, config) + } + + var err error + if delete { + err = m.store.DeleteInterface(id) + } else { + err = m.store.SaveInterface(device, peers) + } + if err != nil { + return errors.Wrapf(err, "failed to persist interface") + } + + return nil +} + +func (m *WgCtrlManager) peerExists(id persistence.PeerIdentifier) bool { + for _, peers := range m.peers { + if _, ok := peers[id]; ok { + return true + } + } + + return false +} + +func (m *WgCtrlManager) persistPeer(id persistence.PeerIdentifier, delete bool) error { + if m.store == nil { + return nil // nothing to do + } + + var peer persistence.PeerConfig + for _, peers := range m.peers { + if p, ok := peers[id]; ok { + peer = p + break + } + } + + var err error + if delete { + err = m.store.DeletePeer(id, peer.PeerInterfaceConfig.Identifier) + } else { + err = m.store.SavePeer(peer, peer.PeerInterfaceConfig.Identifier) + } + if err != nil { + return errors.Wrapf(err, "failed to persist peer %s", id) + } + + return nil +} + +func (m *WgCtrlManager) getPeer(id persistence.PeerIdentifier) (persistence.PeerConfig, error) { + for _, peers := range m.peers { + if _, ok := peers[id]; ok { + return peers[id], nil + } + } + + return persistence.PeerConfig{}, errors.New("peer not found") +} + +func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { + rawAddresses := strings.Split(addrStr, ",") + addresses := make([]*netlink.Addr, 0, len(rawAddresses)) + for i := range rawAddresses { + rawAddress := strings.TrimSpace(rawAddresses[i]) + if rawAddress == "" { + continue // skip empty string + } + address, err := netlink.ParseAddr(rawAddress) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress) + } + addresses = append(addresses, address) + } + + return addresses, nil +} + +func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg persistence.PeerConfig) (wgtypes.PeerConfig, error) { + publicKey, err := wgtypes.ParseKey(cfg.KeyPair.PublicKey) + if err != nil { + return wgtypes.PeerConfig{}, errors.WithMessage(err, "invalid public key for peer") + } + + var presharedKey *wgtypes.Key + if tmpPresharedKey, err := wgtypes.ParseKey(cfg.PresharedKey); err == nil { + presharedKey = &tmpPresharedKey + } + + var endpoint *net.UDPAddr + if cfg.Endpoint.Value != "" && devType == persistence.InterfaceTypeClient { + addr, err := net.ResolveUDPAddr("udp", cfg.Endpoint.Value.(string)) + if err == nil { + endpoint = addr + } + } + + var keepAlive *time.Duration + if cfg.PersistentKeepalive.Value != 0 { + keepAliveDuration := time.Duration(cfg.PersistentKeepalive.Value.(int)) * time.Second + keepAlive = &keepAliveDuration + } + + allowedIPs := make([]net.IPNet, 0) + var peerAllowedIPs []*netlink.Addr + switch devType { + case persistence.InterfaceTypeClient: + peerAllowedIPs, err = parseIpAddressString(cfg.AllowedIPsStr.GetValue()) + if err != nil { + return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse allowed IP's") + } + case persistence.InterfaceTypeServer: + peerAllowedIPs, err = parseIpAddressString(cfg.AllowedIPsStr.GetValue()) + if err != nil { + return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse allowed IP's") + } + peerExtraAllowedIPs, err := parseIpAddressString(cfg.ExtraAllowedIPsStr) + if err != nil { + return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse extra allowed IP's") + } + + peerAllowedIPs = append(peerAllowedIPs, peerExtraAllowedIPs...) + } + for _, ip := range peerAllowedIPs { + allowedIPs = append(allowedIPs, *ip.IPNet) + } + + wgPeer := wgtypes.PeerConfig{ + PublicKey: publicKey, + Remove: false, + UpdateOnly: true, + PresharedKey: presharedKey, + Endpoint: endpoint, + PersistentKeepaliveInterval: keepAlive, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + } + + return wgPeer, nil +} diff --git a/internal/wireguard/wireguard_test.go b/internal/wireguard/wireguard_test.go new file mode 100644 index 0000000..d8c196b --- /dev/null +++ b/internal/wireguard/wireguard_test.go @@ -0,0 +1,328 @@ +//go:build !integration +// +build !integration + +package wireguard + +import ( + "net" + "reflect" + "sync" + "testing" + + "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" + "github.com/stretchr/testify/mock" + "github.com/vishvananda/netlink" +) + +func TestWgCtrlManager_CreateInterface(t *testing.T) { + tests := []struct { + name string + manager *WgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args persistence.InterfaceIdentifier + wantErr bool + }{ + { + name: "AlreadyExisting", + manager: &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: persistence.InterfaceIdentifier("wg0"), + wantErr: true, + }, + { + name: "LinkAddFailure", + manager: &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: nil, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkAdd", mock.Anything).Return(errors.New("failure")) + }, + args: persistence.InterfaceIdentifier("wg0"), + wantErr: true, + }, + { + name: "LinkSetupFailure", + manager: &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: nil, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkAdd", mock.Anything).Return(nil) + nl.On("LinkSetUp", mock.Anything).Return(errors.New("failure")) + }, + args: persistence.InterfaceIdentifier("wg0"), + wantErr: true, + }, + { + name: "PersistenceFailure", + manager: &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkAdd", mock.Anything).Return(nil) + nl.On("LinkSetUp", mock.Anything).Return(nil) + st.On("SaveInterface", mock.Anything, mock.Anything).Return(errors.New("failure")) + }, + args: persistence.InterfaceIdentifier("wg0"), + wantErr: true, + }, + { + name: "Success", + manager: &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkAdd", mock.Anything).Return(nil) + nl.On("LinkSetUp", mock.Anything).Return(nil) + st.On("SaveInterface", mock.Anything, mock.Anything).Return(nil) + }, + args: persistence.InterfaceIdentifier("wg0"), + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.CreateInterface(tt.args); (err != nil) != tt.wantErr { + t.Errorf("CreateInterface() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func TestWgCtrlManager_DeleteInterface(t *testing.T) { + tests := []struct { + name string + manager *WgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args persistence.InterfaceIdentifier + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.DeleteInterface(tt.args); (err != nil) != tt.wantErr { + t.Errorf("DeleteInterface() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func TestWgCtrlManager_GetInterfaces(t *testing.T) { + tests := []struct { + name string + manager *WgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + want []persistence.InterfaceConfig + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + got, err := tt.manager.GetInterfaces() + if (err != nil) != tt.wantErr { + t.Errorf("GetInterfaces() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetInterfaces() got = %v, want %v", got, tt.want) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func TestWgCtrlManager_UpdateInterface(t *testing.T) { + type args struct { + id persistence.InterfaceIdentifier + cfg persistence.InterfaceConfig + } + tests := []struct { + name string + manager *WgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args args + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.UpdateInterface(tt.args.id, tt.args.cfg); (err != nil) != tt.wantErr { + t.Errorf("UpdateInterface() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func Test_parseIpAddressString(t *testing.T) { + type args struct { + addrStr string + } + var tests = []struct { + name string + args args + want []*netlink.Addr + wantErr bool + }{ + { + name: "Empty String", + args: args{}, + want: []*netlink.Addr{}, + wantErr: false, + }, + { + name: "Single IPv4", + args: args{addrStr: "123.123.123.123"}, + want: nil, + wantErr: true, + }, + { + name: "Malformed", + args: args{addrStr: "hello world"}, + want: nil, + wantErr: true, + }, + { + name: "Single IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24, 200.201.202.203/16"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }}, + wantErr: false, + }, + { + name: "Single IPv6 CIDR", + args: args{addrStr: "fe80::1/64"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv6 CIDR", + args: args{addrStr: "fe80::1/64 , 2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }}, + wantErr: false, + }, + { + name: "Mixed IPv4 and IPv6 CIDR", + args: args{addrStr: "200.201.202.203/16,2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseIpAddressString(tt.args.addrStr) + if (err != nil) != tt.wantErr { + t.Errorf("parseIpAddressString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseIpAddressString() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tmp/lowlevel/doc.go b/tmp/lowlevel/doc.go deleted file mode 100644 index 057cf03..0000000 --- a/tmp/lowlevel/doc.go +++ /dev/null @@ -1,6 +0,0 @@ -package lowlevel - -/** -This package contains wrappers for low level api's like netlink or the WireGuard control library. -Wrapping those external libraries makes mocking and testing code easier. -*/ diff --git a/tmp/lowlevel/netlink.go b/tmp/lowlevel/netlink.go deleted file mode 100644 index b86d13e..0000000 --- a/tmp/lowlevel/netlink.go +++ /dev/null @@ -1,63 +0,0 @@ -package lowlevel - -import ( - "github.com/vishvananda/netlink" -) - -// A NetlinkClient is a type which can control a netlink device. -type NetlinkClient interface { - LinkAdd(link netlink.Link) error - LinkDel(link netlink.Link) error - LinkByName(name string) (netlink.Link, error) - LinkSetUp(link netlink.Link) error - LinkSetDown(link netlink.Link) error - LinkSetMTU(link netlink.Link, mtu int) error - AddrReplace(link netlink.Link, addr *netlink.Addr) error - AddrAdd(link netlink.Link, addr *netlink.Addr) error - AddrList(link netlink.Link) ([]netlink.Addr, error) -} - -type NetlinkManager struct { -} - -func (n NetlinkManager) LinkAdd(link netlink.Link) error { return netlink.LinkAdd(link) } - -func (n NetlinkManager) LinkDel(link netlink.Link) error { return netlink.LinkDel(link) } - -func (n NetlinkManager) LinkByName(name string) (netlink.Link, error) { - return netlink.LinkByName(name) -} - -func (n NetlinkManager) LinkSetUp(link netlink.Link) error { return netlink.LinkSetUp(link) } - -func (n NetlinkManager) LinkSetDown(link netlink.Link) error { return netlink.LinkSetDown(link) } - -func (n NetlinkManager) LinkSetMTU(link netlink.Link, mtu int) error { - return netlink.LinkSetMTU(link, mtu) -} - -func (n NetlinkManager) AddrReplace(link netlink.Link, addr *netlink.Addr) error { - return netlink.AddrReplace(link, addr) -} - -func (n NetlinkManager) AddrAdd(link netlink.Link, addr *netlink.Addr) error { - return netlink.AddrAdd(link, addr) -} - -func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) { - listIPv4, err := netlink.AddrList(link, netlink.FAMILY_V4) - if err != nil { - return nil, err - } - - listIPv6, err := netlink.AddrList(link, netlink.FAMILY_V6) - if err != nil { - return nil, err - } - - ipAddresses := make([]netlink.Addr, 0, len(listIPv4)+len(listIPv6)) - ipAddresses = append(ipAddresses, listIPv4...) - ipAddresses = append(ipAddresses, listIPv6...) - - return ipAddresses, nil -} diff --git a/tmp/lowlevel/wgctrl.go b/tmp/lowlevel/wgctrl.go deleted file mode 100644 index ab6832c..0000000 --- a/tmp/lowlevel/wgctrl.go +++ /dev/null @@ -1,15 +0,0 @@ -package lowlevel - -import ( - "io" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// A WireGuardClient is a type which can control a WireGuard device. -type WireGuardClient interface { - io.Closer - Devices() ([]*wgtypes.Device, error) - Device(name string) (*wgtypes.Device, error) - ConfigureDevice(name string, cfg wgtypes.Config) error -} diff --git a/tmp/persistence/database.go b/tmp/persistence/database.go deleted file mode 100644 index dc7cf83..0000000 --- a/tmp/persistence/database.go +++ /dev/null @@ -1 +0,0 @@ -package persistence diff --git a/tmp/persistence/ldap.go b/tmp/persistence/ldap.go deleted file mode 100644 index 90c28b3..0000000 --- a/tmp/persistence/ldap.go +++ /dev/null @@ -1,4 +0,0 @@ -package persistence - -type LdapLoader interface { -} diff --git a/tmp/persistence/models.go b/tmp/persistence/models.go deleted file mode 100644 index 3b37212..0000000 --- a/tmp/persistence/models.go +++ /dev/null @@ -1,157 +0,0 @@ -package persistence - -import ( - "database/sql" - "time" - - "gorm.io/gorm" -) - -type BaseModel struct { - CreatedBy string - UpdatedBy string - CreatedAt time.Time - UpdatedAt time.Time - DisabledAt sql.NullTime -} - -type InterfaceIdentifier string -type PeerIdentifier string -type UserIdentifier string - -type KeyPair struct { - PrivateKey string - PublicKey string -} - -type PreSharedKey string - -type InterfaceType string - -const ( - InterfaceTypeServer InterfaceType = "server" - InterfaceTypeClient InterfaceType = "client" -) - -type InterfaceConfig struct { - BaseModel - - // WireGuard specific (for the [interface] section of the config file) - - Identifier InterfaceIdentifier // device name, for example: wg0 - KeyPair KeyPair // private/public Key of the server interface - ListenPort int // the listening port, for example: 51820 - - AddressStr string // the interface ip addresses, comma separated - DnsStr string // the dns server that should be set if the interface is up, comma separated - - Mtu int // the device MTU - FirewallMark int32 // a firewall mark - RoutingTable string // the routing table - - PreUp string // action that is executed before the device is up - PostUp string // action that is executed after the device is up - PreDown string // action that is executed before the device is down - PostDown string // action that is executed after the device is down - - SaveConfig bool // automatically persist config changes to the wgX.conf file - - // WG Portal specific - Enabled bool // flag that specifies if the interface is enabled (up) or nor (down) - DisplayName string // a nice display name/ description for the interface - Type InterfaceType // the interface type, either InterfaceTypeServer or InterfaceTypeClient - DriverType string // the interface driver type (linux, software, ...) - - // Default settings for the peer, used for new peers, those settings will be published to ConfigOption options of - // the peer config - - PeerDefNetworkStr string // the default subnets from which peers will get their IP addresses, comma seperated - PeerDefDnsStr string // the default dns server for the peer - PeerDefEndpoint string // the default endpoint for the peer - PeerDefAllowedIPsStr string // the default allowed IP string for the peer - PeerDefMtu int // the default device MTU - PeerDefPersistentKeepalive int // the default persistent keep-alive Value - PeerDefFirewallMark int32 // default firewall mark - PeerDefRoutingTable string // the default routing table - - PeerDefPreUp string // default action that is executed before the device is up - PeerDefPostUp string // default action that is executed after the device is up - PeerDefPreDown string // default action that is executed before the device is down - PeerDefPostDown string // default action that is executed after the device is down -} - -type PeerInterfaceConfig struct { - AddressStr StringConfigOption // the interface ip addresses, comma separated - DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated - Mtu IntConfigOption // the device MTU - FirewallMark Int32ConfigOption // a firewall mark - RoutingTable StringConfigOption // the routing table - - PreUp StringConfigOption // action that is executed before the device is up - PostUp StringConfigOption // action that is executed after the device is up - PreDown StringConfigOption // action that is executed before the device is down - PostDown StringConfigOption // action that is executed after the device is down -} - -type PeerConfig struct { - BaseModel - - // WireGuard specific (for the [peer] section of the config file) - - Endpoint StringConfigOption // the endpoint address - AllowedIPsStr StringConfigOption // all allowed ip subnets, comma seperated - ExtraAllowedIPsStr string // all allowed ip subnets on the server side, comma seperated - KeyPair KeyPair // private/public Key of the peer - PresharedKey string // the pre-shared Key of the peer - PersistentKeepalive IntConfigOption // the persistent keep-alive interval - - // WG Portal specific - - DisplayName string // a nice display name/ description for the peer - Identifier PeerIdentifier // peer unique identifier - UserIdentifier UserIdentifier // the owner - - // Interface settings for the peer, used to generate the [interface] section in the peer config file - PeerInterfaceConfig -} - -type UserSource string - -const ( - UserSourceLdap UserSource = "ldap" // LDAP / ActiveDirectory - UserSourceDatabase UserSource = "db" // sqlite / mysql database - UserSourceOIDC UserSource = "oidc" // open id connect, TODO: implement -) - -type PrivateString string - -func (PrivateString) MarshalJSON() ([]byte, error) { - return []byte(`""`), nil -} - -func (PrivateString) String() string { - return "" -} - -// User is the user model that gets linked to peer entries, by default an empty user model with only the email address is created -type User struct { - // required fields - Uid UserIdentifier `gorm:"primaryKey"` - Email string `form:"email" binding:"required,email"` - Source UserSource - IsAdmin bool - - // optional fields - Firstname string `form:"firstname" binding:"omitempty"` - Lastname string `form:"lastname" binding:"omitempty"` - Phone string `form:"phone" binding:"omitempty"` - Department string `form:"department" binding:"omitempty"` - - // optional, integrated password authentication - Password PrivateString `form:"password" binding:"omitempty"` - - // database internal fields - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty" swaggertype:"string"` -} diff --git a/tmp/persistence/options.go b/tmp/persistence/options.go deleted file mode 100644 index 683eae5..0000000 --- a/tmp/persistence/options.go +++ /dev/null @@ -1,81 +0,0 @@ -package persistence - -// ConfigOption is an Overridable configuration option -type ConfigOption struct { - Value interface{} - Overridable bool -} - -type StringConfigOption struct { - ConfigOption -} - -func (o StringConfigOption) GetValue() string { - if o.Value == nil { - return "" - } - return o.Value.(string) -} - -func NewStringConfigOption(value string, overridable bool) StringConfigOption { - return StringConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type IntConfigOption struct { - ConfigOption -} - -func (o IntConfigOption) GetValue() int { - if o.Value == nil { - return 0 - } - return o.Value.(int) -} - -func NewIntConfigOption(value int, overridable bool) IntConfigOption { - return IntConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type Int32ConfigOption struct { - ConfigOption -} - -func (o Int32ConfigOption) GetValue() int32 { - if o.Value == nil { - return 0 - } - - return o.Value.(int32) -} - -func NewInt32ConfigOption(value int32, overridable bool) Int32ConfigOption { - return Int32ConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} - -type BoolConfigOption struct { - ConfigOption -} - -func (o BoolConfigOption) GetValue() bool { - if o.Value == nil { - return false - } - - return o.Value.(bool) -} - -func NewBoolConfigOption(value bool, overridable bool) BoolConfigOption { - return BoolConfigOption{ConfigOption{ - Value: value, - Overridable: overridable, - }} -} diff --git a/tmp/persistence/users.go b/tmp/persistence/users.go deleted file mode 100644 index 8f3d56a..0000000 --- a/tmp/persistence/users.go +++ /dev/null @@ -1,19 +0,0 @@ -package persistence - -import "gorm.io/gorm" - -type UserFilterCondition func(tx *gorm.DB) - -type UsersLoader interface { - GetUser(id UserIdentifier) (User, error) - GetUsers() ([]User, error) - GetUsersUnscoped() ([]User, error) - GetUsersFiltered(filter ...UserFilterCondition) ([]User, error) -} - -type Users interface { - UsersLoader - - SaveUser(user User) error - DeleteUser(identifier UserIdentifier) error -} diff --git a/tmp/persistence/wireguard.go b/tmp/persistence/wireguard.go deleted file mode 100644 index 9f320f1..0000000 --- a/tmp/persistence/wireguard.go +++ /dev/null @@ -1,14 +0,0 @@ -package persistence - -type WireGuard interface { - GetAvailableInterfaces() ([]InterfaceIdentifier, error) - - GetAllInterfaces(interfaceIdentifiers ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) - GetInterface(identifier InterfaceIdentifier) (InterfaceConfig, []PeerConfig, error) - - SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error - SavePeer(peer PeerConfig, interfaceIdentifier InterfaceIdentifier) error - - DeleteInterface(identifier InterfaceIdentifier) error - DeletePeer(peer PeerIdentifier, interfaceIdentifier InterfaceIdentifier) error -} diff --git a/tmp/portal/api.go b/tmp/portal/api.go deleted file mode 100644 index 8d7996b..0000000 --- a/tmp/portal/api.go +++ /dev/null @@ -1 +0,0 @@ -package portal diff --git a/tmp/portal/web.go b/tmp/portal/web.go deleted file mode 100644 index 8d7996b..0000000 --- a/tmp/portal/web.go +++ /dev/null @@ -1 +0,0 @@ -package portal diff --git a/tmp/user/authentication.go b/tmp/user/authentication.go deleted file mode 100644 index b200aa5..0000000 --- a/tmp/user/authentication.go +++ /dev/null @@ -1,4 +0,0 @@ -package user - -type Authenticator interface { -} diff --git a/tmp/user/manager.go b/tmp/user/manager.go deleted file mode 100644 index 1bfd1fc..0000000 --- a/tmp/user/manager.go +++ /dev/null @@ -1,7 +0,0 @@ -package user - -import "github.com/h44z/wg-portal/tmp/persistence" - -type Manager interface { - persistence.UsersLoader -} diff --git a/tmp/wireguard/manager.go b/tmp/wireguard/manager.go deleted file mode 100644 index 0df314e..0000000 --- a/tmp/wireguard/manager.go +++ /dev/null @@ -1,50 +0,0 @@ -package wireguard - -import ( - "io" - - "github.com/h44z/wg-portal/tmp/persistence" -) - -type KeyGenerator interface { - GetFreshKeypair() (persistence.KeyPair, error) - GetPreSharedKey() (persistence.PreSharedKey, error) -} - -// InterfaceManager provides methods to create/update/delete physical WireGuard devices. -type InterfaceManager interface { - GetInterfaces() ([]persistence.InterfaceConfig, error) - CreateInterface(id persistence.InterfaceIdentifier) error - DeleteInterface(id persistence.InterfaceIdentifier) error - UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error -} - -type ImportableInterface struct { - persistence.InterfaceConfig - ImportLocation string - ImportType string -} - -type ImportManager interface { - GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) - ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) -} - -type ConfigFileGenerator interface { - GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) - GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) -} - -type PeerManager interface { - GetPeers(device persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) - SavePeers(device persistence.InterfaceIdentifier, peers ...persistence.PeerConfig) error - RemovePeer(device persistence.InterfaceIdentifier, peer persistence.PeerIdentifier) error -} - -type Manager interface { - KeyGenerator - InterfaceManager - PeerManager - ImportManager - ConfigFileGenerator -} diff --git a/tmp/wireguard/persistence.go b/tmp/wireguard/persistence.go deleted file mode 100644 index 0b280a9..0000000 --- a/tmp/wireguard/persistence.go +++ /dev/null @@ -1 +0,0 @@ -package wireguard diff --git a/tmp/wireguard/template.go b/tmp/wireguard/template.go deleted file mode 100644 index e021d08..0000000 --- a/tmp/wireguard/template.go +++ /dev/null @@ -1,66 +0,0 @@ -package wireguard - -import ( - "bytes" - "embed" - "io" - "text/template" - - "github.com/h44z/wg-portal/tmp/persistence" - - "github.com/pkg/errors" -) - -//go:embed tpl_files/* -var TemplateFiles embed.FS - -type TemplateHandler struct { - templates *template.Template -} - -func NewTemplateHandler() (*TemplateHandler, error) { - templateCache, err := template.New("WireGuard").ParseFS(TemplateFiles, "tpl_files/*.tpl") - if err != nil { - return nil, errors.Wrapf(err, "failed to parse template files") - } - - handler := &TemplateHandler{ - templates: templateCache, - } - - return handler, nil -} - -func (c TemplateHandler) GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) { - var tplBuff bytes.Buffer - - err := c.templates.ExecuteTemplate(&tplBuff, "interface.tpl", map[string]interface{}{ - "Interface": cfg, - "Peers": peers, - "Portal": map[string]interface{}{ - "Version": "unknown", - }, - }) - if err != nil { - return nil, errors.Wrapf(err, "failed to execute interface template for %s", cfg.Identifier) - } - - return &tplBuff, nil -} - -func (c TemplateHandler) GetPeerConfig(peer persistence.PeerConfig, iface persistence.InterfaceConfig) (io.Reader, error) { - var tplBuff bytes.Buffer - - err := c.templates.ExecuteTemplate(&tplBuff, "peer.tpl", map[string]interface{}{ - "Peer": peer, - "Interface": iface, - "Portal": map[string]interface{}{ - "Version": "unknown", - }, - }) - if err != nil { - return nil, errors.Wrapf(err, "failed to execute peer template for %s", peer.Uid) - } - - return &tplBuff, nil -} diff --git a/tmp/wireguard/tpl_files/interface.tpl b/tmp/wireguard/tpl_files/interface.tpl deleted file mode 100644 index 2abb63d..0000000 --- a/tmp/wireguard/tpl_files/interface.tpl +++ /dev/null @@ -1,82 +0,0 @@ -# AUTOGENERATED FILE - DO NOT EDIT -# This file uses wg-quick format. See https://man7.org/linux/man-pages/man8/wg-quick.8.html#CONFIGURATION - -# -WGP- WIREGUARD PORTAL CONFIGURATION FILE, version {{ .Portal.Version }} -# Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. - -[Interface] -# -WGP- Interface: {{ .Interface.DeviceName }} | Updated: {{ .Interface.UpdatedAt }} | Created: {{ .Interface.CreatedAt }} -# -WGP- Display name: {{ .Interface.DisplayName }} -# -WGP- Interface mode: {{ .Interface.Type }} -# -WGP- PublicKey = {{ .Interface.KeyPair.PublicKey }} - -# Core settings -PrivateKey = {{ .Interface.KeyPair.PrivateKey }} -Address = {{ .Interface.AddressStr }} - -# Misc. settings (optional) -{{- if ne .Interface.ListenPort 0}} -ListenPort = {{ .Interface.ListenPort }} -{{- end}} -{{- if ne .Interface.Mtu 0}} -MTU = {{.Interface.Mtu}} -{{- end}} -{{- if and (ne .Interface.DnsStr "") (eq $.Interface.Type "client")}} -DNS = {{ .Interface.DnsStr }} -{{- end}} -{{- if ne .Interface.FirewallMark 0}} -FwMark = {{.Interface.FirewallMark}} -{{- end}} -{{- if ne .Interface.RoutingTable ""}} -Table = {{.Interface.RoutingTable}} -{{- end}} -{{- if .Interface.SaveConfig}} -SaveConfig = true -{{- end}} - -# Interface hooks (optional) -{{- if .Interface.PreUp}} -PreUp = {{ .Interface.PreUp }} -{{- end}} -{{- if .Interface.PostUp}} -PostUp = {{ .Interface.PostUp }} -{{- end}} -{{- if .Interface.PreDown}} -PreDown = {{ .Interface.PreDown }} -{{- end}} -{{- if .Interface.PostDown}} -PostDown = {{ .Interface.PostDown }} -{{- end}} - -# -# Peers -# - -{{range .Peers}} -{{- if not .DisabledAt}} -[Peer] -# -WGP- Peer: {{.Uid}} | Updated: {{.UpdatedAt}} | Created: {{.CreatedAt}} -# -WGP- Display name: {{ .Identifier }} -{{- if .KeyPair.PrivateKey}} -# -WGP- PrivateKey: {{.KeyPair.PrivateKey}} -{{- end}} -PublicKey = {{ .KeyPair.PublicKey }} -{{- if .PresharedKey}} -PresharedKey = {{ .PresharedKey }} -{{- end}} -{{- if eq $.Interface.Type "server"}} -AllowedIPs = {{ .AddressStr }}{{if ne .ExtraAllowedIPsStr ""}}, {{ .ExtraAllowedIPsStr }}{{end}} -{{- end}} -{{- if eq $.Interface.Type "client"}} -{{- if .AllowedIPsStr}} -AllowedIPs = {{ .AllowedIPsStr }} -{{- end}} -{{- end}} -{{- if and (ne .Endpoint "") (eq $.Interface.Type "client")}} -Endpoint = {{ .Endpoint }} -{{- end}} -{{- if ne .PersistentKeepalive 0}} -PersistentKeepalive = {{ .PersistentKeepalive }} -{{- end}} -{{- end}} -{{end}} \ No newline at end of file diff --git a/tmp/wireguard/tpl_files/peer.tpl b/tmp/wireguard/tpl_files/peer.tpl deleted file mode 100644 index 2278023..0000000 --- a/tmp/wireguard/tpl_files/peer.tpl +++ /dev/null @@ -1,60 +0,0 @@ -# AUTOGENERATED FILE - DO NOT EDIT -# This file uses wg-quick format. See https://man7.org/linux/man-pages/man8/wg-quick.8.html#CONFIGURATION - -# -WGP- WIREGUARD PORTAL CONFIGURATION FILE, version {{ .Portal.Version }} -# Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. - -[Interface] -# -WGP- Peer: {{.Peer.Uid}} | Updated: {{.Peer.UpdatedAt}} | Created: {{.Peer.CreatedAt}} -# -WGP- Display name: {{ .Peer.Identifier }} -# -WGP- PublicKey: {{ .Peer.KeyPair.PublicKey }} -{{- if eq $.Interface.Type "server"}} -# -WGP- Peer type: client -{{else}} -# -WGP- Peer type: server -{{- end}} - -# Core settings -PrivateKey = {{ .Peer.KeyPair.PrivateKey }} -Address = {{ .Peer.AddressStr.GetValue }} - -# Misc. settings (optional) -{{- if .Peer.DnsStr.GetValue}} -DNS = {{ .Peer.DnsStr.GetValue }} -{{- end}} -{{- if ne .Peer.Mtu.GetValue 0}} -MTU = {{ .Peer.Mtu.GetValue }} -{{- end}} -{{- if ne .Peer.FirewallMark.GetValue 0}} -FwMark = {{ .Peer.FirewallMark.GetValue }} -{{- end}} -{{- if ne .Peer.RoutingTable.GetValue ""}} -Table = {{ .Peer.RoutingTable.GetValue }} -{{- end}} - -# Interface hooks (optional) -{{- if .Peer.PreUp.GetValue}} -PreUp = {{ .Peer.PreUp.GetValue }} -{{- end}} -{{- if .Peer.PostUp.GetValue}} -PostUp = {{ .Peer.PostUp.GetValue }} -{{- end}} -{{- if .Peer.PreDown.GetValue}} -PreDown = {{ .Peer.PreDown.GetValue }} -{{- end}} -{{- if .Peer.PostDown.GetValue}} -PostDown = {{ .Peer.PostDown.GetValue }} -{{- end}} - -[Peer] -PublicKey = {{ .Interface.KeyPair.PublicKey }} -Endpoint = {{ .Peer.Endpoint.GetValue }} -{{- if .Peer.AllowedIPsStr.GetValue}} -AllowedIPs = {{ .Peer.AllowedIPsStr.GetValue }} -{{- end}} -{{- if .Peer.PresharedKey}} -PresharedKey = {{ .Peer.PresharedKey }} -{{- end}} -{{- if ne .Peer.PersistentKeepalive.GetValue 0}} -PersistentKeepalive = {{ .Peer.PersistentKeepalive.GetValue }} -{{- end}} \ No newline at end of file