Newer
Older
wg-portal / internal / wireguard / manager.go
package wireguard

import (
	"bufio"
	"fmt"
	"net"
	"os"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"github.com/pkg/errors"
	"github.com/vishvananda/netlink"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

type KeyGenerator interface {
	GetFreshKeypair() (KeyPair, error)
	GetPreSharedKey() (PreSharedKey, error)
}

// DeviceManager provides methods to create/update/delete physical WireGuard devices.
type DeviceManager interface {
	CreateDevice(device DeviceIdentifier) error
	DeleteDevice(device DeviceIdentifier) error
	UpdateDevice(device DeviceIdentifier, cfg InterfaceConfig) error
}

type PeerManager interface {
	GetPeers(device DeviceIdentifier) ([]PeerConfig, error)
	SavePeers(device DeviceIdentifier, peers ...PeerConfig) error
	RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error
}

type Manager interface {
	KeyGenerator
	DeviceManager
	PeerManager
}

type ManagementUtil struct {
	configPath string

	wg Client
	nl NetlinkClient
	cp ConfigPersister

	// internal holder of interface configurations
	interfaces map[DeviceIdentifier]InterfaceConfig
	// internal holder of peer configurations
	peers map[DeviceIdentifier]map[PeerIdentifier]PeerConfig
}

func (m ManagementUtil) GetFreshKeypair() (KeyPair, error) {
	privateKey, err := wgtypes.GeneratePrivateKey()
	if err != nil {
		return KeyPair{}, errors.Wrap(err, "failed to generate private Key")
	}

	return KeyPair{
		PrivateKey: privateKey.String(),
		PublicKey:  privateKey.PublicKey().String(),
	}, nil
}

func (m ManagementUtil) GetPreSharedKey() (PreSharedKey, error) {
	preSharedKey, err := wgtypes.GenerateKey()
	if err != nil {
		return "", errors.Wrap(err, "failed to generate pre-shared Key")
	}

	return PreSharedKey(preSharedKey.String()), nil
}

func (m *ManagementUtil) CreateDevice(identifier DeviceIdentifier) error {
	if m.deviceExists(identifier) {
		return errors.Errorf("device %s already exists", identifier)
	}
	link := &netlink.GenericLink{
		LinkAttrs: netlink.LinkAttrs{
			Name: string(identifier),
		},
		LinkType: "wireguard",
	}
	err := m.nl.LinkAdd(link)
	if err != nil {
		return errors.Wrapf(err, "failed to create WireGuard interface")
	}

	if err := m.nl.LinkSetUp(link); err != nil {
		return errors.Wrapf(err, "failed to enable WireGuard interface")
	}

	m.interfaces[identifier] = InterfaceConfig{DeviceName: identifier}

	return nil
}

func (m *ManagementUtil) DeleteDevice(identifier DeviceIdentifier) error {
	if !m.deviceExists(identifier) {
		return errors.Errorf("device %s does not exist", identifier)
	}
	err := m.nl.LinkDel(&netlink.GenericLink{
		LinkAttrs: netlink.LinkAttrs{
			Name: string(identifier),
		},
		LinkType: "wireguard",
	})
	if err != nil {
		return errors.Wrapf(err, "failed to delete WireGuard interface")
	}

	delete(m.interfaces, identifier)

	return nil
}

func (m *ManagementUtil) UpdateDevice(identifier DeviceIdentifier, cfg InterfaceConfig) error {
	if !m.deviceExists(identifier) {
		return errors.Errorf("device %s does not exist", identifier)
	}
	cfg.DeviceName = identifier // ensure that the same device name is set

	// Update net-link attributes
	link, err := m.nl.LinkByName(string(identifier))
	if err != nil {
		return errors.Wrapf(err, "failed to open WireGuard interface")
	}
	if err := m.nl.LinkSetMTU(link, cfg.Mtu); err != nil {
		return errors.Wrapf(err, "failed to set MTU")
	}
	addresses, err := parseIpAddressString(cfg.AddressStr)
	for i := 0; i < len(addresses); i++ {
		var err error
		if i == 0 {
			err = m.nl.AddrReplace(link, addresses[i])
		} else {
			err = m.nl.AddrAdd(link, addresses[i])
		}
		if err != nil {
			return errors.Wrapf(err, "failed to set ip address %v", addresses[i])
		}
	}

	// Update WireGuard attributes
	pKey, _ := wgtypes.NewKey(cfg.KeyPair.GetPrivateKeyBytes())
	var fwMark *int
	if cfg.FirewallMark != 0 {
		*fwMark = int(cfg.FirewallMark)
	}
	err = m.wg.ConfigureDevice(string(identifier), wgtypes.Config{
		PrivateKey:   &pKey,
		ListenPort:   &cfg.ListenPort,
		FirewallMark: fwMark,
	})
	if err != nil {
		return errors.Wrapf(err, "failed to update WireGuard settings")
	}

	// Update link state
	if cfg.Enabled {
		if err := m.nl.LinkSetUp(link); err != nil {
			return errors.Wrapf(err, "failed to enable WireGuard interface")
		}
	} else {
		if err := m.nl.LinkSetDown(link); err != nil {
			return errors.Wrapf(err, "failed to disable WireGuard interface")
		}
	}

	m.interfaces[identifier] = cfg

	return nil
}

func (m ManagementUtil) GetPeers(device DeviceIdentifier) ([]PeerConfig, error) {
	if !m.deviceExists(device) {
		return nil, errors.Errorf("device %s does not exist", device)
	}

	peers := make([]PeerConfig, 0, len(m.peers[device]))
	for _, config := range m.peers[device] {
		peers = append(peers, config)
	}

	return peers, nil
}

func (m ManagementUtil) SavePeers(device DeviceIdentifier, peers ...PeerConfig) error {
	if !m.deviceExists(device) {
		return errors.Errorf("device %s does not exist", device)
	}

	deviceConfig := m.interfaces[device]

	for _, peer := range peers {
		wgPeer, err := getWireGuardPeerConfig(deviceConfig.Type, peer)
		if err != nil {
			return errors.Wrapf(err, "could not generate WireGuard peer configuration for %s", peer.Uid)
		}

		err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}})
		if err != nil {
			return errors.Wrapf(err, "could not save peer %s to WireGuard device %s", peer.Uid, device)
		}

		m.peers[device][peer.Uid] = peer
	}

	return nil
}

