diff --git a/internal/wireguard/test_helpers_test.go b/internal/wireguard/test_helpers_test.go index dbb3d4b..296addb 100644 --- a/internal/wireguard/test_helpers_test.go +++ b/internal/wireguard/test_helpers_test.go @@ -55,7 +55,11 @@ func (m *MockNetlinkClient) LinkByName(name string) (netlink.Link, error) { args := m.Called(name) - return args.Get(0).(netlink.Link), args.Error(1) + if args.Get(0) != nil { + return args.Get(0).(netlink.Link), args.Error(1) + } else { + return nil, args.Error(1) + } } func (m *MockNetlinkClient) LinkSetUp(link netlink.Link) error { diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 1c95fda..e8cb338 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -140,6 +140,9 @@ return errors.WithMessage(err, "failed to set MTU") } addresses, err := parseIpAddressString(cfg.AddressStr) + if err != nil { + return errors.WithMessage(err, "failed to parse ip address") + } for i := 0; i < len(addresses); i++ { var err error if i == 0 { @@ -558,8 +561,8 @@ } var keepAlive *time.Duration - if cfg.PersistentKeepalive.Value != 0 { - keepAliveDuration := time.Duration(cfg.PersistentKeepalive.Value.(int)) * time.Second + if cfg.PersistentKeepalive.GetValue() != 0 { + keepAliveDuration := time.Duration(cfg.PersistentKeepalive.GetValue()) * time.Second keepAlive = &keepAliveDuration } diff --git a/internal/wireguard/wireguard_test.go b/internal/wireguard/wireguard_test.go index f5994e7..ba141ae 100644 --- a/internal/wireguard/wireguard_test.go +++ b/internal/wireguard/wireguard_test.go @@ -8,11 +8,63 @@ "sync" "testing" + "github.com/stretchr/testify/assert" + + "github.com/vishvananda/netlink" + "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" "github.com/stretchr/testify/mock" ) +func TestWgCtrlManager_GetInterfaces(t *testing.T) { + tests := []struct { + name string + manager *wgCtrlManager + want []*persistence.InterfaceConfig + wantErr bool + }{ + { + name: "NoInterface", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + }, + want: []*persistence.InterfaceConfig{}, + wantErr: false, + }, + { + name: "Normal", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{ + "wg0": {Identifier: "wg0"}, + "wg2": {Identifier: "wg2"}, + "wg1": {Identifier: "wg1"}, + }, + }, + want: []*persistence.InterfaceConfig{ + {Identifier: "wg0"}, + {Identifier: "wg1"}, + {Identifier: "wg2"}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.manager.GetInterfaces() + if (err != nil) != tt.wantErr { + t.Errorf("GetInterfaces() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetInterfaces() got = %v, want %v", got, tt.want) + } + }) + } +} + func TestWgCtrlManager_CreateInterface(t *testing.T) { tests := []struct { name string @@ -81,7 +133,7 @@ mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkAdd", mock.Anything).Return(nil) nl.On("LinkSetUp", mock.Anything).Return(nil) - st.On("SaveInterface", mock.Anything, mock.Anything).Return(errors.New("failure")) + st.On("SaveInterface", mock.Anything).Return(errors.New("failure")) }, args: persistence.InterfaceIdentifier("wg0"), wantErr: true, @@ -99,7 +151,7 @@ mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkAdd", mock.Anything).Return(nil) nl.On("LinkSetUp", mock.Anything).Return(nil) - st.On("SaveInterface", mock.Anything, mock.Anything).Return(nil) + st.On("SaveInterface", mock.Anything).Return(nil) }, args: persistence.InterfaceIdentifier("wg0"), wantErr: false, @@ -130,7 +182,93 @@ args persistence.InterfaceIdentifier wantErr bool }{ - // TODO: Add test cases. + { + name: "NonExisting", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: "wg0", + wantErr: true, + }, + { + name: "LowLevelFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkDel", mock.Anything).Return(errors.New("failure")) + }, + args: "wg0", + wantErr: true, + }, + { + name: "PersistenceFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkDel", mock.Anything).Return(nil) + st.On("DeleteInterface", mock.Anything).Return(errors.New("failure")) + }, + args: "wg0", + wantErr: true, + }, + { + name: "PeerPersistenceFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {"peer0": {Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkDel", mock.Anything).Return(nil) + st.On("DeleteInterface", persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + }, + args: "wg0", + wantErr: true, + }, + { + name: "Success", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {"peer0": {Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkDel", mock.Anything).Return(nil) + st.On("DeleteInterface", persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(nil) + }, + args: "wg0", + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -149,38 +287,6 @@ } } -func TestWgCtrlManager_GetInterfaces(t *testing.T) { - tests := []struct { - name string - manager *wgCtrlManager - mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) - want []persistence.InterfaceConfig - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.mockSetup( - tt.manager.wg.(*MockWireGuardClient), - tt.manager.nl.(*MockNetlinkClient), - tt.manager.store.(*MockWireGuardStore), - ) - got, err := tt.manager.GetInterfaces() - if (err != nil) != tt.wantErr { - t.Errorf("GetInterfaces() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("GetInterfaces() got = %v, want %v", got, tt.want) - } - tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) - tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) - tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) - }) - } -} - func TestWgCtrlManager_UpdateInterface(t *testing.T) { type args struct { id persistence.InterfaceIdentifier @@ -193,7 +299,107 @@ args args wantErr bool }{ - // TODO: Add test cases. + { + name: "NonExistent", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: args{ + id: "wg0", + }, + wantErr: true, + }, + { + name: "NonExistentLowLevel", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + nl.On("LinkByName", "wg0").Return(nil, errors.New("failure")) + }, + args: args{ + id: "wg0", + cfg: &persistence.InterfaceConfig{}, + }, + wantErr: true, + }, + { + name: "SuccessEnabled", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + virtLink := &netlink.GenericLink{LinkType: "wireguard"} + nl.On("LinkByName", "wg0").Return(virtLink, nil) + nl.On("LinkSetMTU", virtLink, 234).Return(nil) + nl.On("AddrReplace", virtLink, mock.MatchedBy(func(addr *netlink.Addr) bool { + return addr.String() == "1.2.3.4/24" + })).Return(nil) + nl.On("AddrAdd", virtLink, mock.MatchedBy(func(addr *netlink.Addr) bool { + return addr.String() == "10.0.0.2/24" + })).Return(nil) + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + nl.On("LinkSetUp", virtLink).Return(nil) + st.On("SaveInterface", mock.Anything).Return(nil) + }, + args: args{ + id: "wg0", + cfg: &persistence.InterfaceConfig{ + Mtu: 234, AddressStr: "10.0.0.2/24,1.2.3.4/24", Enabled: true, + KeyPair: persistence.KeyPair{PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + }, + }, + wantErr: false, + }, + { + name: "SuccessDisabled", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + virtLink := &netlink.GenericLink{LinkType: "wireguard"} + nl.On("LinkByName", "wg0").Return(virtLink, nil) + nl.On("LinkSetMTU", virtLink, 234).Return(nil) + nl.On("AddrReplace", virtLink, mock.MatchedBy(func(addr *netlink.Addr) bool { + return addr.String() == "1.2.3.4/24" + })).Return(nil) + nl.On("AddrAdd", virtLink, mock.MatchedBy(func(addr *netlink.Addr) bool { + return addr.String() == "10.0.0.2/24" + })).Return(nil) + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + nl.On("LinkSetDown", virtLink).Return(nil) + st.On("SaveInterface", mock.Anything).Return(nil) + }, + args: args{ + id: "wg0", + cfg: &persistence.InterfaceConfig{ + Mtu: 234, AddressStr: "10.0.0.2/24,1.2.3.4/24", Enabled: false, + KeyPair: persistence.KeyPair{PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + }, + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -211,3 +417,381 @@ }) } } + +func TestWgCtrlManager_ApplyDefaultConfigs(t *testing.T) { + type args struct { + id persistence.InterfaceIdentifier + } + tests := []struct { + name string + manager *wgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args args + wantErr bool + }{ + { + name: "NoInterface", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: args{ + id: "wg0", + }, + wantErr: true, + }, + { + name: "PersistenceFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {Identifier: "wg0"}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Identifier: "peer0", Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}, + }, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + }, + args: args{ + id: "wg0", + }, + wantErr: true, + }, + { + name: "Success", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {Identifier: "wg0"}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": {Identifier: "peer0", Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}, + }, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(nil) + }, + args: args{ + id: "wg0", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.ApplyDefaultConfigs(tt.args.id); (err != nil) != tt.wantErr { + t.Errorf("ApplyDefaultConfigs() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func TestWgCtrlManager_GetPeers(t *testing.T) { + tests := []struct { + name string + manager *wgCtrlManager + interfaceId persistence.InterfaceIdentifier + want []*persistence.PeerConfig + wantErr bool + }{ + { + name: "NoInterface", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + }, + interfaceId: "wg0", + want: nil, + wantErr: true, + }, + { + name: "Normal", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": { + "peer0": &persistence.PeerConfig{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}, + "peer1": &persistence.PeerConfig{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg1"}}, + }, + }, + }, + interfaceId: "wg0", + want: []*persistence.PeerConfig{ + {Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}, + {Interface: &persistence.PeerInterfaceConfig{Identifier: "wg1"}}, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.manager.GetPeers(tt.interfaceId) + if (err != nil) != tt.wantErr { + t.Errorf("GetPeers() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !assert.Equal(t, got, tt.want) { + t.Errorf("GetPeers() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestWgCtrlManager_SavePeers(t *testing.T) { + tests := []struct { + name string + manager *wgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args []*persistence.PeerConfig + wantErr bool + }{ + { + name: "NoInterface", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: []*persistence.PeerConfig{{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + wantErr: true, + }, + { + name: "ConfigGenerationFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: []*persistence.PeerConfig{{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + wantErr: true, + }, + { + name: "WireGuardFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: nil, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(errors.New("failure")) + }, + args: []*persistence.PeerConfig{ + { + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }, + }, + wantErr: true, + }, + { + name: "PersistenceFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + }, + args: []*persistence.PeerConfig{ + { + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }, + }, + wantErr: true, + }, + { + name: "Success", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(nil) + }, + args: []*persistence.PeerConfig{ + { + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.SavePeers(tt.args...); (err != nil) != tt.wantErr { + t.Errorf("SavePeers() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +} + +func TestWgCtrlManager_RemovePeer(t *testing.T) { + tests := []struct { + name string + manager *wgCtrlManager + mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) + args persistence.PeerIdentifier + wantErr bool + }{ + { + name: "NoPeer", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{}, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, + args: "peer0", + wantErr: true, + }, + { + name: "WireGuardFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {"peer0": { + KeyPair: persistence.KeyPair{ + PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + }, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(errors.New("failure")) + }, + args: "peer0", + wantErr: true, + }, + { + name: "PersistenceFailure", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {"peer0": { + KeyPair: persistence.KeyPair{ + PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + }, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + }, + args: "peer0", + wantErr: true, + }, + { + name: "Success", + manager: &wgCtrlManager{ + mux: sync.RWMutex{}, + wg: &MockWireGuardClient{}, + nl: &MockNetlinkClient{}, + store: &MockWireGuardStore{}, + interfaces: map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig{"wg0": {}}, + peers: map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig{ + "wg0": {"peer0": { + KeyPair: persistence.KeyPair{ + PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", + }, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + }}, + }, + }, + mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { + wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(nil) + }, + args: "peer0", + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.mockSetup( + tt.manager.wg.(*MockWireGuardClient), + tt.manager.nl.(*MockNetlinkClient), + tt.manager.store.(*MockWireGuardStore), + ) + if err := tt.manager.RemovePeer(tt.args); (err != nil) != tt.wantErr { + t.Errorf("RemovePeer() error = %v, wantErr %v", err, tt.wantErr) + } + tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) + tt.manager.nl.(*MockNetlinkClient).AssertExpectations(t) + tt.manager.store.(*MockWireGuardStore).AssertExpectations(t) + }) + } +}