package wireguard
import (
"net"
"sort"
"strings"
"sync"
"time"
"github.com/h44z/wg-portal/internal/lowlevel"
"github.com/h44z/wg-portal/internal/persistence"
"github.com/pkg/errors"
"github.com/vishvananda/netlink"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type wgCtrlManager struct {
mux sync.RWMutex // mutex to synchronize access to maps and external api clients
// external api clients
wg lowlevel.WireGuardClient
nl lowlevel.NetlinkClient
// optional persistent backend
store store
// internal holder of interface configurations
interfaces map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig
// internal holder of peer configurations
peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig
}
func newWgCtrlManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*wgCtrlManager, error) {
m := &wgCtrlManager{
mux: sync.RWMutex{},
wg: wg,
nl: nl,
store: store,
interfaces: make(map[persistence.InterfaceIdentifier]*persistence.InterfaceConfig),
peers: make(map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]*persistence.PeerConfig),
}
if err := m.initializeFromStore(); err != nil {
return nil, errors.WithMessage(err, "failed to initialize manager from store")
}
return m, nil
}
func (m *wgCtrlManager) GetInterfaces() ([]*persistence.InterfaceConfig, error) {
m.mux.RLock()
defer m.mux.RUnlock()
interfaces := make([]*persistence.InterfaceConfig, 0, len(m.interfaces))
for _, iface := range m.interfaces {
interfaces = append(interfaces, iface)
}
// Order the interfaces by device name
sort.Slice(interfaces, func(i, j int) bool {
return interfaces[i].Identifier < interfaces[j].Identifier
})
return interfaces, nil
}
func (m *wgCtrlManager) GetInterface(id persistence.InterfaceIdentifier) (*persistence.InterfaceConfig, error) {
m.mux.RLock()
defer m.mux.RUnlock()
if !m.deviceExists(id) {
return nil, errors.New("device does not exist")
}
return m.interfaces[id], nil
}
func (m *wgCtrlManager) CreateInterface(id persistence.InterfaceIdentifier) error {
m.mux.Lock()
defer m.mux.Unlock()
if m.deviceExists(id) {
return errors.New("device already exists")
}
err := m.createLowLevelInterface(id)
if err != nil {
return errors.WithMessage(err, "failed to create low level interface")
}
newInterface := &persistence.InterfaceConfig{Identifier: id, Type: persistence.InterfaceTypeServer}
m.interfaces[id] = newInterface
m.peers[id] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig)
err = m.persistInterface(id, false)
if err != nil {
return errors.WithMessage(err, "failed to persist created interface")
}
return nil
}
func (m *wgCtrlManager) DeleteInterface(id persistence.InterfaceIdentifier) error {
m.mux.Lock()
defer m.mux.Unlock()
if !m.deviceExists(id) {
return errors.New("interface does not exist")
}
err := m.nl.LinkDel(&netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: string(id),
},
LinkType: "wireguard",
})
if err != nil {
return errors.WithMessage(err, "failed to delete low level interface")
}
err = m.persistInterface(id, true)
if err != nil {
return errors.WithMessage(err, "failed to persist deleted interface")
}
for peerId := range m.peers[id] {
err = m.persistPeer(peerId, true)
if err != nil {
return errors.WithMessagef(err, "failed to persist deleted peer %s", peerId)
}
}
delete(m.interfaces, id)
delete(m.peers, id)
return nil
}
func (m *wgCtrlManager) UpdateInterface(cfg *persistence.InterfaceConfig) error {
if err := m.checkInterface(cfg); err != nil {
return errors.WithMessage(err, "interface validation failed")
}
m.mux.Lock()
defer m.mux.Unlock()
if !m.deviceExists(cfg.Identifier) {
return errors.New("interface does not exist")
}
// Update net-link attributes
link, err := m.nl.LinkByName(string(cfg.Identifier))
if err != nil {
return errors.WithMessage(err, "failed to open low level interface")
}
if cfg.Mtu != 0 {
if err := m.nl.LinkSetMTU(link, cfg.Mtu); err != nil {
return errors.WithMessage(err, "failed to set MTU")
}
}
addresses, err := parseIpAddressString(cfg.AddressStr)
if err != nil {
return errors.WithMessage(err, "failed to parse ip address")
}
for i := 0; i < len(addresses); i++ {
var err error
if i == 0 {
err = m.nl.AddrReplace(link, addresses[i])
} else {
err = m.nl.AddrAdd(link, addresses[i])
}
if err != nil {
return errors.WithMessage(err, "failed to set ip address")
}
}
// Update WireGuard attributes
pKey, err := wgtypes.NewKey(GetPrivateKeyBytes(cfg.KeyPair))
if err != nil {
return errors.WithMessage(err, "failed to parse private key bytes")
}
var fwMark *int
if cfg.FirewallMark != 0 {
*fwMark = int(cfg.FirewallMark)
}
err = m.wg.ConfigureDevice(string(cfg.Identifier), wgtypes.Config{
PrivateKey: &pKey,
ListenPort: &cfg.ListenPort,
FirewallMark: fwMark,
})
if err != nil {
return errors.WithMessage(err, "failed to update WireGuard settings")
}
// Update link state
if cfg.Enabled {
if err := m.nl.LinkSetUp(link); err != nil {
return errors.WithMessage(err, "failed to enable low level interface")
}
} else {
if err := m.nl.LinkSetDown(link); err != nil {
return errors.WithMessage(err, "failed to disable low level interface")
}
}
// update internal map
m.interfaces[cfg.Identifier] = cfg
err = m.persistInterface(cfg.Identifier, false)
if err != nil {
return errors.WithMessage(err, "failed to persist updated interface")
}
return nil
}
func (m *wgCtrlManager) ApplyDefaultConfigs(id persistence.InterfaceIdentifier) error {
m.mux.Lock()
defer m.mux.Unlock()
if !m.deviceExists(id) {
return errors.New("device does not exist")
}
cfg := m.interfaces[id]
for p := range m.peers[id] {
m.peers[id][p].Endpoint.TrySetValue(cfg.PeerDefEndpoint)
m.peers[id][p].AllowedIPsStr.TrySetValue(cfg.PeerDefAllowedIPsStr)
m.peers[id][p].Interface.Identifier = cfg.Identifier
m.peers[id][p].Interface.Type = cfg.Type
m.peers[id][p].Interface.PublicKey = cfg.KeyPair.PublicKey
m.peers[id][p].Interface.DnsStr.TrySetValue(cfg.PeerDefDnsStr)
m.peers[id][p].Interface.Mtu.TrySetValue(cfg.PeerDefMtu)
m.peers[id][p].Interface.FirewallMark.TrySetValue(cfg.PeerDefFirewallMark)
m.peers[id][p].Interface.RoutingTable.TrySetValue(cfg.PeerDefRoutingTable)
m.peers[id][p].Interface.PreUp.TrySetValue(cfg.PeerDefPreUp)
m.peers[id][p].Interface.PostUp.TrySetValue(cfg.PeerDefPostUp)
m.peers[id][p].Interface.PreDown.TrySetValue(cfg.PeerDefPreDown)
m.peers[id][p].Interface.PostDown.TrySetValue(cfg.PeerDefPostDown)
err := m.persistPeer(m.peers[id][p].Identifier, false)
if err != nil {
return errors.Wrapf(err, "failed to persist peer defaults to %s", m.peers[id][p].Identifier)
}
}
return nil
}
func (m *wgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]*persistence.PeerConfig, error) {
m.mux.RLock()
defer m.mux.RUnlock()
if !m.deviceExists(interfaceId) {
return nil, errors.New("device does not exist")
}
peers := make([]*persistence.PeerConfig, 0, len(m.peers[interfaceId]))
for i := range m.peers[interfaceId] {
peers = append(peers, m.peers[interfaceId][i])
}
sort.Slice(peers, func(i, j int) bool {
return peers[i].Identifier < peers[j].Identifier
})
return peers, nil
}
func (m *wgCtrlManager) SavePeers(peers ...*persistence.PeerConfig) error {
m.mux.Lock()
defer m.mux.Unlock()
for _, peer := range peers {
if err := m.checkPeer(peer); err != nil {
return errors.WithMessage(err, "peer validation failed")
}
deviceId := peer.Interface.Identifier
if !m.deviceExists(deviceId) {
return errors.Errorf("device does not exist")
}
deviceConfig := m.interfaces[deviceId]
wgPeer, err := getWireGuardPeerConfig(deviceConfig.Type, peer)
if err != nil {
return errors.WithMessagef(err, "could not generate WireGuard peer configuration for %s", peer.Identifier)
}
err = m.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}})
if err != nil {
return errors.Wrapf(err, "could not save peer %s to WireGuard device %s", peer.Identifier, deviceId)
}
m.peers[deviceId][peer.Identifier] = peer
err = m.persistPeer(peer.Identifier, false)
if err != nil {
return errors.Wrapf(err, "failed to persist updated peer %s", peer.Identifier)
}
}
return nil
}
func (m *wgCtrlManager) RemovePeer(id persistence.PeerIdentifier) error {
m.mux.Lock()
defer m.mux.Unlock()
if !m.peerExists(id) {
return errors.Errorf("peer does not exist")
}
peer, _ := m.getPeer(id)
deviceId := peer.Interface.Identifier
publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey)
if err != nil {
return errors.WithMessage(err, "invalid public key")
}
wgPeer := wgtypes.PeerConfig{
PublicKey: publicKey,
Remove: true,
}
err = m.wg.ConfigureDevice(string(deviceId), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}})
if err != nil {
return errors.WithMessage(err, "could not remove peer from WireGuard interface")
}
err = m.persistPeer(id, true)
if err != nil {
return errors.WithMessage(err, "failed to persist deleted peer")
}
delete(m.peers[deviceId], id)
return nil
}
func (m *wgCtrlManager) GetImportableInterfaces() (map[*ImportableInterface][]*persistence.PeerConfig, error) {
devices, err := m.wg.Devices()
if err != nil {
return nil, errors.WithMessage(err, "failed to get WireGuard device list")
}
m.mux.RLock()
defer m.mux.RUnlock()
interfaces := make(map[*ImportableInterface][]*persistence.PeerConfig, len(devices))
for d, device := range devices {
if _, exists := m.interfaces[persistence.InterfaceIdentifier(device.Name)]; exists {
continue // interface already managed, skip
}
cfg, err := m.convertWireGuardInterface(devices[d])
if err != nil {
return nil, errors.WithMessagef(err, "failed to convert WireGuard interface %s", device.Name)
}
interfaces[cfg] = make([]*persistence.PeerConfig, len(device.Peers))
for p, peer := range device.Peers {
peerCfg, err := m.convertWireGuardPeer(&device.Peers[p], cfg)
if err != nil {
return nil, errors.WithMessagef(err, "failed to convert WireGuard peer %s from %s",
peer.PublicKey.String(), device.Name)
}
interfaces[cfg][p] = peerCfg
}
}
return interfaces, nil
}
func (m *wgCtrlManager) ImportInterface(cfg *ImportableInterface, peers []*persistence.PeerConfig) error {
m.mux.Lock()
defer m.mux.Unlock()
newInterface := &cfg.InterfaceConfig
if err := m.checkInterface(newInterface); err != nil {
return errors.WithMessage(err, "interface validation failed")
}
m.interfaces[newInterface.Identifier] = newInterface
m.peers[newInterface.Identifier] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig)
err := m.persistInterface(newInterface.Identifier, false)
if err != nil {
return errors.WithMessage(err, "failed to persist imported interface")
}
for _, peer := range peers {
if err := m.checkPeer(peer); err != nil {
return errors.WithMessage(err, "peer validation failed")
}
m.peers[newInterface.Identifier][peer.Identifier] = peer
err = m.persistPeer(peer.Identifier, false)
if err != nil {
return errors.Wrapf(err, "failed to persist imported peer %s", peer.Identifier)
}
}
return nil
}
//
// -- Helpers
//
func (m *wgCtrlManager) initializeFromStore() error {
if m.store == nil {
return nil // no store, nothing to do
}
interfaceIds, err := m.store.GetAvailableInterfaces()
if err != nil {
return errors.WithMessage(err, "failed to get available interfaces")
}
interfaces, err := m.store.GetAllInterfaces(interfaceIds...)
if err != nil {
return errors.WithMessage(err, "failed to get all interfaces")
}
for tmpCfg, tmpPeers := range interfaces {
cfg := tmpCfg
peers := tmpPeers
m.interfaces[cfg.Identifier] = &cfg
if _, ok := m.peers[cfg.Identifier]; !ok {
m.peers[cfg.Identifier] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig)
}
for p, peer := range peers {
m.peers[cfg.Identifier][peer.Identifier] = &peers[p]
}
}
return nil
}
func (m *wgCtrlManager) createLowLevelInterface(id persistence.InterfaceIdentifier) error {
link := &netlink.GenericLink{
LinkAttrs: netlink.LinkAttrs{
Name: string(id),
},
LinkType: "wireguard",
}
err := m.nl.LinkAdd(link)
if err != nil {
return errors.Wrapf(err, "failed to create netlink interface")
}
if err := m.nl.LinkSetUp(link); err != nil {
return errors.Wrapf(err, "failed to enable netlink interface")
}
return nil
}
func (m *wgCtrlManager) deviceExists(id persistence.InterfaceIdentifier) bool {
if _, ok := m.interfaces[id]; ok {
return true
}
return false
}
func (m *wgCtrlManager) persistInterface(id persistence.InterfaceIdentifier, delete bool) error {
if m.store == nil {
return nil // nothing to do
}
var err error
if delete {
err = m.store.DeleteInterface(id)
} else {
err = m.store.SaveInterface(*m.interfaces[id])
}
if err != nil {
return errors.Wrapf(err, "failed to persist interface")
}
return nil
}
func (m *wgCtrlManager) peerExists(id persistence.PeerIdentifier) bool {
for _, peers := range m.peers {
if _, ok := peers[id]; ok {
return true
}
}
return false
}
func (m *wgCtrlManager) persistPeer(id persistence.PeerIdentifier, delete bool) error {
if m.store == nil {
return nil // nothing to do
}
var peer *persistence.PeerConfig
for _, peers := range m.peers {
if p, ok := peers[id]; ok {
peer = p
break
}
}
var err error
if delete {
err = m.store.DeletePeer(id)
} else {
err = m.store.SavePeer(*peer)
}
if err != nil {
return errors.Wrapf(err, "failed to persist peer %s", id)
}
return nil
}
func (m *wgCtrlManager) getPeer(id persistence.PeerIdentifier) (*persistence.PeerConfig, error) {
for _, peers := range m.peers {
if _, ok := peers[id]; ok {
return peers[id], nil
}
}
return nil, errors.New("peer not found")
}
func (m *wgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (*ImportableInterface, error) {
cfg := &ImportableInterface{}
cfg.Identifier = persistence.InterfaceIdentifier(device.Name)
cfg.Type = persistence.InterfaceTypeServer // default assume that the imported device is a server device
cfg.FirewallMark = int32(device.FirewallMark)
cfg.KeyPair = persistence.KeyPair{
PrivateKey: device.PrivateKey.String(),
PublicKey: device.PublicKey.String(),
}
cfg.ListenPort = device.ListenPort
cfg.DriverType = device.Type.String()
lowLevelInterface, err := m.nl.LinkByName(device.Name)
if err != nil {
return nil, errors.WithMessagef(err, "failed to get low level interface for %s", device.Name)
}
cfg.Mtu = lowLevelInterface.Attrs().MTU
ipAddresses, err := m.nl.AddrList(lowLevelInterface)
if err != nil {
return nil, errors.WithMessagef(err, "failed to get low level addresses for %s", device.Name)
}
cfg.AddressStr = ipAddressesToString(ipAddresses)
return cfg, nil
}
func (m *wgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev *ImportableInterface) (*persistence.PeerConfig, error) {
peerCfg := &persistence.PeerConfig{}
peerCfg.Identifier = persistence.PeerIdentifier(peer.PublicKey.String())
peerCfg.KeyPair = persistence.KeyPair{
PublicKey: peer.PublicKey.String(),
}
peerCfg.DisplayName = "Autodetected Peer (" + peer.PublicKey.String()[0:8] + ")"
if peer.Endpoint != nil {
peerCfg.Endpoint = persistence.NewStringConfigOption(peer.Endpoint.String(), true)
}
if peer.PresharedKey != (wgtypes.Key{}) {
peerCfg.PresharedKey = peer.PresharedKey.String()
}
allowedIPs := make([]string, len(peer.AllowedIPs)) // use allowed IP's as the peer IP's
for i, ip := range peer.AllowedIPs {
allowedIPs[i] = ip.String()
}
peerCfg.AllowedIPsStr = persistence.NewStringConfigOption(strings.Join(allowedIPs, ","), true)
peerCfg.PersistentKeepalive = persistence.NewIntConfigOption(int(peer.PersistentKeepaliveInterval.Seconds()), true)
peerCfg.Interface = &persistence.PeerInterfaceConfig{
Identifier: dev.Identifier,
AddressStr: persistence.NewStringConfigOption(dev.AddressStr, true), // todo: correct?
DnsStr: persistence.NewStringConfigOption(dev.DnsStr, true),
Mtu: persistence.NewIntConfigOption(dev.Mtu, true),
}
return peerCfg, nil
}
func (m *wgCtrlManager) checkInterface(cfg *persistence.InterfaceConfig) error {
if cfg == nil {
return errors.New("interface config must not be nil")
}
if cfg.Identifier == "" {
return errors.New("missing interface identifier")
}
if cfg.Type == "" {
return errors.New("missing interface type")
}
return nil
}
func (m *wgCtrlManager) checkPeer(cfg *persistence.PeerConfig) error {
if cfg == nil {
return errors.New("peer config must not be nil")
}
if cfg.Identifier == "" {
return errors.New("missing peer identifier")
}
if cfg.Interface == nil {
return errors.New("missing peer interface")
}
if cfg.Interface.Identifier == "" {
return errors.New("missing peer interface identifier")
}
return nil
}
func getWireGuardPeerConfig(devType persistence.InterfaceType, cfg *persistence.PeerConfig) (wgtypes.PeerConfig, error) {
publicKey, err := wgtypes.ParseKey(cfg.KeyPair.PublicKey)
if err != nil {
return wgtypes.PeerConfig{}, errors.WithMessage(err, "invalid public key for peer")
}
var presharedKey *wgtypes.Key
if tmpPresharedKey, err := wgtypes.ParseKey(cfg.PresharedKey); err == nil {
presharedKey = &tmpPresharedKey
}
var endpoint *net.UDPAddr
if cfg.Endpoint.Value != "" && devType == persistence.InterfaceTypeClient {
addr, err := net.ResolveUDPAddr("udp", cfg.Endpoint.GetValue())
if err == nil {
endpoint = addr
}
}
var keepAlive *time.Duration
if cfg.PersistentKeepalive.GetValue() != 0 {
keepAliveDuration := time.Duration(cfg.PersistentKeepalive.GetValue()) * time.Second
keepAlive = &keepAliveDuration
}
allowedIPs := make([]net.IPNet, 0)
var peerAllowedIPs []*netlink.Addr
switch devType {
case persistence.InterfaceTypeClient:
peerAllowedIPs, err = parseIpAddressString(cfg.AllowedIPsStr.GetValue())
if err != nil {
return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse allowed IP's")
}
case persistence.InterfaceTypeServer:
peerAllowedIPs, err = parseIpAddressString(cfg.AllowedIPsStr.GetValue())
if err != nil {
return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse allowed IP's")
}
peerExtraAllowedIPs, err := parseIpAddressString(cfg.ExtraAllowedIPsStr)
if err != nil {
return wgtypes.PeerConfig{}, errors.WithMessage(err, "failed to parse extra allowed IP's")
}
peerAllowedIPs = append(peerAllowedIPs, peerExtraAllowedIPs...)
}
for _, ip := range peerAllowedIPs {
allowedIPs = append(allowedIPs, *ip.IPNet)
}
wgPeer := wgtypes.PeerConfig{
PublicKey: publicKey,
Remove: false,
UpdateOnly: true,
PresharedKey: presharedKey,
Endpoint: endpoint,
PersistentKeepaliveInterval: keepAlive,
ReplaceAllowedIPs: true,
AllowedIPs: allowedIPs,
}
return wgPeer, nil
}