Newer
Older
wg-portal / internal / wireguard / wireguard_ip.go
@Christoph Haas Christoph Haas on 11 Oct 2021 5 KB wip: ip handling, refactoring, tests
package wireguard

import (
	"bytes"
	"net"
	"sort"
	"strings"

	"github.com/h44z/wg-portal/internal/persistence"
	"github.com/pkg/errors"
	"github.com/vishvananda/netlink"
)

func (m *wgCtrlManager) GetAllUsedIPs(id persistence.InterfaceIdentifier) ([]*netlink.Addr, error) {
	m.mux.RLock()
	defer m.mux.RUnlock()

	if !m.deviceExists(id) {
		return nil, errors.New("interface does not exist")
	}

	var usedAddresses []*netlink.Addr
	for _, peer := range m.peers[id] {
		addresses, err := parseIpAddressString(peer.Interface.AddressStr.GetValue())
		if err != nil {
			return nil, errors.WithMessagef(err, "unable to parse addresses of peer %s", peer.Identifier)
		}

		usedAddresses = append(usedAddresses, addresses...)
	}

	sort.Slice(usedAddresses, func(i, j int) bool {
		return bytes.Compare(usedAddresses[i].IP, usedAddresses[j].IP) < 0
	})

	return usedAddresses, nil
}

func (m *wgCtrlManager) GetUsedIPs(id persistence.InterfaceIdentifier, subnetCidr string) ([]*netlink.Addr, error) {
	m.mux.RLock()
	defer m.mux.RUnlock()

	if !m.deviceExists(id) {
		return nil, errors.New("interface does not exist")
	}

	subnet, err := parseCIDR(subnetCidr)
	if err != nil {
		return nil, errors.WithMessagef(err, "unable to parse subnet addresses")
	}

	var usedAddresses []*netlink.Addr
	for _, peer := range m.peers[id] {
		addresses, err := parseIpAddressString(peer.Interface.AddressStr.GetValue())
		if err != nil {
			return nil, errors.WithMessagef(err, "unable to parse addresses of peer %s", peer.Identifier)
		}

		for _, address := range addresses {
			if subnet.Contains(address.IP) {
				usedAddresses = append(usedAddresses, address)
			}
		}
	}

	sort.Slice(usedAddresses, func(i, j int) bool {
		return bytes.Compare(usedAddresses[i].IP, usedAddresses[j].IP) < 0
	})

	return usedAddresses, nil
}

func (m *wgCtrlManager) GetFreshIp(id persistence.InterfaceIdentifier, subnetCidr string, increment ...bool) (*netlink.Addr, error) {
	m.mux.RLock()
	defer m.mux.RUnlock()

	if !m.deviceExists(id) {
		return nil, errors.New("interface does not exist")
	}

	subnet, err := parseCIDR(subnetCidr)
	if err != nil {
		return nil, errors.WithMessagef(err, "unable to parse subnet addresses")
	}
	isV4 := isV4(subnet)

	usedIPs, err := m.GetUsedIPs(id, subnetCidr) // highest IP is at the end of the array
	if err != nil {
		return nil, errors.WithMessagef(err, "unable to load used IP addresses")
	}

	// these two addresses are not usable
	broadcastAddr := broadcastAddr(subnet)
	networkAddr := subnet.IP
	// start with the lowest IP and check all others
	ip := &netlink.Addr{
		IPNet: &net.IPNet{IP: subnet.IP.Mask(subnet.Mask).To16(), Mask: subnet.Mask},
	}
	if len(increment) != 0 && increment[0] == true && len(usedIPs) > 0 {
		// start with the maximum used IP and check all above
		ip = &netlink.Addr{
			IPNet: &net.IPNet{IP: make([]byte, 16), Mask: subnet.Mask},
		}
		copy(ip.IP, usedIPs[len(usedIPs)-1].IP)
	}

	for ; subnet.Contains(ip.IP); increaseIP(ip) {
		if bytes.Compare(ip.IP, networkAddr) == 0 {
			continue
		}
		if isV4 && bytes.Compare(ip.IP, broadcastAddr.IP) == 0 {
			continue
		}

		ok := true
		for _, r := range usedIPs {
			if bytes.Compare(ip.IP, r.IP) == 0 {
				ok = false
				break
			}
		}

		if ok {
			return ip, nil
		}
	}

	return nil, errors.New("ip range exceeded")
}

//  http://play.golang.org/p/m8TNTtygK0
func increaseIP(ip *netlink.Addr) {
	for j := len(ip.IP) - 1; j >= 0; j-- {
		ip.IP[j]++
		if ip.IP[j] > 0 {
			break
		}
	}
}

// BroadcastAddr returns the last address in the given network (for IPv6), or the broadcast address.
func broadcastAddr(n *netlink.Addr) *netlink.Addr {
	// The golang net package doesn't make it easy to calculate the broadcast address. :(
	var broadcast = net.IPv6zero
	var mask = []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} // ensure that mask also has 16 bytes (also for IPv4)
	if len(n.Mask) == 4 {
		for i := 0; i < 4; i++ {
			mask[12+i] = n.Mask[i]
		}
	} else {
		for i := 0; i < 16; i++ {
			mask[i] = n.Mask[i]
		}
	}
	for i := 0; i < len(n.IP); i++ {
		broadcast[i] = n.IP[i] | ^mask[i]
	}
	return &netlink.Addr{
		IPNet: &net.IPNet{IP: broadcast, Mask: n.Mask},
	}
}

func isV4(n *netlink.Addr) bool {
	if n.IP.To4() != nil {
		return true
	}

	return false
}

func parseCIDR(cidr string) (*netlink.Addr, error) {
	addr, err := netlink.ParseAddr(cidr)
	if err != nil {
		return nil, errors.WithMessagef(err, "failed to parse cidr")
	}

	// Use the 16byte representation for all IP families.
	if len(addr.IP) != 16 {
		addr.IP = addr.IP.To16()
	}

	return addr, nil
}

func parseIpAddressString(addrStr string) ([]*netlink.Addr, error) {
	rawAddresses := strings.Split(addrStr, ",")
	addresses := make([]*netlink.Addr, 0, len(rawAddresses))
	for i := range rawAddresses {
		rawAddress := strings.TrimSpace(rawAddresses[i])
		if rawAddress == "" {
			continue // skip empty string
		}
		address, err := parseCIDR(rawAddress)
		if err != nil {
			return nil, errors.Wrapf(err, "failed to parse IP address %s", rawAddress)
		}

		addresses = append(addresses, address)
	}

	sort.Slice(addresses, func(i, j int) bool {
		return bytes.Compare(addresses[i].IP, addresses[j].IP) < 0
	})

	return addresses, nil
}

func ipAddressesToString(addresses []netlink.Addr) string {
	addressesStr := make([]string, len(addresses))
	for i := range addresses {
		addressesStr[i] = addresses[i].String()
	}

	return strings.Join(addressesStr, ",")
}