diff --git a/internal/persistence/models.go b/internal/persistence/models.go index 1d4387b..8341896 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -157,5 +157,5 @@ // database internal fields CreatedAt time.Time UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty"` + DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty"` // used as a "deactivated" flag } diff --git a/internal/persistence/users.go b/internal/persistence/users.go index a181852..146c948 100644 --- a/internal/persistence/users.go +++ b/internal/persistence/users.go @@ -6,22 +6,6 @@ "github.com/pkg/errors" ) -func (d *Database) GetUser(id UserIdentifier) (User, error) { - var user User - if err := d.db.First(&user, id).Error; err != nil { - return User{}, errors.WithMessagef(err, "unable to find user %s", id) - } - return user, nil -} - -func (d *Database) GetUsers() ([]User, error) { - var users []User - if err := d.db.Find(&users).Error; err != nil { - return nil, errors.WithMessagef(err, "unable to find users") - } - return users, nil -} - func (d *Database) GetUsersUnscoped() ([]User, error) { var users []User if err := d.db.Unscoped().Find(&users).Error; err != nil { @@ -30,18 +14,6 @@ return users, nil } -func (d *Database) GetUsersFiltered(filters ...DatabaseFilterCondition) ([]User, error) { - var users []User - tx := d.db - for _, filter := range filters { - tx = filter(tx) - } - if err := tx.Find(&users).Error; err != nil { - return nil, errors.WithMessagef(err, "unable to find filtered users") - } - return users, nil -} - func (d *Database) SaveUser(user User) error { create := user.Uid == "" now := time.Now() @@ -67,3 +39,33 @@ } return nil } + +// Extra functions, currently unused... + +func (d *Database) GetUser(id UserIdentifier) (User, error) { + var user User + if err := d.db.First(&user, id).Error; err != nil { + return User{}, errors.WithMessagef(err, "unable to find user %s", id) + } + return user, nil +} + +func (d *Database) GetUsers() ([]User, error) { + var users []User + if err := d.db.Find(&users).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find users") + } + return users, nil +} + +func (d *Database) GetUsersFiltered(filters ...DatabaseFilterCondition) ([]User, error) { + var users []User + tx := d.db + for _, filter := range filters { + tx = filter(tx) + } + if err := tx.Find(&users).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find filtered users") + } + return users, nil +} diff --git a/internal/persistence/wireguard.go b/internal/persistence/wireguard.go index 4a3e132..ee78270 100644 --- a/internal/persistence/wireguard.go +++ b/internal/persistence/wireguard.go @@ -8,7 +8,7 @@ func (d *Database) GetAvailableInterfaces() ([]InterfaceIdentifier, error) { var interfaces []InterfaceConfig if err := d.db.Select("identifier").Find(&interfaces).Error; err != nil { - return nil, errors.WithMessagef(err, "unable to find interfaces") + return nil, errors.WithMessage(err, "unable to find interfaces") } interfaceIds := make([]InterfaceIdentifier, len(interfaces)) @@ -22,7 +22,7 @@ func (d *Database) GetAllInterfaces(ids ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { var interfaces []InterfaceConfig if err := d.db.Where("interface IN ?", ids).Find(&interfaces).Error; err != nil { - return nil, errors.WithMessagef(err, "unable to find interfaces") + return nil, errors.WithMessage(err, "unable to find interfaces") } interfaceMap := make(map[InterfaceConfig][]PeerConfig, len(interfaces)) @@ -40,12 +40,12 @@ func (d *Database) GetInterface(id InterfaceIdentifier) (InterfaceConfig, []PeerConfig, error) { var iface InterfaceConfig if err := d.db.First(&iface, id).Error; err != nil { - return InterfaceConfig{}, nil, errors.WithMessagef(err, "unable to find interface %s", id) + return InterfaceConfig{}, nil, errors.WithMessage(err, "unable to find interface") } var peers []PeerConfig if err := d.db.Where("interface = ?", id).Find(&peers).Error; err != nil { - return InterfaceConfig{}, nil, errors.WithMessagef(err, "unable to find peers for %s", id) + return InterfaceConfig{}, nil, errors.WithMessage(err, "unable to find peers") } return iface, peers, nil @@ -58,14 +58,27 @@ return nil } -func (d *Database) SavePeer(peer PeerConfig, id InterfaceIdentifier) error { - return nil -} - func (d *Database) DeleteInterface(id InterfaceIdentifier) error { + if err := d.db.Delete(&InterfaceConfig{}, id).Error; err != nil { + return errors.WithMessage(err, "unable to delete interface") + } + return nil } -func (d *Database) DeletePeer(peerId PeerIdentifier, id InterfaceIdentifier) error { +func (d *Database) SavePeer(peer PeerConfig) error { + if err := d.db.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(&peer).Error; err != nil { + return errors.WithMessage(err, "unable to save peer") + } + + return nil +} + +func (d *Database) DeletePeer(peerId PeerIdentifier) error { + if err := d.db.Delete(&PeerConfig{}, peerId).Error; err != nil { + return errors.WithMessage(err, "unable to delete peer") + } return nil } diff --git a/internal/user/authentication.go b/internal/user/authentication.go index a25bc4c..2e81a72 100644 --- a/internal/user/authentication.go +++ b/internal/user/authentication.go @@ -8,20 +8,8 @@ "golang.org/x/crypto/bcrypt" ) -type PasswordAuthenticator struct { - store store -} - -func NewPasswordAuthenticator(store store) (*PasswordAuthenticator, error) { - a := &PasswordAuthenticator{ - store: store, - } - - return a, nil -} - -func (p *PasswordAuthenticator) PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error { - user, err := p.store.GetUser(userId) +func (p *PersistentManager) PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error { + user, err := p.GetUser(userId) if err != nil { return errors.WithMessagef(err, "unable to load user %s", userId) } @@ -33,8 +21,8 @@ return nil } -func (p *PasswordAuthenticator) HashedAuthentication(userId persistence.UserIdentifier, hashedPassword string) error { - user, err := p.store.GetUser(userId) +func (p *PersistentManager) HashedAuthentication(userId persistence.UserIdentifier, hashedPassword string) error { + user, err := p.GetUser(userId) if err != nil { return errors.WithMessagef(err, "unable to load user %s", userId) } @@ -46,7 +34,7 @@ return nil } -func (p *PasswordAuthenticator) HashPassword(plain string) (string, error) { +func (p *PersistentManager) HashPassword(plain string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(plain), bcrypt.DefaultCost) if err != nil { return "", errors.WithMessage(err, "failed to hash password") diff --git a/internal/user/manager.go b/internal/user/manager.go index d3744b6..d50edbf 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -1,20 +1,23 @@ package user import ( + "sort" + "sync" + "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" ) type Loader interface { - GetUser(id persistence.UserIdentifier) (persistence.User, error) - GetActiveUsers() ([]persistence.User, error) - GetAllUsers() ([]persistence.User, error) - GetFilteredUsers(filter ...persistence.DatabaseFilterCondition) ([]persistence.User, error) + GetUser(id persistence.UserIdentifier) (*persistence.User, error) + GetActiveUsers() ([]*persistence.User, error) + GetAllUsers() ([]*persistence.User, error) + GetFilteredUsers(filter Filter) ([]*persistence.User, error) } type Updater interface { - CreateUser(user persistence.User) error - UpdateUser(user persistence.User) error + CreateUser(user *persistence.User) error + UpdateUser(user *persistence.User) error DeleteUser(identifier persistence.UserIdentifier) error } @@ -27,6 +30,9 @@ HashPassword(plain string) (string, error) } +// Filter can be used to filter users. If this function returns true, the given user is included in the result. +type Filter func(user *persistence.User) bool + type Manager interface { Loader Updater @@ -35,10 +41,12 @@ } type PersistentManager struct { + mux sync.RWMutex // mutex to synchronize access to maps and external api clients + store store - authenticator Authenticator - hasher PasswordHasher + // internal holder of user objects + users map[persistence.UserIdentifier]*persistence.User } func NewPersistentManager(store store) (*PersistentManager, error) { @@ -46,44 +54,210 @@ return nil, errors.New("user manager requires a valid store object") } - pwa, err := NewPasswordAuthenticator(store) - if err != nil { - return nil, errors.WithMessage(err, "failed to initialize authenticator") - } - mgr := &PersistentManager{ - store: store, - authenticator: pwa, - hasher: pwa, + store: store, + + users: make(map[persistence.UserIdentifier]*persistence.User), } return mgr, nil } -func (p *PersistentManager) GetUser(id persistence.UserIdentifier) (persistence.User, error) { - return p.store.GetUser(id) +func (p *PersistentManager) GetUser(id persistence.UserIdentifier) (*persistence.User, error) { + p.mux.RLock() + defer p.mux.RUnlock() + + if !p.userExists(id) { + return nil, errors.New("no such user exists") + } + + if !p.userIsEnabled(id) { + return nil, errors.New("user is disabled") + } + + return p.users[id], nil } -func (p *PersistentManager) GetActiveUsers() ([]persistence.User, error) { - return p.store.GetUsers() +func (p *PersistentManager) GetActiveUsers() ([]*persistence.User, error) { + p.mux.RLock() + defer p.mux.RUnlock() + + users := make([]*persistence.User, 0, len(p.users)) + for _, user := range p.users { + if !user.DeletedAt.Valid { + users = append(users, user) + } + } + + // Order the users by uid + sort.Slice(users, func(i, j int) bool { + return users[i].Uid < users[j].Uid + }) + + return users, nil } -func (p *PersistentManager) GetAllUsers() ([]persistence.User, error) { - return p.store.GetUsersUnscoped() +func (p *PersistentManager) GetAllUsers() ([]*persistence.User, error) { + p.mux.RLock() + defer p.mux.RUnlock() + + users := make([]*persistence.User, 0, len(p.users)) + for _, user := range p.users { + users = append(users, user) + } + + // Order the users by uid + sort.Slice(users, func(i, j int) bool { + return users[i].Uid < users[j].Uid + }) + + return users, nil } -func (p *PersistentManager) GetFilteredUsers(filter ...persistence.DatabaseFilterCondition) ([]persistence.User, error) { - return p.store.GetUsersFiltered(filter...) +func (p *PersistentManager) GetFilteredUsers(filter Filter) ([]*persistence.User, error) { + p.mux.RLock() + defer p.mux.RUnlock() + + users := make([]*persistence.User, 0, len(p.users)) + for _, user := range p.users { + if filter == nil || filter(user) { + users = append(users, user) + } + } + + // Order the users by uid + sort.Slice(users, func(i, j int) bool { + return users[i].Uid < users[j].Uid + }) + + return users, nil } -func (p *PersistentManager) CreateUser(user persistence.User) error { - return p.store.SaveUser(user) +func (p *PersistentManager) CreateUser(user *persistence.User) error { + if err := p.checkUser(user); err != nil { + return errors.WithMessage(err, "user validation failed") + } + + p.mux.Lock() + defer p.mux.Unlock() + + if p.userExists(user.Uid) { + return errors.New("user already exists") + } + + p.users[user.Uid] = user + + err := p.persistUser(user.Uid, false) + if err != nil { + return errors.WithMessage(err, "failed to persist created user") + } + + return nil } -func (p *PersistentManager) UpdateUser(user persistence.User) error { - return p.store.SaveUser(user) +func (p *PersistentManager) UpdateUser(user *persistence.User) error { + if err := p.checkUser(user); err != nil { + return errors.WithMessage(err, "user validation failed") + } + + p.mux.Lock() + defer p.mux.Unlock() + + if !p.userExists(user.Uid) { + return errors.New("user does not exists") + } + + p.users[user.Uid] = user + + err := p.persistUser(user.Uid, false) + if err != nil { + return errors.WithMessage(err, "failed to persist updated user") + } + + return nil } -func (p *PersistentManager) DeleteUser(identifier persistence.UserIdentifier) error { - return p.store.DeleteUser(identifier) +func (p *PersistentManager) DeleteUser(id persistence.UserIdentifier) error { + p.mux.Lock() + defer p.mux.Unlock() + if !p.userExists(id) { + return errors.New("user does not exists") + } + + err := p.persistUser(id, true) + if err != nil { + return errors.WithMessage(err, "failed to persist deleted user") + } + + delete(p.users, id) + + return nil +} + +// +// -- Helpers +// + +func (p *PersistentManager) initializeFromStore() error { + if p.store == nil { + return nil // no store, nothing to do + } + + users, err := p.store.GetUsersUnscoped() + if err != nil { + return errors.WithMessage(err, "failed to get all users") + } + + for _, tmpUser := range users { + user := tmpUser + p.users[user.Uid] = &user + } + + return nil +} + +func (p *PersistentManager) userExists(id persistence.UserIdentifier) bool { + if _, ok := p.users[id]; ok { + return true + } + return false +} + +func (p *PersistentManager) userIsEnabled(id persistence.UserIdentifier) bool { + if user, ok := p.users[id]; ok && !user.DeletedAt.Valid { + return true + } + return false +} + +func (p *PersistentManager) persistUser(id persistence.UserIdentifier, delete bool) error { + if p.store == nil { + return nil // nothing to do + } + + var err error + if delete { + err = p.store.DeleteUser(id) + } else { + err = p.store.SaveUser(*p.users[id]) + } + if err != nil { + return errors.Wrapf(err, "failed to persist user") + } + + return nil +} + +func (p *PersistentManager) checkUser(user *persistence.User) error { + if user == nil { + return errors.New("user must not be nil") + } + if user.Uid == "" { + return errors.New("missing user identifier") + } + if user.Source == "" { + return errors.New("missing user source") + } + + return nil } diff --git a/internal/user/persistence.go b/internal/user/persistence.go index ccc273d..3b33b08 100644 --- a/internal/user/persistence.go +++ b/internal/user/persistence.go @@ -5,10 +5,7 @@ ) type store interface { - GetUser(id persistence.UserIdentifier) (persistence.User, error) - GetUsers() ([]persistence.User, error) GetUsersUnscoped() ([]persistence.User, error) - GetUsersFiltered(filters ...persistence.DatabaseFilterCondition) ([]persistence.User, error) SaveUser(user persistence.User) error DeleteUser(identifier persistence.UserIdentifier) error } diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 6dc27d4..df2b543 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -22,7 +22,7 @@ GetInterfaces() ([]*persistence.InterfaceConfig, error) CreateInterface(id persistence.InterfaceIdentifier) error DeleteInterface(id persistence.InterfaceIdentifier) error - UpdateInterface(id persistence.InterfaceIdentifier, cfg *persistence.InterfaceConfig) error + UpdateInterface(cfg *persistence.InterfaceConfig) error ApplyDefaultConfigs(id persistence.InterfaceIdentifier) error } diff --git a/internal/wireguard/persistence.go b/internal/wireguard/persistence.go index c1e1b7c..2a2c95a 100644 --- a/internal/wireguard/persistence.go +++ b/internal/wireguard/persistence.go @@ -11,8 +11,8 @@ GetInterface(identifier persistence.InterfaceIdentifier) (persistence.InterfaceConfig, []persistence.PeerConfig, error) SaveInterface(cfg persistence.InterfaceConfig) error - SavePeer(peer persistence.PeerConfig, interfaceIdentifier persistence.InterfaceIdentifier) error + SavePeer(peer persistence.PeerConfig) error DeleteInterface(identifier persistence.InterfaceIdentifier) error - DeletePeer(peer persistence.PeerIdentifier, interfaceIdentifier persistence.InterfaceIdentifier) error + DeletePeer(peer persistence.PeerIdentifier) error } diff --git a/internal/wireguard/test_helpers_test.go b/internal/wireguard/test_helpers_test.go index 296addb..9a5cf47 100644 --- a/internal/wireguard/test_helpers_test.go +++ b/internal/wireguard/test_helpers_test.go @@ -120,8 +120,8 @@ return args.Error(0) } -func (w *MockWireGuardStore) SavePeer(peer persistence.PeerConfig, interfaceIdentifier persistence.InterfaceIdentifier) error { - args := w.Called(peer, interfaceIdentifier) +func (w *MockWireGuardStore) SavePeer(peer persistence.PeerConfig) error { + args := w.Called(peer) return args.Error(0) } @@ -130,7 +130,7 @@ return args.Error(0) } -func (w *MockWireGuardStore) DeletePeer(peer persistence.PeerIdentifier, interfaceIdentifier persistence.InterfaceIdentifier) error { - args := w.Called(peer, interfaceIdentifier) +func (w *MockWireGuardStore) DeletePeer(peer persistence.PeerIdentifier) error { + args := w.Called(peer) return args.Error(0) } diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index cea4179..37b744f 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -74,7 +74,7 @@ return errors.WithMessage(err, "failed to create low level interface") } - newInterface := &persistence.InterfaceConfig{Identifier: id} + newInterface := &persistence.InterfaceConfig{Identifier: id, Type: persistence.InterfaceTypeServer} m.interfaces[id] = newInterface m.peers[id] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig) @@ -122,17 +122,20 @@ return nil } -func (m *wgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg *persistence.InterfaceConfig) error { +func (m *wgCtrlManager) UpdateInterface(cfg *persistence.InterfaceConfig) error { + if err := m.checkInterface(cfg); err != nil { + return errors.WithMessage(err, "interface validation failed") + } + m.mux.Lock() defer m.mux.Unlock() - if !m.deviceExists(id) { + if !m.deviceExists(cfg.Identifier) { return errors.New("interface does not exist") } - cfg.Identifier = id // ensure that the same device name is set // Update net-link attributes - link, err := m.nl.LinkByName(string(id)) + link, err := m.nl.LinkByName(string(cfg.Identifier)) if err != nil { return errors.WithMessage(err, "failed to open low level interface") } @@ -167,7 +170,7 @@ if cfg.FirewallMark != 0 { *fwMark = int(cfg.FirewallMark) } - err = m.wg.ConfigureDevice(string(id), wgtypes.Config{ + err = m.wg.ConfigureDevice(string(cfg.Identifier), wgtypes.Config{ PrivateKey: &pKey, ListenPort: &cfg.ListenPort, FirewallMark: fwMark, @@ -188,9 +191,9 @@ } // update internal map - m.interfaces[id] = cfg + m.interfaces[cfg.Identifier] = cfg - err = m.persistInterface(id, false) + err = m.persistInterface(cfg.Identifier, false) if err != nil { return errors.WithMessage(err, "failed to persist updated interface") } @@ -259,6 +262,10 @@ defer m.mux.Unlock() for _, peer := range peers { + if err := m.checkPeer(peer); err != nil { + return errors.WithMessage(err, "peer validation failed") + } + deviceId := peer.Interface.Identifier if !m.deviceExists(deviceId) { return errors.Errorf("device does not exist") @@ -362,7 +369,31 @@ m.mux.Lock() defer m.mux.Unlock() - // TODO: implement + newInterface := &cfg.InterfaceConfig + if err := m.checkInterface(newInterface); err != nil { + return errors.WithMessage(err, "interface validation failed") + } + + m.interfaces[newInterface.Identifier] = newInterface + m.peers[newInterface.Identifier] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig) + + err := m.persistInterface(newInterface.Identifier, false) + if err != nil { + return errors.WithMessage(err, "failed to persist imported interface") + } + + for _, peer := range peers { + if err := m.checkPeer(peer); err != nil { + return errors.WithMessage(err, "peer validation failed") + } + + m.peers[newInterface.Identifier][peer.Identifier] = peer + + err = m.persistPeer(peer.Identifier, false) + if err != nil { + return errors.Wrapf(err, "failed to persist imported peer %s", peer.Identifier) + } + } return nil } @@ -470,9 +501,9 @@ var err error if delete { - err = m.store.DeletePeer(id, peer.Interface.Identifier) + err = m.store.DeletePeer(id) } else { - err = m.store.SavePeer(*peer, peer.Interface.Identifier) + err = m.store.SavePeer(*peer) } if err != nil { return errors.Wrapf(err, "failed to persist peer %s", id) @@ -495,6 +526,7 @@ cfg := &ImportableInterface{} cfg.Identifier = persistence.InterfaceIdentifier(device.Name) + cfg.Type = persistence.InterfaceTypeServer // default assume that the imported device is a server device cfg.FirewallMark = int32(device.FirewallMark) cfg.KeyPair = persistence.KeyPair{ PrivateKey: device.PrivateKey.String(), @@ -547,6 +579,37 @@ return peerCfg, nil } +func (m *wgCtrlManager) checkInterface(cfg *persistence.InterfaceConfig) error { + if cfg == nil { + return errors.New("interface config must not be nil") + } + if cfg.Identifier == "" { + return errors.New("missing interface identifier") + } + if cfg.Type == "" { + return errors.New("missing interface type") + } + + return nil +} + +func (m *wgCtrlManager) checkPeer(cfg *persistence.PeerConfig) error { + if cfg == nil { + return errors.New("peer config must not be nil") + } + if cfg.Identifier == "" { + return errors.New("missing peer identifier") + } + if cfg.Interface == nil { + return errors.New("missing peer interface") + } + if cfg.Interface.Identifier == "" { + return errors.New("missing peer interface identifier") + } + + return nil +} + func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg *persistence.PeerConfig) (wgtypes.PeerConfig, error) { publicKey, err := wgtypes.ParseKey(cfg.KeyPair.PublicKey) if err != nil { diff --git a/internal/wireguard/wireguard_test.go b/internal/wireguard/wireguard_test.go index a3e17a6..77098c1 100644 --- a/internal/wireguard/wireguard_test.go +++ b/internal/wireguard/wireguard_test.go @@ -244,7 +244,7 @@ mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkDel", mock.Anything).Return(nil) st.On("DeleteInterface", persistence.InterfaceIdentifier("wg0")).Return(nil) - st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + st.On("DeletePeer", persistence.PeerIdentifier("peer0")).Return(errors.New("failure")) }, args: "wg0", wantErr: true, @@ -264,7 +264,7 @@ mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { nl.On("LinkDel", mock.Anything).Return(nil) st.On("DeleteInterface", persistence.InterfaceIdentifier("wg0")).Return(nil) - st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0")).Return(nil) }, args: "wg0", wantErr: false, @@ -289,7 +289,6 @@ func TestWgCtrlManager_UpdateInterface(t *testing.T) { type args struct { - id persistence.InterfaceIdentifier cfg *persistence.InterfaceConfig } tests := []struct { @@ -310,10 +309,7 @@ peers: nil, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, - args: args{ - id: "wg0", - }, - wantErr: true, + wantErr: true, }, { name: "NonExistentLowLevel", @@ -329,8 +325,7 @@ nl.On("LinkByName", "wg0").Return(nil, errors.New("failure")) }, args: args{ - id: "wg0", - cfg: &persistence.InterfaceConfig{}, + cfg: &persistence.InterfaceConfig{Identifier: "wg0", Type: persistence.InterfaceTypeServer}, }, wantErr: true, }, @@ -359,8 +354,8 @@ st.On("SaveInterface", mock.Anything).Return(nil) }, args: args{ - id: "wg0", cfg: &persistence.InterfaceConfig{ + Identifier: "wg0", Type: persistence.InterfaceTypeServer, Mtu: 234, AddressStr: "10.0.0.2/24,1.2.3.4/24", Enabled: true, KeyPair: persistence.KeyPair{PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, }, @@ -392,8 +387,8 @@ st.On("SaveInterface", mock.Anything).Return(nil) }, args: args{ - id: "wg0", cfg: &persistence.InterfaceConfig{ + Identifier: "wg0", Type: persistence.InterfaceTypeServer, Mtu: 234, AddressStr: "10.0.0.2/24,1.2.3.4/24", Enabled: false, KeyPair: persistence.KeyPair{PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, }, @@ -408,7 +403,7 @@ tt.manager.nl.(*MockNetlinkClient), tt.manager.store.(*MockWireGuardStore), ) - if err := tt.manager.UpdateInterface(tt.args.id, tt.args.cfg); (err != nil) != tt.wantErr { + if err := tt.manager.UpdateInterface(tt.args.cfg); (err != nil) != tt.wantErr { t.Errorf("UpdateInterface() error = %v, wantErr %v", err, tt.wantErr) } tt.manager.wg.(*MockWireGuardClient).AssertExpectations(t) @@ -460,7 +455,7 @@ }, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { - st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + st.On("SavePeer", mock.Anything).Return(errors.New("failure")) }, args: args{ id: "wg0", @@ -482,7 +477,7 @@ }, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { - st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("SavePeer", mock.Anything).Return(nil) }, args: args{ id: "wg0", @@ -578,7 +573,7 @@ peers: nil, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, - args: []*persistence.PeerConfig{{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + args: []*persistence.PeerConfig{{Identifier: "peer0", Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, wantErr: true, }, { @@ -592,7 +587,7 @@ peers: nil, }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) {}, - args: []*persistence.PeerConfig{{Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, + args: []*persistence.PeerConfig{{Identifier: "peer0", Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}}}, wantErr: true, }, { @@ -610,8 +605,9 @@ }, args: []*persistence.PeerConfig{ { - KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, - Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + Identifier: "peer0", + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, }, }, wantErr: true, @@ -630,12 +626,13 @@ }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) - st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + st.On("SavePeer", mock.Anything).Return(errors.New("failure")) }, args: []*persistence.PeerConfig{ { - KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, - Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + Identifier: "peer0", + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, }, }, wantErr: true, @@ -654,12 +651,13 @@ }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) - st.On("SavePeer", mock.Anything, persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("SavePeer", mock.Anything).Return(nil) }, args: []*persistence.PeerConfig{ { - KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, - Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, + Identifier: "peer0", + KeyPair: persistence.KeyPair{PublicKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI=", PrivateKey: "pcDxSxSZp5x87cNoRJaHdAOzxrxDfDUn7pGmrY/AmzI="}, + Interface: &persistence.PeerInterfaceConfig{Identifier: "wg0"}, }, }, wantErr: false, @@ -748,7 +746,7 @@ }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) - st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(errors.New("failure")) + st.On("DeletePeer", persistence.PeerIdentifier("peer0")).Return(errors.New("failure")) }, args: "peer0", wantErr: true, @@ -773,7 +771,7 @@ }, mockSetup: func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) { wg.On("ConfigureDevice", "wg0", mock.Anything).Return(nil) - st.On("DeletePeer", persistence.PeerIdentifier("peer0"), persistence.InterfaceIdentifier("wg0")).Return(nil) + st.On("DeletePeer", persistence.PeerIdentifier("peer0")).Return(nil) }, args: "peer0", wantErr: false,