diff --git a/cmd/cli/main.go b/cmd/cli/main.go new file mode 100644 index 0000000..1183771 --- /dev/null +++ b/cmd/cli/main.go @@ -0,0 +1,181 @@ +package main + +import ( + "fmt" + "log" + "os" + "strings" + + "github.com/h44z/wg-portal/internal/persistence" + + "github.com/h44z/wg-portal/internal/portal" + "github.com/pkg/errors" + + "github.com/urfave/cli/v2" +) + +const ( + dsnFlag = "dsn" + interfaceFlag = "interface" +) + +var backend *portal.Backend + +var globalFlags = []cli.Flag{ + &cli.StringFlag{ + Name: dsnFlag, + Value: "./sqlite.db", + Usage: "A DSN for the data store.", + }, +} + +var commands = []*cli.Command{ + { + Name: "list", + Aliases: []string{"l"}, + Usage: "list interfaces or peers", + Subcommands: []*cli.Command{ + { + Name: "interface", + Usage: "show interface information", + ArgsUsage: "", + Action: func(c *cli.Context) error { + if c.Args().Len() != 1 { + return errors.New("missing/invalid interface identifier") + } + interfaceIdentifier := persistence.InterfaceIdentifier(strings.TrimSpace(c.Args().Get(0))) + + cfg, err := backend.GetInterface(interfaceIdentifier) + if err != nil { + return errors.WithMessage(err, "failed to get interface") + } + + peers, err := backend.GetPeers(interfaceIdentifier) + if err != nil { + return errors.WithMessage(err, "failed to get interface peers") + } + + config, err := backend.GetInterfaceConfig(cfg, peers) + if err != nil { + return errors.WithMessage(err, "failed to get interface config") + } + + fmt.Println(config) + + return nil + }, + }, + { + Name: "interfaces", + Usage: "list all interfaces", + Action: func(c *cli.Context) error { + interfaces, err := backend.GetInterfaces() + if err != nil { + return errors.WithMessage(err, "failed to get all interfaces") + } + + fmt.Println("Managed WireGuard Interfaces:") + for i, cfg := range interfaces { + desc := "" + if cfg.DisplayName != "" { + desc = fmt.Sprintf(" (%s)", cfg.DisplayName) + } + fmt.Printf(" %d\t%s%s\n", i, cfg.Identifier, desc) + } + + importable, err := backend.GetImportableInterfaces() + if err != nil { + return errors.WithMessage(err, "failed to get importable interfaces") + } + + fmt.Println("Importable WireGuard Interfaces:") + i := 0 + for cfg := range importable { + fmt.Printf(" %d\t%s\n", i, cfg.Identifier) + i++ + } + + return nil + }, + }, + { + Name: "peers", + Usage: "list all peers", + ArgsUsage: "", + Action: func(c *cli.Context) error { + if c.Args().Len() != 1 { + return errors.New("missing/invalid interface identifier") + } + interfaceIdentifier := persistence.InterfaceIdentifier(strings.TrimSpace(c.Args().Get(0))) + + peers, err := backend.GetPeers(interfaceIdentifier) + if err != nil { + return errors.WithMessage(err, "failed to get all peers") + } + + fmt.Println("WireGuard Peers:") + for i, cfg := range peers { + desc := "" + if cfg.DisplayName != "" { + desc = fmt.Sprintf(" (%s)", cfg.DisplayName) + } + fmt.Printf(" %d\t%s%s\n", i, cfg.Identifier, desc) + } + return nil + }, + }, + }, + }, + { + Name: "import", + Aliases: []string{"i"}, + Usage: "import existing interface", + ArgsUsage: "", + Action: func(c *cli.Context) error { + if c.Args().Len() != 1 { + return errors.New("missing/invalid interface identifier") + } + importIdentifier := strings.TrimSpace(c.Args().Get(0)) + + err := backend.ImportInterface(persistence.InterfaceIdentifier(importIdentifier)) + if err != nil { + return err + } + + fmt.Println("Imported interface", importIdentifier) + + return nil + }, + }, +} + +func main() { + app := cli.NewApp() + app.Name = "wg-portal" + app.Version = "0.0.1" + app.Usage = "WireGuard Portal CLI client" + app.EnableBashCompletion = true + app.Commands = commands + app.Flags = globalFlags + app.Before = func(c *cli.Context) error { + dsn := c.String(dsnFlag) + database, err := persistence.NewDatabase(persistence.DatabaseConfig{ + Type: "sqlite", + DSN: dsn, + }) + if err != nil { + return errors.WithMessagef(err, "failed to initialize persistent store") + } + + backend, err = portal.NewBackend(database) + if err != nil { + return errors.WithMessagef(err, "backend failed to initialize") + } + return nil + } + + err := app.Run(os.Args) + if err != nil { + log.Fatal(err) + } +} diff --git a/internal/persistence/database.go b/internal/persistence/database.go index d3d2ebe..5d71d44 100644 --- a/internal/persistence/database.go +++ b/internal/persistence/database.go @@ -78,5 +78,5 @@ d.db = gormDb - return d, nil + return d, d.Migrate() } diff --git a/internal/persistence/migrations.go b/internal/persistence/migrations.go new file mode 100644 index 0000000..4f14f99 --- /dev/null +++ b/internal/persistence/migrations.go @@ -0,0 +1,7 @@ +package persistence + +func (d *Database) Migrate() error { + d.db.AutoMigrate(&InterfaceConfig{}, &User{}) + d.db.AutoMigrate(&PeerConfig{}) + return nil +} diff --git a/internal/persistence/models.go b/internal/persistence/models.go index 8341896..fa8a3d8 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -38,8 +38,8 @@ // WireGuard specific (for the [interface] section of the config file) - Identifier InterfaceIdentifier // device name, for example: wg0 - KeyPair KeyPair // private/public Key of the server interface + Identifier InterfaceIdentifier `gorm:"primaryKey"` // device name, for example: wg0 + KeyPair // private/public Key of the server interface ListenPort int // the listening port, for example: 51820 AddressStr string // the interface ip addresses, comma separated @@ -81,20 +81,20 @@ } type PeerInterfaceConfig struct { - Identifier InterfaceIdentifier // the interface identifier - Type InterfaceType // the interface type - PublicKey string // the interface public key + Identifier InterfaceIdentifier `gorm:"index;column:iface_identifier"` // the interface identifier + Type InterfaceType `gorm:"column:iface_type"` // the interface type + PublicKey string `gorm:"column:iface_pubkey"` // the interface public key - AddressStr StringConfigOption // the interface ip addresses, comma separated - DnsStr StringConfigOption // the dns server that should be set if the interface is up, comma separated - Mtu IntConfigOption // the device MTU - FirewallMark Int32ConfigOption // a firewall mark - RoutingTable StringConfigOption // the routing table + AddressStr StringConfigOption `gorm:"embedded;embeddedPrefix:iface_address_str_"` // the interface ip addresses, comma separated + DnsStr StringConfigOption `gorm:"embedded;embeddedPrefix:iface_dns_str_"` // the dns server that should be set if the interface is up, comma separated + Mtu IntConfigOption `gorm:"embedded;embeddedPrefix:iface_mtu_"` // the device MTU + FirewallMark Int32ConfigOption `gorm:"embedded;embeddedPrefix:iface_firewall_mark_"` // a firewall mark + RoutingTable StringConfigOption `gorm:"embedded;embeddedPrefix:iface_routing_table_"` // the routing table - PreUp StringConfigOption // action that is executed before the device is up - PostUp StringConfigOption // action that is executed after the device is up - PreDown StringConfigOption // action that is executed before the device is down - PostDown StringConfigOption // action that is executed after the device is down + PreUp StringConfigOption `gorm:"embedded;embeddedPrefix:iface_pre_up_"` // action that is executed before the device is up + PostUp StringConfigOption `gorm:"embedded;embeddedPrefix:iface_post_up_"` // action that is executed after the device is up + PreDown StringConfigOption `gorm:"embedded;embeddedPrefix:iface_pre_down_"` // action that is executed before the device is down + PostDown StringConfigOption `gorm:"embedded;embeddedPrefix:iface_post_down_"` // action that is executed after the device is down } type PeerConfig struct { @@ -102,21 +102,21 @@ // WireGuard specific (for the [peer] section of the config file) - Endpoint StringConfigOption // the endpoint address - AllowedIPsStr StringConfigOption // all allowed ip subnets, comma seperated + Endpoint StringConfigOption `gorm:"embedded;embeddedPrefix:endpoint_"` // the endpoint address + AllowedIPsStr StringConfigOption `gorm:"embedded;embeddedPrefix:allowed_ips_str_"` // all allowed ip subnets, comma seperated ExtraAllowedIPsStr string // all allowed ip subnets on the server side, comma seperated - KeyPair KeyPair // private/public Key of the peer + KeyPair // private/public Key of the peer PresharedKey string // the pre-shared Key of the peer - PersistentKeepalive IntConfigOption // the persistent keep-alive interval + PersistentKeepalive IntConfigOption `gorm:"embedded;embeddedPrefix:persistent_keep_alive_"` // the persistent keep-alive interval // WG Portal specific DisplayName string // a nice display name/ description for the peer - Identifier PeerIdentifier // peer unique identifier - UserIdentifier UserIdentifier // the owner + Identifier PeerIdentifier `gorm:"primaryKey"` // peer unique identifier + UserIdentifier UserIdentifier `gorm:"index"` // the owner // Interface settings for the peer, used to generate the [interface] section in the peer config file - Interface *PeerInterfaceConfig + Interface *PeerInterfaceConfig `gorm:"embedded"` } type UserSource string @@ -140,10 +140,10 @@ // User is the user model that gets linked to peer entries, by default an empty user model with only the email address is created type User struct { // required fields - Uid UserIdentifier `gorm:"primaryKey"` - Email string `form:"email" binding:"required,email"` - Source UserSource - IsAdmin bool + Identifier UserIdentifier `gorm:"primaryKey"` + Email string `form:"email" binding:"required,email"` + Source UserSource + IsAdmin bool // optional fields Firstname string `form:"firstname" binding:"omitempty"` diff --git a/internal/persistence/options.go b/internal/persistence/options.go index f8031d0..233c4d8 100644 --- a/internal/persistence/options.go +++ b/internal/persistence/options.go @@ -1,20 +1,12 @@ package persistence -// ConfigOption is an Overridable configuration option -type ConfigOption struct { - Value interface{} - Overridable bool -} - type StringConfigOption struct { - ConfigOption + Value string `gorm:"column:v"` + Overridable bool `gorm:"column:o"` } func (o StringConfigOption) GetValue() string { - if o.Value == nil { - return "" - } - return o.Value.(string) + return o.Value } func (o *StringConfigOption) SetValue(value string) { @@ -30,21 +22,19 @@ } func NewStringConfigOption(value string, overridable bool) StringConfigOption { - return StringConfigOption{ConfigOption{ + return StringConfigOption{ Value: value, Overridable: overridable, - }} + } } type IntConfigOption struct { - ConfigOption + Value int `gorm:"column:v"` + Overridable bool `gorm:"column:o"` } func (o IntConfigOption) GetValue() int { - if o.Value == nil { - return 0 - } - return o.Value.(int) + return o.Value } func (o *IntConfigOption) SetValue(value int) { @@ -60,22 +50,19 @@ } func NewIntConfigOption(value int, overridable bool) IntConfigOption { - return IntConfigOption{ConfigOption{ + return IntConfigOption{ Value: value, Overridable: overridable, - }} + } } type Int32ConfigOption struct { - ConfigOption + Value int32 `gorm:"column:v"` + Overridable bool `gorm:"column:o"` } func (o Int32ConfigOption) GetValue() int32 { - if o.Value == nil { - return 0 - } - - return o.Value.(int32) + return o.Value } func (o *Int32ConfigOption) SetValue(value int32) { @@ -91,22 +78,19 @@ } func NewInt32ConfigOption(value int32, overridable bool) Int32ConfigOption { - return Int32ConfigOption{ConfigOption{ + return Int32ConfigOption{ Value: value, Overridable: overridable, - }} + } } type BoolConfigOption struct { - ConfigOption + Value bool `gorm:"column:v"` + Overridable bool `gorm:"column:o"` } func (o BoolConfigOption) GetValue() bool { - if o.Value == nil { - return false - } - - return o.Value.(bool) + return o.Value } func (o *BoolConfigOption) SetValue(value bool) { @@ -122,8 +106,8 @@ } func NewBoolConfigOption(value bool, overridable bool) BoolConfigOption { - return BoolConfigOption{ConfigOption{ + return BoolConfigOption{ Value: value, Overridable: overridable, - }} + } } diff --git a/internal/persistence/users.go b/internal/persistence/users.go index 146c948..04d38fe 100644 --- a/internal/persistence/users.go +++ b/internal/persistence/users.go @@ -15,7 +15,7 @@ } func (d *Database) SaveUser(user User) error { - create := user.Uid == "" + create := user.Identifier == "" now := time.Now() user.UpdatedAt = now @@ -27,7 +27,7 @@ } } else { if err := d.db.Save(&user).Error; err != nil { - return errors.WithMessagef(err, "unable to update user %s", user.Uid) + return errors.WithMessagef(err, "unable to update user %s", user.Identifier) } } return nil diff --git a/internal/persistence/wireguard.go b/internal/persistence/wireguard.go index ee78270..318a3fc 100644 --- a/internal/persistence/wireguard.go +++ b/internal/persistence/wireguard.go @@ -21,14 +21,14 @@ func (d *Database) GetAllInterfaces(ids ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { var interfaces []InterfaceConfig - if err := d.db.Where("interface IN ?", ids).Find(&interfaces).Error; err != nil { + if err := d.db.Where("identifier IN ?", ids).Find(&interfaces).Error; err != nil { return nil, errors.WithMessage(err, "unable to find interfaces") } interfaceMap := make(map[InterfaceConfig][]PeerConfig, len(interfaces)) for i := range interfaces { var peers []PeerConfig - if err := d.db.Where("interface = ?", interfaces[i].Identifier).Find(&peers).Error; err != nil { + if err := d.db.Where("iface_identifier = ?", interfaces[i].Identifier).Find(&peers).Error; err != nil { return nil, errors.WithMessagef(err, "unable to find peers for %s", interfaces[i].Identifier) } interfaceMap[interfaces[i]] = peers @@ -44,7 +44,7 @@ } var peers []PeerConfig - if err := d.db.Where("interface = ?", id).Find(&peers).Error; err != nil { + if err := d.db.Where("identifier = ?", id).Find(&peers).Error; err != nil { return InterfaceConfig{}, nil, errors.WithMessage(err, "unable to find peers") } diff --git a/internal/portal/backend.go b/internal/portal/backend.go new file mode 100644 index 0000000..5875fea --- /dev/null +++ b/internal/portal/backend.go @@ -0,0 +1,73 @@ +package portal + +import ( + "github.com/h44z/wg-portal/internal/lowlevel" + "github.com/h44z/wg-portal/internal/persistence" + "github.com/h44z/wg-portal/internal/user" + "github.com/h44z/wg-portal/internal/wireguard" + "github.com/pkg/errors" + "golang.zx2c4.com/wireguard/wgctrl" +) + +// type alias +type UserManager = user.Manager +type WireGuardManager = wireguard.Manager + +type Backend struct { + UserManager + WireGuardManager +} + +func NewBackend(db *persistence.Database) (*Backend, error) { + wg, err := wgctrl.New() + if err != nil { + return nil, errors.WithMessage(err, "failed to get wgctrl handle") + } + + nl := &lowlevel.NetlinkManager{} + + wgm, err := wireguard.NewPersistentManager(wg, nl, db) + if err != nil { + return nil, errors.WithMessage(err, "failed to setup WireGuard manager") + } + + um, err := user.NewPersistentManager(db) + if err != nil { + return nil, errors.WithMessage(err, "failed to setup user manager") + } + + b := &Backend{ + UserManager: um, + WireGuardManager: wgm, + } + + return b, nil +} + +func (b *Backend) ImportInterface(identifier persistence.InterfaceIdentifier) error { + importable, err := b.GetImportableInterfaces() + if err != nil { + return errors.WithMessage(err, "failed to get importable interfaces") + } + + var interfaceConfig *wireguard.ImportableInterface + var peers []*persistence.PeerConfig + for cfg, peerList := range importable { + if cfg.Identifier == identifier { + interfaceConfig = cfg + peers = peerList + break + } + } + + if interfaceConfig == nil { + return errors.New("the given interface is not importable") + } + + err = b.WireGuardManager.ImportInterface(interfaceConfig, peers) + if err != nil { + return errors.WithMessagef(err, "failed to import interface") + } + + return nil +} diff --git a/internal/user/manager.go b/internal/user/manager.go index d50edbf..68646c1 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -50,10 +50,6 @@ } func NewPersistentManager(store store) (*PersistentManager, error) { - if store == nil { - return nil, errors.New("user manager requires a valid store object") - } - mgr := &PersistentManager{ store: store, @@ -91,7 +87,7 @@ // Order the users by uid sort.Slice(users, func(i, j int) bool { - return users[i].Uid < users[j].Uid + return users[i].Identifier < users[j].Identifier }) return users, nil @@ -108,7 +104,7 @@ // Order the users by uid sort.Slice(users, func(i, j int) bool { - return users[i].Uid < users[j].Uid + return users[i].Identifier < users[j].Identifier }) return users, nil @@ -127,7 +123,7 @@ // Order the users by uid sort.Slice(users, func(i, j int) bool { - return users[i].Uid < users[j].Uid + return users[i].Identifier < users[j].Identifier }) return users, nil @@ -141,13 +137,13 @@ p.mux.Lock() defer p.mux.Unlock() - if p.userExists(user.Uid) { + if p.userExists(user.Identifier) { return errors.New("user already exists") } - p.users[user.Uid] = user + p.users[user.Identifier] = user - err := p.persistUser(user.Uid, false) + err := p.persistUser(user.Identifier, false) if err != nil { return errors.WithMessage(err, "failed to persist created user") } @@ -163,13 +159,13 @@ p.mux.Lock() defer p.mux.Unlock() - if !p.userExists(user.Uid) { + if !p.userExists(user.Identifier) { return errors.New("user does not exists") } - p.users[user.Uid] = user + p.users[user.Identifier] = user - err := p.persistUser(user.Uid, false) + err := p.persistUser(user.Identifier, false) if err != nil { return errors.WithMessage(err, "failed to persist updated user") } @@ -210,7 +206,7 @@ for _, tmpUser := range users { user := tmpUser - p.users[user.Uid] = &user + p.users[user.Identifier] = &user } return nil @@ -252,7 +248,7 @@ if user == nil { return errors.New("user must not be nil") } - if user.Uid == "" { + if user.Identifier == "" { return errors.New("missing user identifier") } if user.Source == "" { diff --git a/internal/wireguard/functional_test.go b/internal/wireguard/functional_test.go index be9ecc4..a826f48 100644 --- a/internal/wireguard/functional_test.go +++ b/internal/wireguard/functional_test.go @@ -99,7 +99,7 @@ AddressStr: "10.98.87.76/24", Enabled: true, } - err = mgr.UpdateInterface(interfaceName, cfg) + err = mgr.UpdateInterface(cfg) assert.NoError(t, err) // Validate that the interface has been updated @@ -128,7 +128,7 @@ AddressStr: "10.98.87.76/24", Enabled: false, } - err = mgr.UpdateInterface(interfaceName, cfg) + err = mgr.UpdateInterface(cfg) assert.NoError(t, err) // Validate that the interface has been updated @@ -158,11 +158,11 @@ AddressStr: "10.98.87.76/24", Enabled: false, } - err = mgr.UpdateInterface(interfaceName, cfg) + err = mgr.UpdateInterface(cfg) assert.NoError(t, err) cfg.Enabled = true - err = mgr.UpdateInterface(interfaceName, cfg) + err = mgr.UpdateInterface(cfg) assert.NoError(t, err) // Validate that the interface has been updated diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index df2b543..6f70263 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -20,6 +20,7 @@ // InterfaceManager provides methods to create/update/delete physical WireGuard devices. type InterfaceManager interface { GetInterfaces() ([]*persistence.InterfaceConfig, error) + GetInterface(id persistence.InterfaceIdentifier) (*persistence.InterfaceConfig, error) CreateInterface(id persistence.InterfaceIdentifier) error DeleteInterface(id persistence.InterfaceIdentifier) error UpdateInterface(cfg *persistence.InterfaceConfig) error @@ -68,7 +69,7 @@ // type PersistentManager struct { - wgCtrlKeyGenerator + *wgCtrlKeyGenerator *templateHandler *wgCtrlManager } @@ -85,7 +86,7 @@ } m := &PersistentManager{ - wgCtrlKeyGenerator: wgCtrlKeyGenerator{}, + wgCtrlKeyGenerator: &wgCtrlKeyGenerator{}, wgCtrlManager: wgManager, templateHandler: tplManager, } diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 37b744f..fb10659 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -62,6 +62,17 @@ return interfaces, nil } +func (m *wgCtrlManager) GetInterface(id persistence.InterfaceIdentifier) (*persistence.InterfaceConfig, error) { + m.mux.RLock() + defer m.mux.RUnlock() + + if !m.deviceExists(id) { + return nil, errors.New("device does not exist") + } + + return m.interfaces[id], nil +} + func (m *wgCtrlManager) CreateInterface(id persistence.InterfaceIdentifier) error { m.mux.Lock() defer m.mux.Unlock() @@ -246,8 +257,8 @@ } peers := make([]*persistence.PeerConfig, 0, len(m.peers[interfaceId])) - for _, config := range m.peers[interfaceId] { - peers = append(peers, config) + for i := range m.peers[interfaceId] { + peers = append(peers, m.peers[interfaceId][i]) } sort.Slice(peers, func(i, j int) bool { @@ -424,8 +435,8 @@ if _, ok := m.peers[cfg.Identifier]; !ok { m.peers[cfg.Identifier] = make(map[persistence.PeerIdentifier]*persistence.PeerConfig) } - for _, peer := range peers { - m.peers[cfg.Identifier][peer.Identifier] = &peer + for p, peer := range peers { + m.peers[cfg.Identifier][peer.Identifier] = &peers[p] } } @@ -623,7 +634,7 @@ var endpoint *net.UDPAddr if cfg.Endpoint.Value != "" && devType == persistence.InterfaceTypeClient { - addr, err := net.ResolveUDPAddr("udp", cfg.Endpoint.Value.(string)) + addr, err := net.ResolveUDPAddr("udp", cfg.Endpoint.GetValue()) if err == nil { endpoint = addr }