diff --git a/internal/persistence/models.go b/internal/persistence/models.go index 2238f1d..1d4387b 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -81,12 +81,15 @@ } type PeerInterfaceConfig struct { - Identifier InterfaceIdentifier // the interface identifier - AddressStr StringConfigOption // the interface ip addresses, comma separated - DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated - Mtu IntConfigOption // the device MTU - FirewallMark Int32ConfigOption // a firewall mark - RoutingTable StringConfigOption // the routing table + Identifier InterfaceIdentifier // the interface identifier + Type InterfaceType // the interface type + PublicKey string // the interface public key + + AddressStr StringConfigOption // the interface ip addresses, comma separated + DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated + Mtu IntConfigOption // the device MTU + FirewallMark Int32ConfigOption // a firewall mark + RoutingTable StringConfigOption // the routing table PreUp StringConfigOption // action that is executed before the device is up PostUp StringConfigOption // action that is executed after the device is up @@ -113,7 +116,7 @@ UserIdentifier UserIdentifier // the owner // Interface settings for the peer, used to generate the [interface] section in the peer config file - PeerInterfaceConfig + Interface *PeerInterfaceConfig } type UserSource string diff --git a/internal/persistence/options.go b/internal/persistence/options.go index 683eae5..f8031d0 100644 --- a/internal/persistence/options.go +++ b/internal/persistence/options.go @@ -17,6 +17,18 @@ return o.Value.(string) } +func (o *StringConfigOption) SetValue(value string) { + o.Value = value +} + +func (o *StringConfigOption) TrySetValue(value string) bool { + if o.Overridable { + o.Value = value + return true + } + return false +} + func NewStringConfigOption(value string, overridable bool) StringConfigOption { return StringConfigOption{ConfigOption{ Value: value, @@ -35,6 +47,18 @@ return o.Value.(int) } +func (o *IntConfigOption) SetValue(value int) { + o.Value = value +} + +func (o *IntConfigOption) TrySetValue(value int) bool { + if o.Overridable { + o.Value = value + return true + } + return false +} + func NewIntConfigOption(value int, overridable bool) IntConfigOption { return IntConfigOption{ConfigOption{ Value: value, @@ -54,6 +78,18 @@ return o.Value.(int32) } +func (o *Int32ConfigOption) SetValue(value int32) { + o.Value = value +} + +func (o *Int32ConfigOption) TrySetValue(value int32) bool { + if o.Overridable { + o.Value = value + return true + } + return false +} + func NewInt32ConfigOption(value int32, overridable bool) Int32ConfigOption { return Int32ConfigOption{ConfigOption{ Value: value, @@ -73,6 +109,18 @@ return o.Value.(bool) } +func (o *BoolConfigOption) SetValue(value bool) { + o.Value = value +} + +func (o *BoolConfigOption) TrySetValue(value bool) bool { + if o.Overridable { + o.Value = value + return true + } + return false +} + func NewBoolConfigOption(value bool, overridable bool) BoolConfigOption { return BoolConfigOption{ConfigOption{ Value: value, diff --git a/internal/user/manager.go b/internal/user/manager.go index 8bd2dda..d3744b6 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -9,7 +9,7 @@ GetUser(id persistence.UserIdentifier) (persistence.User, error) GetActiveUsers() ([]persistence.User, error) GetAllUsers() ([]persistence.User, error) - GetFilteredUsers(filter ...filterCondition) ([]persistence.User, error) + GetFilteredUsers(filter ...persistence.DatabaseFilterCondition) ([]persistence.User, error) } type Updater interface { @@ -72,7 +72,7 @@ return p.store.GetUsersUnscoped() } -func (p *PersistentManager) GetFilteredUsers(filter ...filterCondition) ([]persistence.User, error) { +func (p *PersistentManager) GetFilteredUsers(filter ...persistence.DatabaseFilterCondition) ([]persistence.User, error) { return p.store.GetUsersFiltered(filter...) } diff --git a/internal/wireguard/keys.go b/internal/wireguard/keys.go index 50789d8..b95f25d 100644 --- a/internal/wireguard/keys.go +++ b/internal/wireguard/keys.go @@ -4,7 +4,6 @@ "encoding/base64" "github.com/h44z/wg-portal/internal/persistence" - "github.com/pkg/errors" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) diff --git a/internal/wireguard/keys_test.go b/internal/wireguard/keys_test.go index 4c828f9..ad5e64b 100644 --- a/internal/wireguard/keys_test.go +++ b/internal/wireguard/keys_test.go @@ -4,7 +4,6 @@ "testing" "github.com/h44z/wg-portal/internal/persistence" - "github.com/stretchr/testify/assert" ) diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index d16545e..6dc27d4 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -3,6 +3,8 @@ import ( "io" + "github.com/vishvananda/netlink" + "github.com/pkg/errors" "github.com/h44z/wg-portal/internal/lowlevel" @@ -17,10 +19,11 @@ // InterfaceManager provides methods to create/update/delete physical WireGuard devices. type InterfaceManager interface { - GetInterfaces() ([]persistence.InterfaceConfig, error) + GetInterfaces() ([]*persistence.InterfaceConfig, error) CreateInterface(id persistence.InterfaceIdentifier) error DeleteInterface(id persistence.InterfaceIdentifier) error - UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error + UpdateInterface(id persistence.InterfaceIdentifier, cfg *persistence.InterfaceConfig) error + ApplyDefaultConfigs(id persistence.InterfaceIdentifier) error } type ImportableInterface struct { @@ -30,29 +33,34 @@ } type ImportManager interface { - GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) - ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error + GetImportableInterfaces() (map[*ImportableInterface][]*persistence.PeerConfig, error) + ImportInterface(cfg *ImportableInterface, peers []*persistence.PeerConfig) error } type ConfigFileGenerator interface { - GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) - GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) + GetInterfaceConfig(cfg *persistence.InterfaceConfig, peers []*persistence.PeerConfig) (io.Reader, error) + GetPeerConfig(peer *persistence.PeerConfig) (io.Reader, error) } type PeerManager interface { - GetPeers(device persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) - SavePeers(peers ...persistence.PeerConfig) error + GetPeers(device persistence.InterfaceIdentifier) ([]*persistence.PeerConfig, error) + SavePeers(peers ...*persistence.PeerConfig) error RemovePeer(peer persistence.PeerIdentifier) error } +type IpManager interface { + GetAllUsedIPs(device persistence.InterfaceIdentifier) ([]*netlink.Addr, error) + GetUsedIPs(device persistence.InterfaceIdentifier, subnetCidr string) ([]*netlink.Addr, error) + GetFreshIp(device persistence.InterfaceIdentifier, subnetCidr string, increment ...bool) (*netlink.Addr, error) +} + type Manager interface { KeyGenerator InterfaceManager PeerManager + IpManager ImportManager ConfigFileGenerator - - ApplyDefaultConfigs(device persistence.InterfaceIdentifier) error } // @@ -84,8 +92,3 @@ return m, nil } - -func (p *PersistentManager) ApplyDefaultConfigs(device persistence.InterfaceIdentifier) error { - // TODO: implement - return nil -} diff --git a/internal/wireguard/template.go b/internal/wireguard/template.go index 3718e36..510e91e 100644 --- a/internal/wireguard/template.go +++ b/internal/wireguard/template.go @@ -30,7 +30,7 @@ return handler, nil } -func (c templateHandler) GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) { +func (c templateHandler) GetInterfaceConfig(cfg *persistence.InterfaceConfig, peers []*persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "interface.tpl", map[string]interface{}{ @@ -47,12 +47,11 @@ return &tplBuff, nil } -func (c templateHandler) GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) { +func (c templateHandler) GetPeerConfig(peer *persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "peer.tpl", map[string]interface{}{ - "Peer": peer, - "Interface": peer.PeerInterfaceConfig, + "Peer": peer, "Portal": map[string]interface{}{ "Version": "unknown", }, diff --git a/internal/wireguard/template_test.go b/internal/wireguard/template_test.go index 5a8c86e..c2ddddd 100644 --- a/internal/wireguard/template_test.go +++ b/internal/wireguard/template_test.go @@ -19,8 +19,8 @@ func TestTemplateHandler_GetInterfaceConfig(t *testing.T) { type args struct { - cfg persistence.InterfaceConfig - peers []persistence.PeerConfig + cfg *persistence.InterfaceConfig + peers []*persistence.PeerConfig } tests := []struct { name string @@ -30,7 +30,9 @@ }{ { name: "All Empty", - args: args{}, + args: args{ + cfg: &persistence.InterfaceConfig{Identifier: "test0"}, + }, want: bytes.NewBuffer([]byte(`# AUTOGENERATED FILE - DO NOT EDIT # This file uses wg-quick format. See https://man7.org/linux/man-pages/man8/wg-quick.8.html#CONFIGURATION @@ -38,7 +40,7 @@ # Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. [Interface] -# -WGP- Interface: | Updated: 0001-01-01 00:00:00 +0000 UTC | Created: 0001-01-01 00:00:00 +0000 UTC +# -WGP- Interface: test0 | Updated: 0001-01-01 00:00:00 +0000 UTC | Created: 0001-01-01 00:00:00 +0000 UTC # -WGP- Display name: # -WGP- Interface mode: # -WGP- PublicKey = @@ -77,8 +79,7 @@ func TestTemplateHandler_GetPeerConfig(t *testing.T) { type args struct { - peer persistence.PeerConfig - iface persistence.InterfaceConfig + peer *persistence.PeerConfig } tests := []struct { name string @@ -88,7 +89,9 @@ }{ { name: "All empty", - args: args{}, + args: args{ + peer: &persistence.PeerConfig{Identifier: "peer0", Interface: &persistence.PeerInterfaceConfig{}}, + }, want: bytes.NewBuffer([]byte(`# AUTOGENERATED FILE - DO NOT EDIT # This file uses wg-quick format. See https://man7.org/linux/man-pages/man8/wg-quick.8.html#CONFIGURATION @@ -96,7 +99,7 @@ # Lines starting with the -WGP- tag are used by the WireGuard Portal configuration parser. [Interface] -# -WGP- Peer: | Updated: 0001-01-01 00:00:00 +0000 UTC | Created: 0001-01-01 00:00:00 +0000 UTC +# -WGP- Peer: peer0 | Updated: 0001-01-01 00:00:00 +0000 UTC | Created: 0001-01-01 00:00:00 +0000 UTC # -WGP- Display name: # -WGP- PublicKey: # -WGP- Peer type: server diff --git a/internal/wireguard/test_helpers_test.go b/internal/wireguard/test_helpers_test.go index 686d6c5..dbb3d4b 100644 --- a/internal/wireguard/test_helpers_test.go +++ b/internal/wireguard/test_helpers_test.go @@ -111,8 +111,8 @@ return args.Get(0).(persistence.InterfaceConfig), args.Get(1).([]persistence.PeerConfig), args.Error(2) } -func (w *MockWireGuardStore) SaveInterface(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) error { - args := w.Called(cfg, peers) +func (w *MockWireGuardStore) SaveInterface(cfg persistence.InterfaceConfig) error { + args := w.Called(cfg) return args.Error(0) } diff --git a/internal/wireguard/tpl_files/peer.tpl b/internal/wireguard/tpl_files/peer.tpl index cd44f64..118eec9 100644 --- a/internal/wireguard/tpl_files/peer.tpl +++ b/internal/wireguard/tpl_files/peer.tpl @@ -8,7 +8,7 @@ # -WGP- Peer: {{.Peer.Identifier}} | Updated: {{.Peer.UpdatedAt}} | Created: {{.Peer.CreatedAt}} # -WGP- Display name: {{ .Peer.DisplayName }} # -WGP- PublicKey: {{ .Peer.KeyPair.PublicKey }} -{{- if eq $.Interface.Type "server"}} +{{- if eq .Peer.Interface.Type "server"}} # -WGP- Peer type: client {{else}} # -WGP- Peer type: server @@ -16,38 +16,38 @@ # Core settings PrivateKey = {{ .Peer.KeyPair.PrivateKey }} -Address = {{ .Peer.AddressStr.GetValue }} +Address = {{ .Peer.Interface.AddressStr.GetValue }} # Misc. settings (optional) -{{- if .Peer.DnsStr.GetValue}} -DNS = {{ .Peer.DnsStr.GetValue }} +{{- if .Peer.Interface.DnsStr.GetValue}} +DNS = {{ .Peer.Interface.DnsStr.GetValue }} {{- end}} -{{- if ne .Peer.Mtu.GetValue 0}} -MTU = {{ .Peer.Mtu.GetValue }} +{{- if ne .Peer.Interface.Mtu.GetValue 0}} +MTU = {{ .Peer.Interface.Mtu.GetValue }} {{- end}} -{{- if ne .Peer.FirewallMark.GetValue 0}} -FwMark = {{ .Peer.FirewallMark.GetValue }} +{{- if ne .Peer.Interface.FirewallMark.GetValue 0}} +FwMark = {{ .Peer.Interface.FirewallMark.GetValue }} {{- end}} -{{- if ne .Peer.RoutingTable.GetValue ""}} -Table = {{ .Peer.RoutingTable.GetValue }} +{{- if ne .Peer.Interface.RoutingTable.GetValue ""}} +Table = {{ .Peer.Interface.RoutingTable.GetValue }} {{- end}} # Interface hooks (optional) -{{- if .Peer.PreUp.GetValue}} -PreUp = {{ .Peer.PreUp.GetValue }} +{{- if .Peer.Interface.PreUp.GetValue}} +PreUp = {{ .Peer.Interface.PreUp.GetValue }} {{- end}} -{{- if .Peer.PostUp.GetValue}} -PostUp = {{ .Peer.PostUp.GetValue }} +{{- if .Peer.Interface.PostUp.GetValue}} +PostUp = {{ .Peer.Interface.PostUp.GetValue }} {{- end}} -{{- if .Peer.PreDown.GetValue}} -PreDown = {{ .Peer.PreDown.GetValue }} +{{- if .Peer.Interface.PreDown.GetValue}} +PreDown = {{ .Peer.Interface.PreDown.GetValue }} {{- end}} -{{- if .Peer.PostDown.GetValue}} -PostDown = {{ .Peer.PostDown.GetValue }} +{{- if .Peer.Interface.PostDown.GetValue}} +PostDown = {{ .Peer.Interface.PostDown.GetValue }} {{- end}} [Peer] -PublicKey = {{ .Interface.KeyPair.PublicKey }} +PublicKey = {{ .Peer.Interface.PublicKey }} Endpoint = {{ .Peer.Endpoint.GetValue }} {{- if .Peer.AllowedIPsStr.GetValue}} AllowedIPs = {{ .Peer.AllowedIPsStr.GetValue }} diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 52cd7bf..1c95fda 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -25,9 +25,9 @@ store store // internal holder of interface configurations - interfaces map[persistence.InterfaceIdentifier]persistence.InterfaceConfig + interfaces map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig // internal holder of peer configurations - peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig + peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig } func newWgCtrlManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*wgCtrlManager, error) { @@ -36,8 +36,8 @@ wg: wg, nl: nl, store: store, - interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), - peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig), + interfaces: make(map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig), + peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig), } if err := m.initializeFromStore(); err != nil { @@ -47,11 +47,11 @@ return m, nil } -func (m *wgCtrlManager) GetInterfaces() ([]persistence.InterfaceConfig, error) { +func (m *wgCtrlManager) GetInterfaces() ([]*persistence.InterfaceConfig, error) { m.mux.RLock() defer m.mux.RUnlock() - interfaces := make([]persistence.InterfaceConfig, 0, len(m.interfaces)) - for _, iface := range interfaces { + interfaces := make([]*persistence.InterfaceConfig, 0, len(m.interfaces)) + for _, iface := range m.interfaces { interfaces = append(interfaces, iface) } // Order the interfaces by device name @@ -74,9 +74,9 @@ return errors.WithMessage(err, "failed to create low level interface") } - newInterface := persistence.InterfaceConfig{Identifier: id} + newInterface := &persistence.InterfaceConfig{Identifier: id} m.interfaces[id] = newInterface - m.peers[id] = make(map[persistence.PeerIdentifier]persistence.PeerConfig) + m.peers[id] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig) err = m.persistInterface(id, false) if err != nil { @@ -122,9 +122,10 @@ return nil } -func (m *wgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error { +func (m *wgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg *persistence.InterfaceConfig) error { m.mux.Lock() defer m.mux.Unlock() + if !m.deviceExists(id) { return errors.New("interface does not exist") } @@ -192,14 +193,51 @@ return nil } -func (m *wgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) { +func (m *wgCtrlManager) ApplyDefaultConfigs(id persistence.InterfaceIdentifier) error { + m.mux.Lock() + defer m.mux.Unlock() + + if !m.deviceExists(id) { + return errors.New("device does not exist") + } + + cfg := m.interfaces[id] + + for p := range m.peers[id] { + m.peers[id][p].Endpoint.TrySetValue(cfg.PeerDefEndpoint) + m.peers[id][p].AllowedIPsStr.TrySetValue(cfg.PeerDefAllowedIPsStr) + + m.peers[id][p].Interface.Identifier = cfg.Identifier + m.peers[id][p].Interface.Type = cfg.Type + m.peers[id][p].Interface.PublicKey = cfg.KeyPair.PublicKey + + m.peers[id][p].Interface.DnsStr.TrySetValue(cfg.PeerDefDnsStr) + m.peers[id][p].Interface.Mtu.TrySetValue(cfg.PeerDefMtu) + m.peers[id][p].Interface.FirewallMark.TrySetValue(cfg.PeerDefFirewallMark) + m.peers[id][p].Interface.RoutingTable.TrySetValue(cfg.PeerDefRoutingTable) + + m.peers[id][p].Interface.PreUp.TrySetValue(cfg.PeerDefPreUp) + m.peers[id][p].Interface.PostUp.TrySetValue(cfg.PeerDefPostUp) + m.peers[id][p].Interface.PreDown.TrySetValue(cfg.PeerDefPreDown) + m.peers[id][p].Interface.PostDown.TrySetValue(cfg.PeerDefPostDown) + + err := m.persistPeer(m.peers[id][p].Identifier, false) + if err != nil { + return errors.Wrapf(err, "failed to persist peer defaults to %s", m.peers[id][p].Identifier) + } + } + + return nil +} + +func (m *wgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]*persistence.PeerConfig, error) { m.mux.RLock() defer m.mux.RUnlock() if !m.deviceExists(interfaceId) { return nil, errors.New("device does not exist") } - peers := make([]persistence.PeerConfig, 0, len(m.peers[interfaceId])) + peers := make([]*persistence.PeerConfig, 0, len(m.peers[interfaceId])) for _, config := range m.peers[interfaceId] { peers = append(peers, config) } @@ -207,12 +245,12 @@ return peers, nil } -func (m *wgCtrlManager) SavePeers(peers ...persistence.PeerConfig) error { +func (m *wgCtrlManager) SavePeers(peers ...*persistence.PeerConfig) error { m.mux.Lock() defer m.mux.Unlock() for _, peer := range peers { - deviceId := peer.PeerInterfaceConfig.Identifier + deviceId := peer.Interface.Identifier if !m.deviceExists(deviceId) { return errors.Errorf("device does not exist") } @@ -248,7 +286,7 @@ } peer, _ := m.getPeer(id) - deviceId := peer.PeerInterfaceConfig.Identifier + deviceId := peer.Interface.Identifier publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey) if err != nil { @@ -275,7 +313,7 @@ return nil } -func (m *wgCtrlManager) GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) { +func (m *wgCtrlManager) GetImportableInterfaces() (map[*ImportableInterface][]*persistence.PeerConfig, error) { devices, err := m.wg.Devices() if err != nil { return nil, errors.WithMessage(err, "failed to get WireGuard device list") @@ -284,7 +322,7 @@ m.mux.RLock() defer m.mux.RUnlock() - interfaces := make(map[ImportableInterface][]persistence.PeerConfig, len(devices)) + interfaces := make(map[*ImportableInterface][]*persistence.PeerConfig, len(devices)) for d, device := range devices { if _, exists := m.interfaces[persistence.InterfaceIdentifier(device.Name)]; exists { continue // interface already managed, skip @@ -295,7 +333,7 @@ return nil, errors.WithMessagef(err, "failed to convert WireGuard interface %s", device.Name) } - interfaces[cfg] = make([]persistence.PeerConfig, len(device.Peers)) + interfaces[cfg] = make([]*persistence.PeerConfig, len(device.Peers)) for p, peer := range device.Peers { peerCfg, err := m.convertWireGuardPeer(&device.Peers[p], cfg) @@ -311,7 +349,7 @@ return interfaces, nil } -func (m *wgCtrlManager) ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error { +func (m *wgCtrlManager) ImportInterface(cfg *ImportableInterface, peers []*persistence.PeerConfig) error { m.mux.Lock() defer m.mux.Unlock() @@ -339,13 +377,15 @@ return errors.WithMessage(err, "failed to get all interfaces") } - for cfg, peers := range interfaces { - m.interfaces[cfg.Identifier] = cfg + for tmpCfg, tmpPeers := range interfaces { + cfg := tmpCfg + peers := tmpPeers + m.interfaces[cfg.Identifier] = &cfg if _, ok := m.peers[cfg.Identifier]; !ok { - m.peers[cfg.Identifier] = make(map[persistence.PeerIdentifier]persistence.PeerConfig) + m.peers[cfg.Identifier] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig) } for _, peer := range peers { - m.peers[cfg.Identifier][peer.Identifier] = peer + m.peers[cfg.Identifier][peer.Identifier] = &peer } } @@ -387,7 +427,7 @@ if delete { err = m.store.DeleteInterface(id) } else { - err = m.store.SaveInterface(m.interfaces[id]) + err = m.store.SaveInterface(*m.interfaces[id]) } if err != nil { return errors.Wrapf(err, "failed to persist interface") @@ -411,7 +451,7 @@ return nil // nothing to do } - var peer persistence.PeerConfig + var peer *persistence.PeerConfig for _, peers := range m.peers { if p, ok := peers[id]; ok { peer = p @@ -421,9 +461,9 @@ var err error if delete { - err = m.store.DeletePeer(id, peer.PeerInterfaceConfig.Identifier) + err = m.store.DeletePeer(id, peer.Interface.Identifier) } else { - err = m.store.SavePeer(peer, peer.PeerInterfaceConfig.Identifier) + err = m.store.SavePeer(*peer, peer.Interface.Identifier) } if err != nil { return errors.Wrapf(err, "failed to persist peer %s", id) @@ -432,18 +472,18 @@ return nil } -func (m *wgCtrlManager) getPeer(id persistence.PeerIdentifier) (persistence.PeerConfig, error) { +func (m *wgCtrlManager) getPeer(id persistence.PeerIdentifier) (*persistence.PeerConfig, error) { for _, peers := range m.peers { if _, ok := peers[id]; ok { return peers[id], nil } } - return persistence.PeerConfig{}, errors.New("peer not found") + return nil, errors.New("peer not found") } -func (m *wgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (ImportableInterface, error) { - cfg := ImportableInterface{} +func (m *wgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (*ImportableInterface, error) { + cfg := &ImportableInterface{} cfg.Identifier = persistence.InterfaceIdentifier(device.Name) cfg.FirewallMark = int32(device.FirewallMark) @@ -456,20 +496,20 @@ lowLevelInterface, err := m.nl.LinkByName(device.Name) if err != nil { - return ImportableInterface{}, errors.WithMessagef(err, "failed to get low level interface for %s", device.Name) + return nil, errors.WithMessagef(err, "failed to get low level interface for %s", device.Name) } cfg.Mtu = lowLevelInterface.Attrs().MTU ipAddresses, err := m.nl.AddrList(lowLevelInterface) if err != nil { - return ImportableInterface{}, errors.WithMessagef(err, "failed to get low level addresses for %s", device.Name) + return nil, errors.WithMessagef(err, "failed to get low level addresses for %s", device.Name) } cfg.AddressStr = ipAddressesToString(ipAddresses) return cfg, nil } -func (m *wgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev ImportableInterface) (persistence.PeerConfig, error) { - peerCfg := persistence.PeerConfig{} +func (m *wgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev *ImportableInterface) (*persistence.PeerConfig, error) { + peerCfg := &persistence.PeerConfig{} peerCfg.Identifier = persistence.PeerIdentifier(peer.PublicKey.String()) peerCfg.KeyPair = persistence.KeyPair{ PublicKey: peer.PublicKey.String(), @@ -488,7 +528,7 @@ peerCfg.AllowedIPsStr = persistence.NewStringConfigOption(strings.Join(allowedIPs, ","), true) peerCfg.PersistentKeepalive = persistence.NewIntConfigOption(int(peer.PersistentKeepaliveInterval.Seconds()), true) - peerCfg.PeerInterfaceConfig = persistence.PeerInterfaceConfig{ + peerCfg.Interface = &persistence.PeerInterfaceConfig{ Identifier: dev.Identifier, AddressStr: persistence.NewStringConfigOption(dev.AddressStr, true), // todo: correct? DnsStr: persistence.NewStringConfigOption(dev.DnsStr, true), @@ -498,34 +538,7 @@ return peerCfg, nil } -func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { - rawAddresses := strings.Split(addrStr, ",") - addresses := make([]*netlink.Addr, 0, len(rawAddresses)) - for i := range rawAddresses { - rawAddress := strings.TrimSpace(rawAddresses[i]) - if rawAddress == "" { - continue // skip empty string - } - address, err := netlink.ParseAddr(rawAddress) - if err != nil { - return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress) - } - addresses = append(addresses, address) - } - - return addresses, nil -} - -func ipAddressesToString(addresses []netlink.Addr) string { - addressesStr := make([]string, len(addresses)) - for i := range addresses { - addressesStr[i] = addresses[i].String() - } - - return strings.Join(addressesStr, ",") -} - -func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg persistence.PeerConfig) (wgtypes.PeerConfig, error) { +func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg *persistence.PeerConfig) (wgtypes.PeerConfig, error) { publicKey, err := wgtypes.ParseKey(cfg.KeyPair.PublicKey) if err != nil { return wgtypes.PeerConfig{}, errors.WithMessage(err, "invalid public key for peer") diff --git a/internal/wireguard/wireguard_ip.go b/internal/wireguard/wireguard_ip.go new file mode 100644 index 0000000..95da701 --- /dev/null +++ b/internal/wireguard/wireguard_ip.go @@ -0,0 +1,215 @@ +package wireguard + +import ( + "bytes" + "net" + "sort" + "strings" + + "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" + "github.com/vishvananda/netlink" +) + +func (m *wgCtrlManager) GetAllUsedIPs(id persistence.InterfaceIdentifier) ([]*netlink.Addr, error) { + m.mux.RLock() + defer m.mux.RUnlock() + + if !m.deviceExists(id) { + return nil, errors.New("interface does not exist") + } + + var usedAddresses []*netlink.Addr + for _, peer := range m.peers[id] { + addresses, err := parseIpAddressString(peer.Interface.AddressStr.GetValue()) + if err != nil { + return nil, errors.WithMessagef(err, "unable to parse addresses of peer %s", peer.Identifier) + } + + usedAddresses = append(usedAddresses, addresses...) + } + + sort.Slice(usedAddresses, func(i, j int) bool { + return bytes.Compare(usedAddresses[i].IP, usedAddresses[j].IP) < 0 + }) + + return usedAddresses, nil +} + +func (m *wgCtrlManager) GetUsedIPs(id persistence.InterfaceIdentifier, subnetCidr string) ([]*netlink.Addr, error) { + m.mux.RLock() + defer m.mux.RUnlock() + + if !m.deviceExists(id) { + return nil, errors.New("interface does not exist") + } + + subnet, err := parseCIDR(subnetCidr) + if err != nil { + return nil, errors.WithMessagef(err, "unable to parse subnet addresses") + } + + var usedAddresses []*netlink.Addr + for _, peer := range m.peers[id] { + addresses, err := parseIpAddressString(peer.Interface.AddressStr.GetValue()) + if err != nil { + return nil, errors.WithMessagef(err, "unable to parse addresses of peer %s", peer.Identifier) + } + + for _, address := range addresses { + if subnet.Contains(address.IP) { + usedAddresses = append(usedAddresses, address) + } + } + } + + sort.Slice(usedAddresses, func(i, j int) bool { + return bytes.Compare(usedAddresses[i].IP, usedAddresses[j].IP) < 0 + }) + + return usedAddresses, nil +} + +func (m *wgCtrlManager) GetFreshIp(id persistence.InterfaceIdentifier, subnetCidr string, increment ...bool) (*netlink.Addr, error) { + m.mux.RLock() + defer m.mux.RUnlock() + + if !m.deviceExists(id) { + return nil, errors.New("interface does not exist") + } + + subnet, err := parseCIDR(subnetCidr) + if err != nil { + return nil, errors.WithMessagef(err, "unable to parse subnet addresses") + } + isV4 := isV4(subnet) + + usedIPs, err := m.GetUsedIPs(id, subnetCidr) // highest IP is at the end of the array + if err != nil { + return nil, errors.WithMessagef(err, "unable to load used IP addresses") + } + + // these two addresses are not usable + broadcastAddr := broadcastAddr(subnet) + networkAddr := subnet.IP + // start with the lowest IP and check all others + ip := &netlink.Addr{ + IPNet: &net.IPNet{IP: subnet.IP.Mask(subnet.Mask).To16(), Mask: subnet.Mask}, + } + if len(increment) != 0 && increment[0] == true && len(usedIPs) > 0 { + // start with the maximum used IP and check all above + ip = &netlink.Addr{ + IPNet: &net.IPNet{IP: make([]byte, 16), Mask: subnet.Mask}, + } + copy(ip.IP, usedIPs[len(usedIPs)-1].IP) + } + + for ; subnet.Contains(ip.IP); increaseIP(ip) { + if bytes.Compare(ip.IP, networkAddr) == 0 { + continue + } + if isV4 && bytes.Compare(ip.IP, broadcastAddr.IP) == 0 { + continue + } + + ok := true + for _, r := range usedIPs { + if bytes.Compare(ip.IP, r.IP) == 0 { + ok = false + break + } + } + + if ok { + return ip, nil + } + } + + return nil, errors.New("ip range exceeded") +} + +// http://play.golang.org/p/m8TNTtygK0 +func increaseIP(ip *netlink.Addr) { + for j := len(ip.IP) - 1; j >= 0; j-- { + ip.IP[j]++ + if ip.IP[j] > 0 { + break + } + } +} + +// BroadcastAddr returns the last address in the given network (for IPv6), or the broadcast address. +func broadcastAddr(n *netlink.Addr) *netlink.Addr { + // The golang net package doesn't make it easy to calculate the broadcast address. :( + var broadcast = net.IPv6zero + var mask = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} // ensure that mask also has 16 bytes (also for IPv4) + if len(n.Mask) == 4 { + for i := 0; i < 4; i++ { + mask[12+i] = n.Mask[i] + } + } else { + for i := 0; i < 16; i++ { + mask[i] = n.Mask[i] + } + } + for i := 0; i < len(n.IP); i++ { + broadcast[i] = n.IP[i] | ^mask[i] + } + return &netlink.Addr{ + IPNet: &net.IPNet{IP: broadcast, Mask: n.Mask}, + } +} + +func isV4(n *netlink.Addr) bool { + if n.IP.To4() != nil { + return true + } + + return false +} + +func parseCIDR(cidr string) (*netlink.Addr, error) { + addr, err := netlink.ParseAddr(cidr) + if err != nil { + return nil, errors.WithMessagef(err, "failed to parse cidr") + } + + // Use the 16byte representation for all IP families. + if len(addr.IP) != 16 { + addr.IP = addr.IP.To16() + } + + return addr, nil +} + +func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { + rawAddresses := strings.Split(addrStr, ",") + addresses := make([]*netlink.Addr, 0, len(rawAddresses)) + for i := range rawAddresses { + rawAddress := strings.TrimSpace(rawAddresses[i]) + if rawAddress == "" { + continue // skip empty string + } + address, err := parseCIDR(rawAddress) + if err != nil { + return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress) + } + + addresses = append(addresses, address) + } + + sort.Slice(addresses, func(i, j int) bool { + return bytes.Compare(addresses[i].IP, addresses[j].IP) < 0 + }) + + return addresses, nil +} + +func ipAddressesToString(addresses []netlink.Addr) string { + addressesStr := make([]string, len(addresses)) + for i := range addresses { + addressesStr[i] = addresses[i].String() + } + + return strings.Join(addressesStr, ",") +} diff --git a/internal/wireguard/wireguard_ip_test.go b/internal/wireguard/wireguard_ip_test.go new file mode 100644 index 0000000..1a10580 --- /dev/null +++ b/internal/wireguard/wireguard_ip_test.go @@ -0,0 +1,719 @@ +package wireguard + +import ( + "net" + "reflect" + "testing" + + "github.com/h44z/wg-portal/internal/persistence" + "github.com/vishvananda/netlink" +) + +func ignoreNetlinkError(addr *netlink.Addr, _ error) *netlink.Addr { + return addr +} + +func Test_broadcastAddr(t *testing.T) { + tests := []struct { + name string + arg *netlink.Addr + want *netlink.Addr + }{ + { + name: "V4_0", + arg: ignoreNetlinkError(parseCIDR("10.0.0.0/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + }, + { + name: "V4_1", + arg: ignoreNetlinkError(parseCIDR("10.0.0.1/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + }, + { + name: "V4_2", + arg: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + }, + { + name: "V6_0", + arg: ignoreNetlinkError(parseCIDR("fe80::/64")), + want: ignoreNetlinkError(parseCIDR("fe80::ffff:ffff:ffff:ffff/64")), + }, + { + name: "V6_1", + arg: ignoreNetlinkError(parseCIDR("fe80::1:2:3/64")), + want: ignoreNetlinkError(parseCIDR("fe80::ffff:ffff:ffff:ffff/64")), + }, + { + name: "V6_2", + arg: ignoreNetlinkError(parseCIDR("fe80::ffff:ffff:ffff:ffff/64")), + want: ignoreNetlinkError(parseCIDR("fe80::ffff:ffff:ffff:ffff/64")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := broadcastAddr(tt.arg); got.String() != tt.want.String() { + t.Errorf("broadcastAddr() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_increaseIP(t *testing.T) { + tests := []struct { + name string + ip *netlink.Addr + want *netlink.Addr + }{ + { + name: "V4_1", + ip: ignoreNetlinkError(parseCIDR("10.0.0.0/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.1/24")), + }, + { + name: "V4_2", + ip: ignoreNetlinkError(parseCIDR("10.0.0.2/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.3/24")), + }, + { + name: "V4_3", + ip: ignoreNetlinkError(parseCIDR("10.0.0.254/24")), + want: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + }, + { + name: "V4_4", + ip: ignoreNetlinkError(parseCIDR("10.0.0.255/24")), + want: ignoreNetlinkError(parseCIDR("10.0.1.0/24")), + }, + { + name: "V4_5", + ip: ignoreNetlinkError(parseCIDR("10.0.0.5/32")), + want: ignoreNetlinkError(parseCIDR("10.0.0.6/32")), + }, + { + name: "V6_1", + ip: ignoreNetlinkError(parseCIDR("2001:db8::/64")), + want: ignoreNetlinkError(parseCIDR("2001:db8::1/64")), + }, + { + name: "V6_2", + ip: ignoreNetlinkError(parseCIDR("2001:db8::5/64")), + want: ignoreNetlinkError(parseCIDR("2001:db8::6/64")), + }, + { + name: "V6_3", + ip: ignoreNetlinkError(parseCIDR("2001:0db8:0000:0000:ffff:ffff:ffff:fffe/64")), + want: ignoreNetlinkError(parseCIDR("2001:0db8:0000:0000:ffff:ffff:ffff:ffff/64")), + }, + { + name: "V6_4", + ip: ignoreNetlinkError(parseCIDR("2001:0db8:0:0:ffff:ffff:ffff:ffff/64")), + want: ignoreNetlinkError(parseCIDR("2001:db8:0:1::/64")), + }, + { + name: "V6_5", + ip: ignoreNetlinkError(parseCIDR("2001:0db8:0:0:ffff:ffff:ffff:ffff/128")), + want: ignoreNetlinkError(parseCIDR("2001:0db8:0:1::/128")), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + increaseIP(tt.ip) + if !reflect.DeepEqual(tt.ip, tt.want) { + t.Errorf("increaseIP() got = %v, want %v", tt.ip, tt.want) + } + }) + } +} + +func Test_isV4(t *testing.T) { + tests := []struct { + name string + arg *netlink.Addr + want bool + }{ + { + name: "V4", + arg: ignoreNetlinkError(parseCIDR("10.0.0.1/24")), + want: true, + }, + { + name: "V4 network", + arg: ignoreNetlinkError(parseCIDR("10.0.0.0/24")), + want: true, + }, + { + name: "V6", + arg: ignoreNetlinkError(parseCIDR("fe80::/64")), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isV4(tt.arg); got != tt.want { + t.Errorf("isV4() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_wgCtrlManager_GetAllUsedIPs(t *testing.T) { + type args struct { + id persistence.InterfaceIdentifier + } + tests := []struct { + name string + mgr *wgCtrlManager + args args + want []*netlink.Addr + wantErr bool + }{ + { + name: "No Such Interface", + mgr: &wgCtrlManager{peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig)}, + args: args{id: "wg0"}, + want: nil, + wantErr: true, + }, + { + name: "No Peers", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{"wg0": nil}}, + args: args{id: "wg0"}, + want: nil, + wantErr: false, + }, + { + name: "Wrong IP addresses", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("invalid", true)}}, + }, + }, + }, + args: args{id: "wg0"}, + want: nil, + wantErr: true, + }, + { + name: "Single IP addresses", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.3/24", true)}}, + }, + }, + }, + args: args{id: "wg0"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("10.0.0.2/24")), + ignoreNetlinkError(parseCIDR("10.0.0.3/24")), + }, + wantErr: false, + }, + { + name: "Multiple IP addresses", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24,684D:1111:222:3333:4444:5555:6:77/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("1.1.1.1/30,10.0.0.3/24,8.8.8.8/32", true)}}, + }, + }, + }, + args: args{id: "wg0"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("1.1.1.1/30")), + ignoreNetlinkError(parseCIDR("8.8.8.8/32")), + ignoreNetlinkError(parseCIDR("10.0.0.2/24")), + ignoreNetlinkError(parseCIDR("10.0.0.3/24")), + ignoreNetlinkError(parseCIDR("684D:1111:222:3333:4444:5555:6:77/64")), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.mgr.GetAllUsedIPs(tt.args.id) + if (err != nil) != tt.wantErr { + t.Errorf("GetAllUsedIPs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetAllUsedIPs() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_wgCtrlManager_GetUsedIPs(t *testing.T) { + type args struct { + id persistence.InterfaceIdentifier + subnetCidr string + } + tests := []struct { + name string + mgr *wgCtrlManager + args args + want []*netlink.Addr + wantErr bool + }{ + { + name: "No Such Interface", + mgr: &wgCtrlManager{peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig)}, + args: args{id: "wg0", subnetCidr: "10.0.0.0/24"}, + want: nil, + wantErr: true, + }, + { + name: "No Peers", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{"wg0": nil}}, + args: args{id: "wg0", subnetCidr: "10.0.0.0/24"}, + want: nil, + wantErr: false, + }, + { + name: "Wrong subnet", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{"wg0": nil}}, + args: args{id: "wg0", subnetCidr: "subnet"}, + want: nil, + wantErr: true, + }, + { + name: "Wrong IP addresses", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("invalid", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "10.0.0.0/24"}, + want: nil, + wantErr: true, + }, + { + name: "Single IP addresses V4", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.3/24", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "10.0.0.0/24"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("10.0.0.2/24")), + ignoreNetlinkError(parseCIDR("10.0.0.3/24")), + }, + wantErr: false, + }, + { + name: "Single IP addresses V6", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::5/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::6/64", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "2001:db8::/64"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("2001:db8::5/64")), + ignoreNetlinkError(parseCIDR("2001:db8::6/64")), + }, + wantErr: false, + }, + { + name: "Multiple IP addresses V4", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24,684D:1111:222:3333:4444:5555:6:77/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("1.1.1.1/30,10.0.0.3/24,8.8.8.8/32", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "10.0.0.0/24"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("10.0.0.2/24")), + ignoreNetlinkError(parseCIDR("10.0.0.3/24")), + }, + wantErr: false, + }, + { + name: "Multiple IP addresses V6", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24,2001:db8::5/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::6/64", true)}}, + "peer2": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db9::6/64,2001:db8:0:0:100::6/64", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "2001:db8::/64"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("2001:db8::5/64")), + ignoreNetlinkError(parseCIDR("2001:db8::6/64")), + ignoreNetlinkError(parseCIDR("2001:db8::100:0:0:6/64")), + }, + wantErr: false, + }, + { + name: "Sort Order", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.3/16,10.0.0.2/16,10.0.5.2/16", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.1/16,10.0.4.2/16,10.0.6.2/16,10.0.5.3/16", true)}}, + }, + }, + }, + args: args{id: "wg0", subnetCidr: "10.0.0.0/16"}, + want: []*netlink.Addr{ + ignoreNetlinkError(parseCIDR("10.0.0.1/16")), + ignoreNetlinkError(parseCIDR("10.0.0.2/16")), + ignoreNetlinkError(parseCIDR("10.0.0.3/16")), + ignoreNetlinkError(parseCIDR("10.0.4.2/16")), + ignoreNetlinkError(parseCIDR("10.0.5.2/16")), + ignoreNetlinkError(parseCIDR("10.0.5.3/16")), + ignoreNetlinkError(parseCIDR("10.0.6.2/16")), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.mgr.GetUsedIPs(tt.args.id, tt.args.subnetCidr) + if (err != nil) != tt.wantErr { + t.Errorf("GetUsedIPs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetUsedIPs() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_wgCtrlManager_GetFreshIp(t *testing.T) { + type args struct { + id persistence.InterfaceIdentifier + subnetCidr string + increment []bool + } + tests := []struct { + name string + mgr *wgCtrlManager + args args + want *netlink.Addr + wantErr bool + }{ + { + name: "V4_1_noincrement", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.3/24", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "10.0.0.0/24", + }, + want: ignoreNetlinkError(parseCIDR("10.0.0.1/24")), + wantErr: false, + }, + { + name: "V4_1_increment", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/24", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.3/24", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "10.0.0.0/24", + increment: []bool{true}, + }, + want: ignoreNetlinkError(parseCIDR("10.0.0.4/24")), + wantErr: false, + }, + { + name: "V4_1_overflow", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("10.0.0.2/32", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "10.0.0.2/32", + increment: []bool{true}, + }, + want: nil, + wantErr: true, + }, + { + name: "V6_1_noincrement", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::5/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::6/64", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "2001:db8::/64", + }, + want: ignoreNetlinkError(parseCIDR("2001:db8::1/64")), + wantErr: false, + }, + { + name: "V6_1_increment", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::5/64", true)}}, + "peer1": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::6/64", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "2001:db8::/64", + increment: []bool{true}, + }, + want: ignoreNetlinkError(parseCIDR("2001:db8::7/64")), + wantErr: false, + }, + { + name: "V6_1_overflow", + mgr: &wgCtrlManager{ + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Interface: &persistence.PeerInterfaceConfig{AddressStr: persistence.NewStringConfigOption("2001:db8::ffff/128", true)}}, + }, + }, + }, + args: args{ + id: "wg0", + subnetCidr: "2001:db8::/128", + increment: []bool{true}, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.mgr.GetFreshIp(tt.args.id, tt.args.subnetCidr, tt.args.increment...) + if (err != nil) != tt.wantErr { + t.Errorf("GetFreshIp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetFreshIp() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseCIDR(t *testing.T) { + tests := []struct { + name string + cidr string + want *netlink.Addr + wantErr bool + }{ + { + name: "Valid V4", + cidr: "10.0.0.1/24", + want: &netlink.Addr{IPNet: &net.IPNet{ + IP: net.IPv4(10, 0, 0, 1), + Mask: net.IPv4Mask(255, 255, 255, 0)}, + }, + wantErr: false, + }, + { + name: "Inalid V4", + cidr: "10.0.0.1/64", + want: nil, + wantErr: true, + }, + { + name: "Valid V6", + cidr: "fe80::/128", + want: &netlink.Addr{IPNet: &net.IPNet{ + IP: []byte{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + Mask: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + }, + wantErr: false, + }, + { + name: "Inalid V6", + cidr: "10:0:0::/256", + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseCIDR(tt.cidr) + if (err != nil) != tt.wantErr { + t.Errorf("parseCIDR() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseCIDR() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_parseIpAddressString(t *testing.T) { + type args struct { + addrStr string + } + var tests = []struct { + name string + args args + want []*netlink.Addr + wantErr bool + }{ + { + name: "Empty String", + args: args{}, + want: []*netlink.Addr{}, + wantErr: false, + }, + { + name: "Single IPv4", + args: args{addrStr: "123.123.123.123"}, + want: nil, + wantErr: true, + }, + { + name: "Malformed", + args: args{addrStr: "hello world"}, + want: nil, + wantErr: true, + }, + { + name: "Single IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv4 CIDR", + args: args{addrStr: "123.123.123.123/24, 200.201.202.203/16"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(123, 123, 123, 123), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }}, + wantErr: false, + }, + { + name: "Single IPv6 CIDR", + args: args{addrStr: "fe80::1/64"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }}, + wantErr: false, + }, + { + name: "Multiple IPv6 CIDR", + args: args{addrStr: "fe80::1/64 , 2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + }}, + wantErr: false, + }, + { + name: "Mixed IPv4 and IPv6 CIDR", + args: args{addrStr: "200.201.202.203/16,2130:d3ad::b33f/128"}, + want: []*netlink.Addr{{ + IPNet: &net.IPNet{ + IP: net.IPv4(200, 201, 202, 203), + Mask: net.IPv4Mask(255, 255, 0, 0), + }, + }, { + IPNet: &net.IPNet{ + IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, + Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + }}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseIpAddressString(tt.args.addrStr) + if (err != nil) != tt.wantErr { + t.Errorf("parseIpAddressString() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("parseIpAddressString() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/wireguard/wireguard_test.go b/internal/wireguard/wireguard_test.go index f145496..f5994e7 100644 --- a/internal/wireguard/wireguard_test.go +++ b/internal/wireguard/wireguard_test.go @@ -4,7 +4,6 @@ package wireguard import ( - "net" "reflect" "sync" "testing" @@ -12,7 +11,6 @@ "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" "github.com/stretchr/testify/mock" - "github.com/vishvananda/netlink" ) func TestWgCtrlManager_CreateInterface(t *testing.T) { @@ -30,7 +28,7 @@ wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, store: &MockWireGuardStore{}, - interfaces: map[persistence.InterfaceIdentifier]persistence.InterfaceConfig{"wg0": {}}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, peers: nil, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, @@ -77,8 +75,8 @@ wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, store: &MockWireGuardStore{}, - interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), - peers: nil, + interfaces: make(map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig), + peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig), }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkAdd", mock.Anything).Return(nil) @@ -95,8 +93,8 @@ wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, store: &MockWireGuardStore{}, - interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), - peers: nil, + interfaces: make(map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig), + peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig), }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkAdd", mock.Anything).Return(nil) @@ -186,7 +184,7 @@ func TestWgCtrlManager_UpdateInterface(t *testing.T) { type args struct { id persistence.InterfaceIdentifier - cfg persistence.InterfaceConfig + cfg *persistence.InterfaceConfig } tests := []struct { name string @@ -213,116 +211,3 @@ }) } } - -func Test_parseIpAddressString(t *testing.T) { - type args struct { - addrStr string - } - var tests = []struct { - name string - args args - want []*netlink.Addr - wantErr bool - }{ - { - name: "Empty String", - args: args{}, - want: []*netlink.Addr{}, - wantErr: false, - }, - { - name: "Single IPv4", - args: args{addrStr: "123.123.123.123"}, - want: nil, - wantErr: true, - }, - { - name: "Malformed", - args: args{addrStr: "hello world"}, - want: nil, - wantErr: true, - }, - { - name: "Single IPv4 CIDR", - args: args{addrStr: "123.123.123.123/24"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(123, 123, 123, 123), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - }}, - wantErr: false, - }, - { - name: "Multiple IPv4 CIDR", - args: args{addrStr: "123.123.123.123/24, 200.201.202.203/16"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(123, 123, 123, 123), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - }, { - IPNet: &net.IPNet{ - IP: net.IPv4(200, 201, 202, 203), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - }}, - wantErr: false, - }, - { - name: "Single IPv6 CIDR", - args: args{addrStr: "fe80::1/64"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - }}, - wantErr: false, - }, - { - name: "Multiple IPv6 CIDR", - args: args{addrStr: "fe80::1/64 , 2130:d3ad::b33f/128"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, - }, - }, { - IPNet: &net.IPNet{ - IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - }, - }}, - wantErr: false, - }, - { - name: "Mixed IPv4 and IPv6 CIDR", - args: args{addrStr: "200.201.202.203/16,2130:d3ad::b33f/128"}, - want: []*netlink.Addr{{ - IPNet: &net.IPNet{ - IP: net.IPv4(200, 201, 202, 203), - Mask: net.IPv4Mask(255, 255, 0, 0), - }, - }, { - IPNet: &net.IPNet{ - IP: net.IP{0x21, 0x30, 0xd3, 0xad, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xb3, 0x3f}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - }, - }}, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseIpAddressString(tt.args.addrStr) - if (err != nil) != tt.wantErr { - t.Errorf("parseIpAddressString() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("parseIpAddressString() got = %v, want %v", got, tt.want) - } - }) - } -}