diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 7901416..97e0a3e 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -40,8 +40,8 @@ type PeerManager interface { GetPeers(device persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) - SavePeers(device persistence.InterfaceIdentifier, peers ...persistence.PeerConfig) error - RemovePeer(device persistence.InterfaceIdentifier, peer persistence.PeerIdentifier) error + SavePeers(peers ...persistence.PeerConfig) error + RemovePeer(peer persistence.PeerIdentifier) error } type Manager interface { diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 232b3a9..6e68e8d 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -7,13 +7,11 @@ "sync" "time" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - - "github.com/vishvananda/netlink" - "github.com/h44z/wg-portal/internal/lowlevel" "github.com/h44z/wg-portal/internal/persistence" "github.com/pkg/errors" + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) type WgCtrlManager struct { @@ -32,6 +30,23 @@ peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig } +func NewWgCtrlManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*WgCtrlManager, error) { + m := &WgCtrlManager{ + mux: sync.RWMutex{}, + wg: wg, + nl: nl, + store: store, + interfaces: make(map[persistence.InterfaceIdentifier]persistence.InterfaceConfig), + peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig), + } + + if err := m.initializeFromStore(); err != nil { + return nil, errors.WithMessage(err, "failed to initialize manager from store") + } + + return m, nil +} + func (m *WgCtrlManager) GetInterfaces() ([]persistence.InterfaceConfig, error) { m.mux.RLock() defer m.mux.RUnlock() @@ -61,6 +76,7 @@ newInterface := persistence.InterfaceConfig{Identifier: id} m.interfaces[id] = newInterface + m.peers[id] = make(map[persistence.PeerIdentifier]persistence.PeerConfig) err = m.persistInterface(id, false) if err != nil { @@ -93,7 +109,15 @@ return errors.WithMessage(err, "failed to persist deleted interface") } + for peerId := range m.peers[id] { + err = m.persistPeer(peerId, true) + if err != nil { + return errors.WithMessagef(err, "failed to persist deleted peer %s", peerId) + } + } + delete(m.interfaces, id) + delete(m.peers, id) return nil } @@ -255,6 +279,34 @@ // -- Helpers // +func (m *WgCtrlManager) initializeFromStore() error { + if m.store == nil { + return nil // no store, nothing to do + } + + interfaceIds, err := m.store.GetAvailableInterfaces() + if err != nil { + return errors.WithMessage(err, "failed to get available interfaces") + } + + interfaces, err := m.store.GetAllInterfaces(interfaceIds...) + if err != nil { + return errors.WithMessage(err, "failed to get all interfaces") + } + + for cfg, peers := range interfaces { + m.interfaces[cfg.Identifier] = cfg + if _, ok := m.peers[cfg.Identifier]; !ok { + m.peers[cfg.Identifier] = make(map[persistence.PeerIdentifier]persistence.PeerConfig) + } + for _, peer := range peers { + m.peers[cfg.Identifier][peer.Identifier] = peer + } + } + + return nil +} + func (m *WgCtrlManager) createLowLevelInterface(id persistence.InterfaceIdentifier) error { link := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{