diff --git a/internal/lowlevel/wrappers.go b/internal/lowlevel/wrappers.go index 0aa57ab..9b409ed 100644 --- a/internal/lowlevel/wrappers.go +++ b/internal/lowlevel/wrappers.go @@ -25,6 +25,7 @@ LinkSetMTU(link netlink.Link, mtu int) error AddrReplace(link netlink.Link, addr *netlink.Addr) error AddrAdd(link netlink.Link, addr *netlink.Addr) error + AddrList(link netlink.Link) ([]netlink.Addr, error) } type NetlinkManager struct { @@ -53,3 +54,21 @@ func (n NetlinkManager) AddrAdd(link netlink.Link, addr *netlink.Addr) error { return netlink.AddrAdd(link, addr) } + +func (n NetlinkManager) AddrList(link netlink.Link) ([]netlink.Addr, error) { + listIPv4, err := netlink.AddrList(link, netlink.FAMILY_V4) + if err != nil { + return nil, err + } + + listIPv6, err := netlink.AddrList(link, netlink.FAMILY_V6) + if err != nil { + return nil, err + } + + ipAddresses := make([]netlink.Addr, 0, len(listIPv4)+len(listIPv6)) + ipAddresses = append(ipAddresses, listIPv4...) + ipAddresses = append(ipAddresses, listIPv6...) + + return ipAddresses, nil +} diff --git a/internal/wireguard/backend_db.go b/internal/wireguard/backend_db.go index 414b9e9..c0eadeb 100644 --- a/internal/wireguard/backend_db.go +++ b/internal/wireguard/backend_db.go @@ -10,6 +10,8 @@ "gorm.io/gorm" ) +var DatabaseBackendName = "db" + type DatabaseBackend struct { db *gorm.DB } @@ -26,6 +28,10 @@ return backend, nil } +func (d DatabaseBackend) Name() string { + return DatabaseBackendName +} + func (d DatabaseBackend) SaveInterface(cfg InterfaceConfig, _ []PeerConfig) error { iface, peerDefaults := convertInterface(cfg) @@ -170,9 +176,7 @@ return interfaceConfig, peerConfigs, nil } -func (d DatabaseBackend) LoadAll(ignored ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { - interfaceIdentifiers := []DeviceIdentifier{} // TODO: fill this ?! - +func (d DatabaseBackend) LoadAll(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { result := make(map[InterfaceConfig][]PeerConfig) for _, identifier := range interfaceIdentifiers { iface, peers, err := d.Load(identifier) @@ -185,6 +189,20 @@ return result, nil } +func (d DatabaseBackend) GetAvailableInterfaces() ([]DeviceIdentifier, error) { + var iface []dbInterfaceConfig + if err := d.db.Find(&iface).Error; err != nil { + return nil, errors.Wrap(err, "failed to load interfaces from db") + } + + interfaces := make([]DeviceIdentifier, len(iface)) + for i := range iface { + interfaces[i] = DeviceIdentifier(iface[i].DeviceName) + } + + return interfaces, nil +} + // // --- Models // diff --git a/internal/wireguard/backend_file.go b/internal/wireguard/backend_file.go index 6b330a4..de4535a 100644 --- a/internal/wireguard/backend_file.go +++ b/internal/wireguard/backend_file.go @@ -18,6 +18,10 @@ return backend, nil } +func (f FileBackend) Name() string { + return "file" +} + func (f FileBackend) SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error { configContents, err := f.fileGenerator.GetInterfaceConfig(cfg, peers) if err != nil { @@ -58,9 +62,13 @@ } func (f FileBackend) Load(identifier DeviceIdentifier) (InterfaceConfig, []PeerConfig, error) { - panic("implement me") + return InterfaceConfig{}, nil, nil } -func (f FileBackend) LoadAll(ignored ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { - panic("implement me") +func (f FileBackend) LoadAll(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { + return nil, nil +} + +func (f FileBackend) GetAvailableInterfaces() ([]DeviceIdentifier, error) { + return nil, nil } diff --git a/internal/wireguard/configuration.go b/internal/wireguard/configuration.go index a026201..ed8e6fa 100644 --- a/internal/wireguard/configuration.go +++ b/internal/wireguard/configuration.go @@ -187,8 +187,13 @@ DisabledAt *time.Time } +type Name interface { + Name() string +} + // ConfigWriter provides methods for updating persistent backends (like a database or a WireGuard configuration file) type ConfigWriter interface { + Name SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error SavePeer(peer PeerConfig, cfg InterfaceConfig) error DeleteInterface(cfg InterfaceConfig, peers []PeerConfig) error @@ -197,6 +202,8 @@ // ConfigLoader provides methods to load interface and peer configurations from a persistent backend. type ConfigLoader interface { + Name Load(identifier DeviceIdentifier) (InterfaceConfig, []PeerConfig, error) - LoadAll(ignored ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) + LoadAll(interfaceIdentifiers ...DeviceIdentifier) (map[InterfaceConfig][]PeerConfig, error) + GetAvailableInterfaces() ([]DeviceIdentifier, error) } diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index defb042..7c2dca8 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -5,6 +5,7 @@ "fmt" "net" "os" + "sort" "strconv" "strings" "sync" @@ -24,6 +25,7 @@ // DeviceManager provides methods to create/update/delete physical WireGuard devices. type DeviceManager interface { + GetDevices() ([]InterfaceConfig, error) CreateDevice(device DeviceIdentifier) error DeleteDevice(device DeviceIdentifier) error UpdateDevice(device DeviceIdentifier, cfg InterfaceConfig) error @@ -35,6 +37,8 @@ RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error } +type Opt func(svc *ManagementUtil) + type Manager interface { KeyGenerator DeviceManager @@ -47,6 +51,8 @@ wg lowlevel.WireGuardClient nl lowlevel.NetlinkClient + unmanagedInterfaces []DeviceIdentifier // Those interfaces are completely ignored by WireGuard Portal + // config writers and loaders are used to populate the internal config maps cw []ConfigWriter cl []ConfigLoader @@ -57,6 +63,44 @@ peers map[DeviceIdentifier]map[PeerIdentifier]PeerConfig } +func NewManagementUtil(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, opts ...Opt) (*ManagementUtil, error) { + m := &ManagementUtil{ + mux: sync.RWMutex{}, + wg: wg, + nl: nl, + } + + for _, opt := range opts { + opt(m) + } + + // initialize + err := m.initialize() + if err != nil { + return nil, errors.Wrap(err, "failed to initialize WireGuard manager") + } + + return m, nil +} + +func IgnoredInterfaces(ignored ...DeviceIdentifier) Opt { + return func(m *ManagementUtil) { + m.unmanagedInterfaces = ignored + } +} + +func ConfigLoaders(cl ...ConfigLoader) Opt { + return func(m *ManagementUtil) { + m.cl = cl + } +} + +func ConfigWriters(cw ...ConfigWriter) Opt { + return func(m *ManagementUtil) { + m.cw = cw + } +} + func (m *ManagementUtil) GetFreshKeypair() (KeyPair, error) { privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -78,25 +122,29 @@ return PreSharedKey(preSharedKey.String()), nil } +func (m *ManagementUtil) GetDevices() ([]InterfaceConfig, error) { + interfaces := make([]InterfaceConfig, 0, len(m.interfaces)) + for _, iface := range interfaces { + interfaces = append(interfaces, iface) + } + // Order the interfaces by device name + sort.Slice(interfaces, func(i, j int) bool { + return interfaces[i].DeviceName < interfaces[j].DeviceName + }) + + return interfaces, nil +} + func (m *ManagementUtil) CreateDevice(identifier DeviceIdentifier) error { m.mux.Lock() defer m.mux.Unlock() if m.deviceExists(identifier) { return errors.Errorf("device %s already exists", identifier) } - link := &netlink.GenericLink{ - LinkAttrs: netlink.LinkAttrs{ - Name: string(identifier), - }, - LinkType: "wireguard", - } - err := m.nl.LinkAdd(link) - if err != nil { - return errors.Wrapf(err, "failed to create WireGuard interface") - } - if err := m.nl.LinkSetUp(link); err != nil { - return errors.Wrapf(err, "failed to enable WireGuard interface") + err := m.createWgDevice(identifier) + if err != nil { + return errors.Wrapf(err, "failed to create WireGuard interface %s", identifier) } newInterface := InterfaceConfig{DeviceName: identifier} @@ -110,6 +158,25 @@ return nil } +func (m *ManagementUtil) createWgDevice(identifier DeviceIdentifier) error { + link := &netlink.GenericLink{ + LinkAttrs: netlink.LinkAttrs{ + Name: string(identifier), + }, + LinkType: "wireguard", + } + err := m.nl.LinkAdd(link) + if err != nil { + return errors.Wrapf(err, "failed to create WireGuard interface %s", identifier) + } + + if err := m.nl.LinkSetUp(link); err != nil { + return errors.Wrapf(err, "failed to enable WireGuard interface %s", identifier) + } + + return nil +} + func (m *ManagementUtil) DeleteDevice(identifier DeviceIdentifier) error { m.mux.Lock() defer m.mux.Unlock() @@ -284,6 +351,154 @@ return nil } +// TODO: implement/think about +func (m *ManagementUtil) initialize() error { + // Load all interfaces from the database + backendInterfaces, err := m.getBackendInterfaces(DatabaseBackendName) + if err != nil { + return errors.Wrap(err, "failed to load backend interfaces") + } + + /*// Get a list of available WireGuard interfaces + wgInterfaces, err := m.wg.Devices() + if err != nil { + return errors.Wrap(err, "failed to load WireGuard interfaces") + } + + // Create missing WireGuard interfaces + for _, backendInterface := range backendInterfaces { + exists := false + for _, wgInterface := range wgInterfaces { + if string(backendInterface) == wgInterface.Name { + exists = true + break + } + } + if !exists { + err := m.createWgDevice(backendInterface) + if err != nil { + return errors.Wrapf(err, "failed to create WireGuard interface %s found in backend", backendInterface) + } + } + }*/ + + // Load config options from database backend, populate internal state maps + err = m.loadBackendInterfaces(DatabaseBackendName, backendInterfaces...) + if err != nil { + return errors.Wrap(err, "failed to load interface configurations from backend") + } + + // Load missing config options from current interfaces, populate internal state maps + err = m.loadWireGuardInterfaces() + if err != nil { + return errors.Wrap(err, "failed to load interface configurations from WireGuard") + } + + // Persists currently loaded configurations + // TODO + + // Apply configuration options from internal state maps to current interfaces + // TODO + + return nil +} + +func (m *ManagementUtil) getBackendInterfaces(backend string) ([]DeviceIdentifier, error) { + // Load all interfaces from the config loader backends + uniqueInterfaces := make(map[DeviceIdentifier]struct{}) + for _, cl := range m.cl { + if cl.Name() != backend { + continue + } + + backendInterfaces, err := cl.GetAvailableInterfaces() + if err != nil { + return nil, errors.Wrapf(err, "failed to load available interfaces from backend %s", cl.Name()) + } + for _, iface := range backendInterfaces { + uniqueInterfaces[iface] = struct{}{} + } + } + + interfaces := make([]DeviceIdentifier, 0, len(uniqueInterfaces)) + for iface := range uniqueInterfaces { + interfaces = append(interfaces, iface) + } + return interfaces, nil +} + +func (m *ManagementUtil) loadBackendInterfaces(backend string, identifiers ...DeviceIdentifier) error { + for _, cl := range m.cl { + if cl.Name() != backend { + continue + } + ifaceAndPeers, err := cl.LoadAll(identifiers...) + if err != nil { + return errors.Wrapf(err, "failed to load interfaces from backend %s", cl.Name()) + } + + for iface, peers := range ifaceAndPeers { + m.interfaces[iface.DeviceName] = iface + for _, peer := range peers { + m.peers[iface.DeviceName][peer.Uid] = peer + } + } + } + return nil +} + +func (m *ManagementUtil) loadWireGuardInterfaces() error { + // Get a list of available WireGuard interfaces + wgInterfaces, err := m.wg.Devices() + if err != nil { + return errors.Wrap(err, "failed to load WireGuard interfaces") + } + + for _, iface := range wgInterfaces { + if m.interfaceIsIgnored(DeviceIdentifier(iface.Name)) { + continue + } + + devId := DeviceIdentifier(iface.Name) + if _, existing := m.interfaces[devId]; !existing { + m.interfaces[devId] = m.convertWireGuardInterface(*iface) + } + + for _, peer := range iface.Peers { + peerPublicKey := peer.PublicKey.String() + + // check if peer exists, compare public keys + existing := false + for _, existingPeer := range m.peers[devId] { + if existingPeer.KeyPair.PublicKey == peerPublicKey { + existing = true + break + } + } + + if !existing { + // Use the peers public key as UID + m.peers[devId][PeerIdentifier(peerPublicKey)] = m.convertWireGuardPeer(peer) + } + + } + } + return nil +} + +func (m *ManagementUtil) restoreBackendInterfaces() error { + return nil +} + +func (m *ManagementUtil) interfaceIsIgnored(name DeviceIdentifier) bool { + for _, iface := range m.unmanagedInterfaces { + if iface == name { + return true + } + } + return false +} + // // ---- Helpers // @@ -594,3 +809,53 @@ return addresses, nil } + +func (m *ManagementUtil) convertWireGuardInterface(device wgtypes.Device) InterfaceConfig { + cfg := InterfaceConfig{ + DeviceName: DeviceIdentifier(device.Name), + KeyPair: KeyPair{PublicKey: device.PublicKey.String(), PrivateKey: device.PrivateKey.String()}, + ListenPort: device.ListenPort, + FirewallMark: int32(device.FirewallMark), + DriverType: device.Type.String(), + } + + link, err := m.nl.LinkByName(device.Name) + if err != nil || link.Attrs() == nil { + return cfg + } + cfg.Mtu = link.Attrs().MTU + + addresses, err := m.nl.AddrList(link) + if err != nil { + return cfg + } + addressesStr := make([]string, len(addresses)) + for i := range addresses { + addressesStr[i] = addresses[i].String() + } + cfg.AddressStr = strings.Join(addressesStr, ",") + + return cfg +} + +func (m *ManagementUtil) convertWireGuardPeer(peer wgtypes.Peer) PeerConfig { + cfg := PeerConfig{ + KeyPair: KeyPair{PublicKey: peer.PublicKey.String()}, + } + + if peer.Endpoint != nil { + cfg.Endpoint.Value = peer.Endpoint.String() + } + + if peer.PresharedKey != (wgtypes.Key{}) { + cfg.PresharedKey = peer.PresharedKey.String() + } + + ipAddresses := make([]string, len(peer.AllowedIPs)) // use allowed IP's as the peer IP's + for i, ip := range peer.AllowedIPs { + ipAddresses[i] = ip.String() + } + cfg.AddressStr.Value = strings.Join(ipAddresses, ",") + + return cfg +}