func (m ManagementUtil) RemovePeer(device DeviceIdentifier, peer PeerIdentifier) error {
	if !m.deviceExists(device) {
		return errors.Errorf("device %s does not exist", device)
	}
	if !m.peerExists(peer) {
		return errors.Errorf("peer %s does not exist", peer)
	}

	peerConfig := m.peers[device][peer]

	publicKey, err := wgtypes.ParseKey(peerConfig.KeyPair.PublicKey)
	if err != nil {
		return errors.Wrapf(err, "invalid public key for peer %s", peer)
	}

	wgPeer := wgtypes.PeerConfig{
		PublicKey: publicKey,
		Remove:    true,
	}

	err = m.wg.ConfigureDevice(string(device), wgtypes.Config{Peers: []wgtypes.PeerConfig{wgPeer}})
	if err != nil {
		return errors.Wrapf(err, "could not remove peer %s from WireGuard device %s", peer, device)
	}

	delete(m.peers[device], peer)

	return nil
}

//
// ---- Helpers
//

func getWireGuardPeerConfig(deviceType InterfaceType, peer PeerConfig) (wgtypes.PeerConfig, error) {
	publicKey, err := wgtypes.ParseKey(peer.KeyPair.PublicKey)
	if err != nil {
		return wgtypes.PeerConfig{}, errors.Wrapf(err, "invalid public key for peer %s", peer.Uid)
	}

	var presharedKey *wgtypes.Key
	if tmpPresharedKey, err := wgtypes.ParseKey(peer.PresharedKey); err == nil {
		presharedKey = &tmpPresharedKey
	}

	var endpoint *net.UDPAddr
	if peer.Endpoint.Value != "" && deviceType == InterfaceTypeClient {
		addr, err := net.ResolveUDPAddr("udp", peer.Endpoint.Value.(string))
		if err == nil {
			endpoint = addr
		}
	}

	var keepAlive *time.Duration
	if peer.PersistentKeepalive.Value != 0 {
		keepAliveDuration := time.Duration(peer.PersistentKeepalive.Value.(int)) * time.Second
		keepAlive = &keepAliveDuration
	}

	allowedIPs := make([]net.IPNet, 0)
	var peerAllowedIPs []*netlink.Addr
	switch deviceType {
	case InterfaceTypeClient:
		peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsString.GetValue())
		if err != nil {
			return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid)
		}
	case InterfaceTypeServer:
		peerAllowedIPs, err = parseIpAddressString(peer.AllowedIPsString.GetValue())
		if err != nil {
			return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse allowed IP's for peer %s", peer.Uid)
		}
		peerExtraAllowedIPs, err := parseIpAddressString(peer.ExtraAllowedIPsString)
		if err != nil {
			return wgtypes.PeerConfig{}, errors.Wrapf(err, "failed to parse extra allowed IP's for peer %s", peer.Uid)
		}

		peerAllowedIPs = append(peerAllowedIPs, peerExtraAllowedIPs...)
	}
	for _, ip := range peerAllowedIPs {
		allowedIPs = append(allowedIPs, *ip.IPNet)
	}

	wgPeer := wgtypes.PeerConfig{
		PublicKey:                   publicKey,
		Remove:                      false,
		UpdateOnly:                  true,
		PresharedKey:                presharedKey,
		Endpoint:                    endpoint,
		PersistentKeepaliveInterval: keepAlive,
		ReplaceAllowedIPs:           true,
		AllowedIPs:                  allowedIPs,
	}

	return wgPeer, nil
}

