diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 97e0a3e..8890db7 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -2,7 +2,6 @@ import ( "io" - "sync" "github.com/h44z/wg-portal/internal/lowlevel" @@ -30,7 +29,7 @@ type ImportManager interface { GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) - ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) + ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error } type ConfigFileGenerator interface { @@ -59,28 +58,11 @@ type PersistentManager struct { WgCtrlKeyGenerator TemplateHandler - - mux sync.RWMutex // mutex to synchronize access to maps - - // external api clients - wg lowlevel.WireGuardClient - nl lowlevel.NetlinkClient - - // persistent backend - store store - - // internal holder of interface configurations - interfaces map[persistence.InterfaceIdentifier]persistence.InterfaceConfig - // internal holder of peer configurations - peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig + WgCtrlManager } func NewPersistentManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*PersistentManager, error) { - m := &PersistentManager{ - mux: sync.RWMutex{}, - wg: wg, - nl: nl, - } + m := &PersistentManager{} return m, nil } diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 6e68e8d..88710b7 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -275,6 +275,51 @@ return nil } +func (m *WgCtrlManager) GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) { + devices, err := m.wg.Devices() + if err != nil { + return nil, errors.WithMessage(err, "failed to get WireGuard device list") + } + + m.mux.RLock() + defer m.mux.RUnlock() + + interfaces := make(map[ImportableInterface][]persistence.PeerConfig, len(devices)) + for d, device := range devices { + if _, exists := m.interfaces[persistence.InterfaceIdentifier(device.Name)]; exists { + continue // interface already managed, skip + } + + cfg, err := m.convertWireGuardInterface(devices[d]) + if err != nil { + return nil, errors.WithMessagef(err, "failed to convert WireGuard interface %s", device.Name) + } + + interfaces[cfg] = make([]persistence.PeerConfig, len(device.Peers)) + + for p, peer := range device.Peers { + peerCfg, err := m.convertWireGuardPeer(&device.Peers[p], cfg) + if err != nil { + return nil, errors.WithMessagef(err, "failed to convert WireGuard peer %s from %s", + peer.PublicKey.String(), device.Name) + } + + interfaces[cfg][p] = peerCfg + } + } + + return interfaces, nil +} + +func (m *WgCtrlManager) ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error { + m.mux.Lock() + defer m.mux.Unlock() + + // TODO: implement + + return nil +} + // // -- Helpers // @@ -403,6 +448,62 @@ return persistence.PeerConfig{}, errors.New("peer not found") } +func (m *WgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (ImportableInterface, error) { + cfg := ImportableInterface{} + + cfg.Identifier = persistence.InterfaceIdentifier(device.Name) + cfg.FirewallMark = int32(device.FirewallMark) + cfg.KeyPair = persistence.KeyPair{ + PrivateKey: device.PrivateKey.String(), + PublicKey: device.PublicKey.String(), + } + cfg.ListenPort = device.ListenPort + cfg.DriverType = device.Type.String() + + lowLevelInterface, err := m.nl.LinkByName(device.Name) + if err != nil { + return ImportableInterface{}, errors.WithMessagef(err, "failed to get low level interface for %s", device.Name) + } + cfg.Mtu = lowLevelInterface.Attrs().MTU + ipAddresses, err := m.nl.AddrList(lowLevelInterface) + if err != nil { + return ImportableInterface{}, errors.WithMessagef(err, "failed to get low level addresses for %s", device.Name) + } + cfg.AddressStr = ipAddressesToString(ipAddresses) + + return cfg, nil +} + +func (m *WgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev ImportableInterface) (persistence.PeerConfig, error) { + peerCfg := persistence.PeerConfig{} + peerCfg.Identifier = persistence.PeerIdentifier(peer.PublicKey.String()) + peerCfg.KeyPair = persistence.KeyPair{ + PublicKey: peer.PublicKey.String(), + } + peerCfg.DisplayName = "Autodetected Peer (" + peer.PublicKey.String()[0:8] + ")" + if peer.Endpoint != nil { + peerCfg.Endpoint = persistence.NewStringConfigOption(peer.Endpoint.String(), true) + } + if peer.PresharedKey != (wgtypes.Key{}) { + peerCfg.PresharedKey = peer.PresharedKey.String() + } + allowedIPs := make([]string, len(peer.AllowedIPs)) // use allowed IP's as the peer IP's + for i, ip := range peer.AllowedIPs { + allowedIPs[i] = ip.String() + } + peerCfg.AllowedIPsStr = persistence.NewStringConfigOption(strings.Join(allowedIPs, ","), true) + peerCfg.PersistentKeepalive = persistence.NewIntConfigOption(int(peer.PersistentKeepaliveInterval.Seconds()), true) + + peerCfg.PeerInterfaceConfig = persistence.PeerInterfaceConfig{ + Identifier: dev.Identifier, + AddressStr: persistence.NewStringConfigOption(dev.AddressStr, true), // todo: correct? + DnsStr: persistence.NewStringConfigOption(dev.DnsStr, true), + Mtu: persistence.NewIntConfigOption(dev.Mtu, true), + } + + return peerCfg, nil +} + func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) { rawAddresses := strings.Split(addrStr, ",") addresses := make([]*netlink.Addr, 0, len(rawAddresses)) @@ -421,6 +522,15 @@ return addresses, nil } +func ipAddressesToString(addresses []netlink.Addr) string { + addressesStr := make([]string, len(addresses)) + for i := range addresses { + addressesStr[i] = addresses[i].String() + } + + return strings.Join(addressesStr, ",") +} + func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg persistence.PeerConfig) (wgtypes.PeerConfig, error) { publicKey, err := wgtypes.ParseKey(cfg.KeyPair.PublicKey) if err != nil {