Newer
Older
wg-portal / internal / server / helper.go
package server

import (
	"crypto/md5"
	"fmt"
	"io/ioutil"
	"syscall"
	"time"

	"github.com/h44z/wg-portal/internal/common"
	"github.com/pkg/errors"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

func (s *Server) PrepareNewPeer() (Peer, error) {
	device := s.peers.GetDevice()

	peer := Peer{}
	peer.IsNew = true
	peer.AllowedIPsStr = device.AllowedIPsStr
	peer.IPs = make([]string, len(device.IPs))
	for i := range device.IPs {
		freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
		if err != nil {
			return Peer{}, err
		}
		peer.IPs[i] = freeIP
	}
	peer.IPsStr = common.ListToString(peer.IPs)
	psk, err := wgtypes.GenerateKey()
	if err != nil {
		return Peer{}, err
	}
	key, err := wgtypes.GeneratePrivateKey()
	if err != nil {
		return Peer{}, err
	}
	peer.PresharedKey = psk.String()
	peer.PrivateKey = key.String()
	peer.PublicKey = key.PublicKey().String()
	peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))

	return peer, nil
}

func (s *Server) CreatePeerByEmail(email, identifierSuffix string, disabled bool) error {
	user, err := s.users.GetOrCreateUser(email)
	if err != nil {
		return errors.WithMessagef(err, "failed to load/create related user %s", email)
	}

	device := s.peers.GetDevice()
	peer := Peer{}
	peer.User = user
	peer.AllowedIPsStr = device.AllowedIPsStr
	peer.IPs = make([]string, len(device.IPs))
	for i := range device.IPs {
		freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
		if err != nil {
			return err
		}
		peer.IPs[i] = freeIP
	}
	peer.IPsStr = common.ListToString(peer.IPs)
	psk, err := wgtypes.GenerateKey()
	if err != nil {
		return err
	}
	key, err := wgtypes.GeneratePrivateKey()
	if err != nil {
		return err
	}
	peer.PresharedKey = psk.String()
	peer.PrivateKey = key.String()
	peer.PublicKey = key.PublicKey().String()
	peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))
	peer.Email = email
	peer.Identifier = fmt.Sprintf("%s %s (%s)", user.Firstname, user.Lastname, identifierSuffix)
	now := time.Now()
	if disabled {
		peer.DeactivatedAt = &now
	}

	return s.CreatePeer(peer)
}

func (s *Server) CreatePeer(peer Peer) error {
	device := s.peers.GetDevice()
	peer.AllowedIPsStr = device.AllowedIPsStr
	if peer.IPs == nil || len(peer.IPs) == 0 {
		peer.IPs = make([]string, len(device.IPs))
		for i := range device.IPs {
			freeIP, err := s.peers.GetAvailableIp(device.IPs[i])
			if err != nil {
				return err
			}
			peer.IPs[i] = freeIP
		}
		peer.IPsStr = common.ListToString(peer.IPs)
	}
	if peer.PrivateKey == "" { // if private key is empty create a new one
		psk, err := wgtypes.GenerateKey()
		if err != nil {
			return err
		}
		key, err := wgtypes.GeneratePrivateKey()
		if err != nil {
			return err
		}
		peer.PresharedKey = psk.String()
		peer.PrivateKey = key.String()
		peer.PublicKey = key.PublicKey().String()
	}
	peer.UID = fmt.Sprintf("u%x", md5.Sum([]byte(peer.PublicKey)))

	// Create WireGuard interface
	if peer.DeactivatedAt == nil {
		if err := s.wg.AddPeer(peer.GetConfig()); err != nil {
			return err
		}
	}

	// Create in database
	if err := s.peers.CreatePeer(peer); err != nil {
		return err
	}

	return s.WriteWireGuardConfigFile()
}

func (s *Server) UpdatePeer(peer Peer, updateTime time.Time) error {
	currentPeer := s.peers.GetPeerByKey(peer.PublicKey)

	// Update WireGuard device
	var err error
	switch {
	case peer.DeactivatedAt == &updateTime:
		err = s.wg.RemovePeer(peer.PublicKey)
	case peer.DeactivatedAt == nil && currentPeer.Peer != nil:
		err = s.wg.UpdatePeer(peer.GetConfig())
	case peer.DeactivatedAt == nil && currentPeer.Peer == nil:
		err = s.wg.AddPeer(peer.GetConfig())
	}
	if err != nil {
		return err
	}

	// Update in database
	if err := s.peers.UpdatePeer(peer); err != nil {
		return err
	}

	return s.WriteWireGuardConfigFile()
}

func (s *Server) DeletePeer(peer Peer) error {
	// Delete WireGuard peer
	if err := s.wg.RemovePeer(peer.PublicKey); err != nil {
		return err
	}

	// Delete in database
	if err := s.peers.DeletePeer(peer); err != nil {
		return err
	}

	return s.WriteWireGuardConfigFile()
}

func (s *Server) RestoreWireGuardInterface() error {
	activePeers := s.peers.GetActivePeers()

	for i := range activePeers {
		if activePeers[i].Peer == nil {
			if err := s.wg.AddPeer(activePeers[i].GetConfig()); err != nil {
				return err
			}
		}
	}

	return nil
}

func (s *Server) WriteWireGuardConfigFile() error {
	if s.config.WG.WireGuardConfig == "" {
		return nil // writing disabled
	}
	if err := syscall.Access(s.config.WG.WireGuardConfig, syscall.O_RDWR); err != nil {
		return err
	}

	device := s.peers.GetDevice()
	cfg, err := device.GetConfigFile(s.peers.GetActivePeers())
	if err != nil {
		return err
	}
	if err := ioutil.WriteFile(s.config.WG.WireGuardConfig, cfg, 0644); err != nil {
		return err
	}
	return nil
}