func (m ManagementUtil) deviceExists(identifier DeviceIdentifier) bool {
	if _, ok := m.interfaces[identifier]; ok {
		return true
	}
	return false
}

func (m ManagementUtil) peerExists(identifier PeerIdentifier) bool {
	for _, peers := range m.peers {
		if _, ok := peers[identifier]; ok {
			return true
		}
	}

	return false
}

// TODO: fix/implement
func (m ManagementUtil) loadExistingInterfaces() ([]InterfaceConfig, error) {
	devices, err := m.wg.Devices()
	if err != nil {
		return nil, errors.Wrapf(err, "failed to get WireGuard device list")
	}

	interfaces := make([]InterfaceConfig, len(devices))
	for i, device := range devices {
		interfaces[i].DeviceName = DeviceIdentifier(device.Name)
		interfaces[i].FirewallMark = int32(device.FirewallMark)
		interfaces[i].KeyPair = KeyPair{
			PrivateKey: device.PrivateKey.String(),
			PublicKey:  device.PublicKey.String(),
		}
		interfaces[i].ListenPort = device.ListenPort
		interfaces[i].DriverType = device.Type.String()

		parsedInterface, _, err := m.parseConfigFile(device.Name)
		if err != nil {
			continue
		}
		interfaces[i].Dns = parsedInterface.Dns
		interfaces[i].DisplayName = parsedInterface.DisplayName
		interfaces[i].PostDown = parsedInterface.PostDown
		interfaces[i].PreDown = parsedInterface.PreDown
		interfaces[i].PostUp = parsedInterface.PostUp
		interfaces[i].PreUp = parsedInterface.PreUp
		interfaces[i].AddressStr = parsedInterface.AddressStr
		interfaces[i].RoutingTable = parsedInterface.RoutingTable
		interfaces[i].Mtu = parsedInterface.Mtu

		fmt.Println(interfaces[i])
	}

	return interfaces, nil
}

