diff --git a/internal/portal/api.go b/internal/portal/api.go index 8d7996b..598d685 100644 --- a/internal/portal/api.go +++ b/internal/portal/api.go @@ -1 +1,9 @@ package portal + +import "github.com/h44z/wg-portal/internal/wireguard" + +var man wireguard.Manager + +func init() { + man, _ = wireguard.NewPersistentManager(nil, nil, nil) +} diff --git a/internal/wireguard/keys.go b/internal/wireguard/keys.go index 657a7b1..50789d8 100644 --- a/internal/wireguard/keys.go +++ b/internal/wireguard/keys.go @@ -23,9 +23,9 @@ return base64.StdEncoding.EncodeToString(key) } -type WgCtrlKeyGenerator struct{} +type wgCtrlKeyGenerator struct{} -func (k WgCtrlKeyGenerator) GetFreshKeypair() (persistence.KeyPair, error) { +func (k wgCtrlKeyGenerator) GetFreshKeypair() (persistence.KeyPair, error) { privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { return persistence.KeyPair{}, errors.Wrap(err, "failed to generate private Key") @@ -37,7 +37,7 @@ }, nil } -func (k WgCtrlKeyGenerator) GetPreSharedKey() (persistence.PreSharedKey, error) { +func (k wgCtrlKeyGenerator) GetPreSharedKey() (persistence.PreSharedKey, error) { preSharedKey, err := wgtypes.GenerateKey() if err != nil { return "", errors.Wrap(err, "failed to generate pre-shared Key") diff --git a/internal/wireguard/keys_test.go b/internal/wireguard/keys_test.go index e8936a0..4c828f9 100644 --- a/internal/wireguard/keys_test.go +++ b/internal/wireguard/keys_test.go @@ -33,7 +33,7 @@ } func TestWgCtrlKeyGenerator_GetFreshKeypair(t *testing.T) { - m := WgCtrlKeyGenerator{} + m := wgCtrlKeyGenerator{} kp, err := m.GetFreshKeypair() assert.NoError(t, err) assert.NotEmpty(t, kp.PrivateKey) @@ -41,7 +41,7 @@ } func TestWgCtrlKeyGenerator_GetPreSharedKey(t *testing.T) { - m := WgCtrlKeyGenerator{} + m := wgCtrlKeyGenerator{} psk, err := m.GetPreSharedKey() assert.NoError(t, err) assert.NotEmpty(t, psk) diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 8890db7..4b5477b 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -3,6 +3,8 @@ import ( "io" + "github.com/pkg/errors" + "github.com/h44z/wg-portal/internal/lowlevel" "github.com/h44z/wg-portal/internal/persistence" @@ -56,13 +58,27 @@ // type PersistentManager struct { - WgCtrlKeyGenerator - TemplateHandler - WgCtrlManager + wgCtrlKeyGenerator + *templateHandler + *wgCtrlManager } func NewPersistentManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*PersistentManager, error) { - m := &PersistentManager{} + wgManager, err := newWgCtrlManager(wg, nl, store) + if err != nil { + return nil, errors.WithMessage(err, "failed to initialize WireGuard manager") + } + + tplManager, err := newTemplateHandler() + if err != nil { + return nil, errors.WithMessage(err, "failed to initialize template manager") + } + + m := &PersistentManager{ + wgCtrlKeyGenerator: wgCtrlKeyGenerator{}, + wgCtrlManager: wgManager, + templateHandler: tplManager, + } return m, nil } diff --git a/internal/wireguard/template.go b/internal/wireguard/template.go index 5ebe02d..3718e36 100644 --- a/internal/wireguard/template.go +++ b/internal/wireguard/template.go @@ -13,24 +13,24 @@ //go:embed tpl_files/* var TemplateFiles embed.FS -type TemplateHandler struct { +type templateHandler struct { templates *template.Template } -func NewTemplateHandler() (*TemplateHandler, error) { +func newTemplateHandler() (*templateHandler, error) { templateCache, err := template.New("WireGuard").ParseFS(TemplateFiles, "tpl_files/*.tpl") if err != nil { return nil, errors.Wrapf(err, "failed to parse template files") } - handler := &TemplateHandler{ + handler := &templateHandler{ templates: templateCache, } return handler, nil } -func (c TemplateHandler) GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) { +func (c templateHandler) GetInterfaceConfig(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "interface.tpl", map[string]interface{}{ @@ -47,7 +47,7 @@ return &tplBuff, nil } -func (c TemplateHandler) GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) { +func (c templateHandler) GetPeerConfig(peer persistence.PeerConfig) (io.Reader, error) { var tplBuff bytes.Buffer err := c.templates.ExecuteTemplate(&tplBuff, "peer.tpl", map[string]interface{}{ diff --git a/internal/wireguard/template_test.go b/internal/wireguard/template_test.go index d855bb4..5a8c86e 100644 --- a/internal/wireguard/template_test.go +++ b/internal/wireguard/template_test.go @@ -12,7 +12,7 @@ ) func TestNewTemplateHandler(t *testing.T) { - got, err := NewTemplateHandler() + got, err := newTemplateHandler() assert.NoError(t, err) assert.NotNil(t, got) } @@ -60,7 +60,7 @@ }, } - c, _ := NewTemplateHandler() + c, _ := newTemplateHandler() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := c.GetInterfaceConfig(tt.args.cfg, tt.args.peers) @@ -115,7 +115,7 @@ wantErr: false, }, } - c, _ := NewTemplateHandler() + c, _ := newTemplateHandler() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := c.GetPeerConfig(tt.args.peer) diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 88710b7..7c1fbe3 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -14,7 +14,7 @@ "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type WgCtrlManager struct { +type wgCtrlManager struct { mux sync.RWMutex // mutex to synchronize access to maps and external api clients // external api clients @@ -30,8 +30,8 @@ peers map[persistence.InterfaceIdentifier]map[persistence.PeerIdentifier]persistence.PeerConfig } -func NewWgCtrlManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*WgCtrlManager, error) { - m := &WgCtrlManager{ +func newWgCtrlManager(wg lowlevel.WireGuardClient, nl lowlevel.NetlinkClient, store store) (*wgCtrlManager, error) { + m := &wgCtrlManager{ mux: sync.RWMutex{}, wg: wg, nl: nl, @@ -47,7 +47,7 @@ return m, nil } -func (m *WgCtrlManager) GetInterfaces() ([]persistence.InterfaceConfig, error) { +func (m *wgCtrlManager) GetInterfaces() ([]persistence.InterfaceConfig, error) { m.mux.RLock() defer m.mux.RUnlock() interfaces := make([]persistence.InterfaceConfig, 0, len(m.interfaces)) @@ -62,7 +62,7 @@ return interfaces, nil } -func (m *WgCtrlManager) CreateInterface(id persistence.InterfaceIdentifier) error { +func (m *wgCtrlManager) CreateInterface(id persistence.InterfaceIdentifier) error { m.mux.Lock() defer m.mux.Unlock() if m.deviceExists(id) { @@ -86,7 +86,7 @@ return nil } -func (m *WgCtrlManager) DeleteInterface(id persistence.InterfaceIdentifier) error { +func (m *wgCtrlManager) DeleteInterface(id persistence.InterfaceIdentifier) error { m.mux.Lock() defer m.mux.Unlock() @@ -122,7 +122,7 @@ return nil } -func (m *WgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error { +func (m *wgCtrlManager) UpdateInterface(id persistence.InterfaceIdentifier, cfg persistence.InterfaceConfig) error { m.mux.Lock() defer m.mux.Unlock() if !m.deviceExists(id) { @@ -192,7 +192,7 @@ return nil } -func (m *WgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) { +func (m *wgCtrlManager) GetPeers(interfaceId persistence.InterfaceIdentifier) ([]persistence.PeerConfig, error) { m.mux.RLock() defer m.mux.RUnlock() if !m.deviceExists(interfaceId) { @@ -207,7 +207,7 @@ return peers, nil } -func (m *WgCtrlManager) SavePeers(peers ...persistence.PeerConfig) error { +func (m *wgCtrlManager) SavePeers(peers ...persistence.PeerConfig) error { m.mux.Lock() defer m.mux.Unlock() @@ -239,7 +239,7 @@ return nil } -func (m *WgCtrlManager) RemovePeer(id persistence.PeerIdentifier) error { +func (m *wgCtrlManager) RemovePeer(id persistence.PeerIdentifier) error { m.mux.Lock() defer m.mux.Unlock() @@ -275,7 +275,7 @@ return nil } -func (m *WgCtrlManager) GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) { +func (m *wgCtrlManager) GetImportableInterfaces() (map[ImportableInterface][]persistence.PeerConfig, error) { devices, err := m.wg.Devices() if err != nil { return nil, errors.WithMessage(err, "failed to get WireGuard device list") @@ -311,7 +311,7 @@ return interfaces, nil } -func (m *WgCtrlManager) ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error { +func (m *wgCtrlManager) ImportInterface(cfg ImportableInterface, peers []persistence.PeerConfig) error { m.mux.Lock() defer m.mux.Unlock() @@ -324,7 +324,7 @@ // -- Helpers // -func (m *WgCtrlManager) initializeFromStore() error { +func (m *wgCtrlManager) initializeFromStore() error { if m.store == nil { return nil // no store, nothing to do } @@ -352,7 +352,7 @@ return nil } -func (m *WgCtrlManager) createLowLevelInterface(id persistence.InterfaceIdentifier) error { +func (m *wgCtrlManager) createLowLevelInterface(id persistence.InterfaceIdentifier) error { link := &netlink.GenericLink{ LinkAttrs: netlink.LinkAttrs{ Name: string(id), @@ -371,14 +371,14 @@ return nil } -func (m *WgCtrlManager) deviceExists(id persistence.InterfaceIdentifier) bool { +func (m *wgCtrlManager) deviceExists(id persistence.InterfaceIdentifier) bool { if _, ok := m.interfaces[id]; ok { return true } return false } -func (m *WgCtrlManager) persistInterface(id persistence.InterfaceIdentifier, delete bool) error { +func (m *wgCtrlManager) persistInterface(id persistence.InterfaceIdentifier, delete bool) error { if m.store == nil { return nil // nothing to do } @@ -402,7 +402,7 @@ return nil } -func (m *WgCtrlManager) peerExists(id persistence.PeerIdentifier) bool { +func (m *wgCtrlManager) peerExists(id persistence.PeerIdentifier) bool { for _, peers := range m.peers { if _, ok := peers[id]; ok { return true @@ -412,7 +412,7 @@ return false } -func (m *WgCtrlManager) persistPeer(id persistence.PeerIdentifier, delete bool) error { +func (m *wgCtrlManager) persistPeer(id persistence.PeerIdentifier, delete bool) error { if m.store == nil { return nil // nothing to do } @@ -438,7 +438,7 @@ return nil } -func (m *WgCtrlManager) getPeer(id persistence.PeerIdentifier) (persistence.PeerConfig, error) { +func (m *wgCtrlManager) getPeer(id persistence.PeerIdentifier) (persistence.PeerConfig, error) { for _, peers := range m.peers { if _, ok := peers[id]; ok { return peers[id], nil @@ -448,7 +448,7 @@ return persistence.PeerConfig{}, errors.New("peer not found") } -func (m *WgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (ImportableInterface, error) { +func (m *wgCtrlManager) convertWireGuardInterface(device *wgtypes.Device) (ImportableInterface, error) { cfg := ImportableInterface{} cfg.Identifier = persistence.InterfaceIdentifier(device.Name) @@ -474,7 +474,7 @@ return cfg, nil } -func (m *WgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev ImportableInterface) (persistence.PeerConfig, error) { +func (m *wgCtrlManager) convertWireGuardPeer(peer *wgtypes.Peer, dev ImportableInterface) (persistence.PeerConfig, error) { peerCfg := persistence.PeerConfig{} peerCfg.Identifier = persistence.PeerIdentifier(peer.PublicKey.String()) peerCfg.KeyPair = persistence.KeyPair{ diff --git a/internal/wireguard/wireguard_test.go b/internal/wireguard/wireguard_test.go index d8c196b..f145496 100644 --- a/internal/wireguard/wireguard_test.go +++ b/internal/wireguard/wireguard_test.go @@ -18,14 +18,14 @@ func TestWgCtrlManager_CreateInterface(t *testing.T) { tests := []struct { name string - manager *WgCtrlManager + manager *wgCtrlManager mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) args persistence.InterfaceIdentifier wantErr bool }{ { name: "AlreadyExisting", - manager: &WgCtrlManager{ + manager: &wgCtrlManager{ mux: sync.RWMutex{}, wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, @@ -39,7 +39,7 @@ }, { name: "LinkAddFailure", - manager: &WgCtrlManager{ + manager: &wgCtrlManager{ mux: sync.RWMutex{}, wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, @@ -55,7 +55,7 @@ }, { name: "LinkSetupFailure", - manager: &WgCtrlManager{ + manager: &wgCtrlManager{ mux: sync.RWMutex{}, wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, @@ -72,7 +72,7 @@ }, { name: "PersistenceFailure", - manager: &WgCtrlManager{ + manager: &wgCtrlManager{ mux: sync.RWMutex{}, wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, @@ -90,7 +90,7 @@ }, { name: "Success", - manager: &WgCtrlManager{ + manager: &wgCtrlManager{ mux: sync.RWMutex{}, wg: &MockWireGuardClient{}, nl: &MockNetlinkClient{}, @@ -127,7 +127,7 @@ func TestWgCtrlManager_DeleteInterface(t *testing.T) { tests := []struct { name string - manager *WgCtrlManager + manager *wgCtrlManager mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) args persistence.InterfaceIdentifier wantErr bool @@ -154,7 +154,7 @@ func TestWgCtrlManager_GetInterfaces(t *testing.T) { tests := []struct { name string - manager *WgCtrlManager + manager *wgCtrlManager mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) want []persistence.InterfaceConfig wantErr bool @@ -190,7 +190,7 @@ } tests := []struct { name string - manager *WgCtrlManager + manager *wgCtrlManager mockSetup func(wg *MockWireGuardClient, nl *MockNetlinkClient, st *MockWireGuardStore) args args wantErr bool