// parseConfigFile parses WireGuard configuration files (INI syntax) and some additional comments in the file
// TODO: fix/implement
func (m ManagementUtil) parseConfigFile(interfaceName string) (InterfaceConfig, []PeerConfig, error) {
	configFile := filepath.Join(m.configPath, interfaceName+".conf")

	file, err := os.Open(configFile)
	if err != nil {
		return InterfaceConfig{}, nil, errors.Wrapf(err, "unable to open config file for interface %s", interfaceName)
	}
	scanner := bufio.NewScanner(file)

	peerSection := false
	iface := InterfaceConfig{}
	for scanner.Scan() {
		line := scanner.Text()
		line = strings.TrimSpace(line)

		switch {
		case strings.HasPrefix(line, "#"): // A comment line
			line = line[1:]
			commentParts := strings.SplitN(line, "=", 1)
			fmt.Println(commentParts, peerSection)
		case strings.HasPrefix(line, "["): // Config section
			line = strings.ToLower(line[1 : len(line)-1])
			switch line {
			case "peer":
				peerSection = true
			case "interface":
				peerSection = false
			default:
				return InterfaceConfig{}, nil, errors.Errorf("configuration file contains unsupported section %s", line)
			}
		default: //Config option
			optionParts := strings.SplitN(line, "=", 1)
			if len(optionParts) != 2 {
				return InterfaceConfig{}, nil, errors.Errorf("configuration file contains invalid line %s", line)
			}
			option := strings.ToLower(strings.TrimSpace(optionParts[0]))
			value := strings.TrimSpace(optionParts[1])
			peerOption := false
			switch option {
			// Interface
			case "privatekey":
				key, err := wgtypes.ParseKey(value)
				if err != nil {
					return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has no valid private Key")
				}
				iface.KeyPair = KeyPair{
					PrivateKey: key.String(),
					PublicKey:  key.PublicKey().String(),
				}
			case "address":
				iface.AddressStr = value
			case "listenport":
				port, err := strconv.Atoi(value)
				if err != nil {
					return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid listen port Value")
				}
				iface.ListenPort = port
			case "postup":
				iface.PostUp = value
			case "postdown":
				iface.PostDown = value
			case "preup":
				iface.PreUp = value
			case "predown":
				iface.PreDown = value
			case "mtu":
				mtu, err := strconv.Atoi(value)
				if err != nil {
					return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid MTU Value")
				}
				iface.Mtu = mtu
			case "dns":
				iface.Dns = value
			case "table":
				iface.RoutingTable = value
			case "fwmark":
				fwMark, err := strconv.Atoi(value)
				if err != nil {
					return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid fwmark Value")
				}
				iface.FirewallMark = int32(fwMark)
			case "saveconfig":
				saveConfig, err := strconv.ParseBool(value)
				if err != nil {
					return InterfaceConfig{}, nil, errors.Wrapf(err, "interface section has invalid save-config Value")
				}
				iface.SaveConfig = saveConfig
			// Peer
			case "endpoint":
				peerOption = true
			case "publickey":
				peerOption = true
			case "allowedips":
				peerOption = true
			case "persistentkeepalive":
				peerOption = true
			case "presharedkey":
				peerOption = true
			}

			if peerSection != peerOption {
				return InterfaceConfig{}, nil, errors.Errorf("config section contains invalid option %s", option)
			}

			fmt.Println(value)
		}
		if strings.HasPrefix(line, "#") {
			fmt.Println("comment")
		}
		fmt.Println(line)
	}

	if err := scanner.Err(); err != nil {
		return InterfaceConfig{}, nil, errors.Wrapf(err, "unable to scan config file for interface %s", interfaceName)
	}

	return InterfaceConfig{}, nil, nil
}

func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) {
	rawAddresses := strings.Split(addrStr, ",")
	addresses := make([]*netlink.Addr, 0, len(rawAddresses))
	for i := range rawAddresses {
		rawAddress := strings.TrimSpace(rawAddresses[i])
		if rawAddress == "" {
			continue // skip empty string
		}
		address, err := netlink.ParseAddr(rawAddress)
		if err != nil {
			return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress)
		}
		addresses = append(addresses, address)
	}

	return addresses, nil
}