diff --git a/go.mod b/go.mod index 17ccbab..55c2d3b 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,19 @@ go 1.16 require ( - github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/kr/text v0.2.0 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pkg/errors v0.9.1 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 github.com/vishvananda/netlink v1.1.0 - golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97 // indirect + golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect golang.zx2c4.com/wireguard/wgctrl v0.0.0-20210506160403-92e472f520a5 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gorm.io/driver/mysql v1.1.2 - gorm.io/gorm v1.21.14 + gorm.io/driver/postgres v1.1.2 + gorm.io/driver/sqlite v1.1.6 + gorm.io/driver/sqlserver v1.0.9 + gorm.io/gorm v1.21.15 ) diff --git a/internal/persistence/database.go b/internal/persistence/database.go index dc7cf83..d3d2ebe 100644 --- a/internal/persistence/database.go +++ b/internal/persistence/database.go @@ -1 +1,82 @@ package persistence + +import ( + "os" + "path/filepath" + "time" + + "github.com/pkg/errors" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" +) + +type SupportedDatabase string + +const ( + SupportedDatabaseMySQL SupportedDatabase = "mysql" + SupportedDatabaseMsSQL SupportedDatabase = "mssql" + SupportedDatabasePostgres SupportedDatabase = "postgres" + SupportedDatabaseSQLite SupportedDatabase = "sqlite" +) + +type DatabaseFilterCondition func(tx *gorm.DB) *gorm.DB + +type DatabaseConfig struct { + Type SupportedDatabase + DSN string // On SQLite: the database file-path, otherwise the dsn (see: https://gorm.io/docs/connecting_to_the_database.html) +} + +type Database struct { + db *gorm.DB +} + +func NewDatabase(cfg DatabaseConfig) (*Database, error) { + d := &Database{} + + var gormDb *gorm.DB + var err error + + switch cfg.Type { + case SupportedDatabaseMySQL: + gormDb, err = gorm.Open(mysql.Open(cfg.DSN), &gorm.Config{}) + if err != nil { + return nil, errors.WithMessage(err, "failed to open MySQL database") + } + + sqlDB, _ := gormDb.DB() + sqlDB.SetConnMaxLifetime(time.Minute * 5) + sqlDB.SetMaxIdleConns(2) + sqlDB.SetMaxOpenConns(10) + err = sqlDB.Ping() // This DOES open a connection if necessary. This makes sure the database is accessible + if err != nil { + return nil, errors.WithMessage(err, "failed to ping MySQL database") + } + case SupportedDatabaseMsSQL: + gormDb, err = gorm.Open(sqlserver.Open(cfg.DSN), &gorm.Config{}) + if err != nil { + return nil, errors.WithMessage(err, "failed to open sqlserver database") + } + case SupportedDatabasePostgres: + gormDb, err = gorm.Open(postgres.Open(cfg.DSN), &gorm.Config{}) + if err != nil { + return nil, errors.WithMessage(err, "failed to open Postgres database") + } + case SupportedDatabaseSQLite: + if _, err = os.Stat(filepath.Dir(cfg.DSN)); os.IsNotExist(err) { + if err = os.MkdirAll(filepath.Dir(cfg.DSN), 0700); err != nil { + return nil, errors.WithMessage(err, "failed to create database base directory") + } + } + gormDb, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}) + if err != nil { + return nil, errors.WithMessage(err, "failed to open sqlite database") + } + } + + d.db = gormDb + + return d, nil +} diff --git a/internal/persistence/models.go b/internal/persistence/models.go index acb58ce..2238f1d 100644 --- a/internal/persistence/models.go +++ b/internal/persistence/models.go @@ -154,5 +154,5 @@ // database internal fields CreatedAt time.Time UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty" swaggertype:"string"` + DeletedAt gorm.DeletedAt `gorm:"index" json:",omitempty"` } diff --git a/internal/persistence/users.go b/internal/persistence/users.go index 8f3d56a..a181852 100644 --- a/internal/persistence/users.go +++ b/internal/persistence/users.go @@ -1,19 +1,69 @@ package persistence -import "gorm.io/gorm" +import ( + "time" -type UserFilterCondition func(tx *gorm.DB) + "github.com/pkg/errors" +) -type UsersLoader interface { - GetUser(id UserIdentifier) (User, error) - GetUsers() ([]User, error) - GetUsersUnscoped() ([]User, error) - GetUsersFiltered(filter ...UserFilterCondition) ([]User, error) +func (d *Database) GetUser(id UserIdentifier) (User, error) { + var user User + if err := d.db.First(&user, id).Error; err != nil { + return User{}, errors.WithMessagef(err, "unable to find user %s", id) + } + return user, nil } -type Users interface { - UsersLoader +func (d *Database) GetUsers() ([]User, error) { + var users []User + if err := d.db.Find(&users).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find users") + } + return users, nil +} - SaveUser(user User) error - DeleteUser(identifier UserIdentifier) error +func (d *Database) GetUsersUnscoped() ([]User, error) { + var users []User + if err := d.db.Unscoped().Find(&users).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find unscoped users") + } + return users, nil +} + +func (d *Database) GetUsersFiltered(filters ...DatabaseFilterCondition) ([]User, error) { + var users []User + tx := d.db + for _, filter := range filters { + tx = filter(tx) + } + if err := tx.Find(&users).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find filtered users") + } + return users, nil +} + +func (d *Database) SaveUser(user User) error { + create := user.Uid == "" + now := time.Now() + + user.UpdatedAt = now + + if create { + user.CreatedAt = now + if err := d.db.Create(&user).Error; err != nil { + return errors.WithMessage(err, "unable to create new user") + } + } else { + if err := d.db.Save(&user).Error; err != nil { + return errors.WithMessagef(err, "unable to update user %s", user.Uid) + } + } + return nil +} + +func (d *Database) DeleteUser(id UserIdentifier) error { + if err := d.db.Delete(&User{}, id).Error; err != nil { + return errors.WithMessagef(err, "unable to delete user %s", id) + } + return nil } diff --git a/internal/persistence/wireguard.go b/internal/persistence/wireguard.go index 9f320f1..4a3e132 100644 --- a/internal/persistence/wireguard.go +++ b/internal/persistence/wireguard.go @@ -1,14 +1,71 @@ package persistence -type WireGuard interface { - GetAvailableInterfaces() ([]InterfaceIdentifier, error) +import ( + "github.com/pkg/errors" + "gorm.io/gorm/clause" +) - GetAllInterfaces(interfaceIdentifiers ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) - GetInterface(identifier InterfaceIdentifier) (InterfaceConfig, []PeerConfig, error) +func (d *Database) GetAvailableInterfaces() ([]InterfaceIdentifier, error) { + var interfaces []InterfaceConfig + if err := d.db.Select("identifier").Find(&interfaces).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find interfaces") + } - SaveInterface(cfg InterfaceConfig, peers []PeerConfig) error - SavePeer(peer PeerConfig, interfaceIdentifier InterfaceIdentifier) error + interfaceIds := make([]InterfaceIdentifier, len(interfaces)) + for i := range interfaces { + interfaceIds[i] = interfaces[i].Identifier + } - DeleteInterface(identifier InterfaceIdentifier) error - DeletePeer(peer PeerIdentifier, interfaceIdentifier InterfaceIdentifier) error + return interfaceIds, nil +} + +func (d *Database) GetAllInterfaces(ids ...InterfaceIdentifier) (map[InterfaceConfig][]PeerConfig, error) { + var interfaces []InterfaceConfig + if err := d.db.Where("interface IN ?", ids).Find(&interfaces).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find interfaces") + } + + interfaceMap := make(map[InterfaceConfig][]PeerConfig, len(interfaces)) + for i := range interfaces { + var peers []PeerConfig + if err := d.db.Where("interface = ?", interfaces[i].Identifier).Find(&peers).Error; err != nil { + return nil, errors.WithMessagef(err, "unable to find peers for %s", interfaces[i].Identifier) + } + interfaceMap[interfaces[i]] = peers + } + + return interfaceMap, nil +} + +func (d *Database) GetInterface(id InterfaceIdentifier) (InterfaceConfig, []PeerConfig, error) { + var iface InterfaceConfig + if err := d.db.First(&iface, id).Error; err != nil { + return InterfaceConfig{}, nil, errors.WithMessagef(err, "unable to find interface %s", id) + } + + var peers []PeerConfig + if err := d.db.Where("interface = ?", id).Find(&peers).Error; err != nil { + return InterfaceConfig{}, nil, errors.WithMessagef(err, "unable to find peers for %s", id) + } + + return iface, peers, nil +} + +func (d *Database) SaveInterface(cfg InterfaceConfig) error { + d.db.Clauses(clause.OnConflict{ + UpdateAll: true, + }).Create(&cfg) + return nil +} + +func (d *Database) SavePeer(peer PeerConfig, id InterfaceIdentifier) error { + return nil +} + +func (d *Database) DeleteInterface(id InterfaceIdentifier) error { + return nil +} + +func (d *Database) DeletePeer(peerId PeerIdentifier, id InterfaceIdentifier) error { + return nil } diff --git a/internal/user/authentication.go b/internal/user/authentication.go index b200aa5..a25bc4c 100644 --- a/internal/user/authentication.go +++ b/internal/user/authentication.go @@ -1,4 +1,56 @@ package user -type Authenticator interface { +import ( + "crypto/subtle" + + "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" + "golang.org/x/crypto/bcrypt" +) + +type PasswordAuthenticator struct { + store store +} + +func NewPasswordAuthenticator(store store) (*PasswordAuthenticator, error) { + a := &PasswordAuthenticator{ + store: store, + } + + return a, nil +} + +func (p *PasswordAuthenticator) PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error { + user, err := p.store.GetUser(userId) + if err != nil { + return errors.WithMessagef(err, "unable to load user %s", userId) + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(plainPassword)); err != nil { + return errors.WithMessage(err, "invalid password") + } + + return nil +} + +func (p *PasswordAuthenticator) HashedAuthentication(userId persistence.UserIdentifier, hashedPassword string) error { + user, err := p.store.GetUser(userId) + if err != nil { + return errors.WithMessagef(err, "unable to load user %s", userId) + } + + if subtle.ConstantTimeCompare([]byte(user.Password), []byte(hashedPassword)) != 1 { + return errors.New("invalid password") + } + + return nil +} + +func (p *PasswordAuthenticator) HashPassword(plain string) (string, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(plain), bcrypt.DefaultCost) + if err != nil { + return "", errors.WithMessage(err, "failed to hash password") + } + + return string(hash), nil } diff --git a/internal/user/manager.go b/internal/user/manager.go index 7ed2597..8bd2dda 100644 --- a/internal/user/manager.go +++ b/internal/user/manager.go @@ -2,8 +2,88 @@ import ( "github.com/h44z/wg-portal/internal/persistence" + "github.com/pkg/errors" ) +type Loader interface { + GetUser(id persistence.UserIdentifier) (persistence.User, error) + GetActiveUsers() ([]persistence.User, error) + GetAllUsers() ([]persistence.User, error) + GetFilteredUsers(filter ...filterCondition) ([]persistence.User, error) +} + +type Updater interface { + CreateUser(user persistence.User) error + UpdateUser(user persistence.User) error + DeleteUser(identifier persistence.UserIdentifier) error +} + +type Authenticator interface { + PlaintextAuthentication(userId persistence.UserIdentifier, plainPassword string) error + HashedAuthentication(userId persistence.UserIdentifier, hashedPassword string) error +} + +type PasswordHasher interface { + HashPassword(plain string) (string, error) +} + type Manager interface { - persistence.UsersLoader + Loader + Updater + Authenticator + PasswordHasher +} + +type PersistentManager struct { + store store + + authenticator Authenticator + hasher PasswordHasher +} + +func NewPersistentManager(store store) (*PersistentManager, error) { + if store == nil { + return nil, errors.New("user manager requires a valid store object") + } + + pwa, err := NewPasswordAuthenticator(store) + if err != nil { + return nil, errors.WithMessage(err, "failed to initialize authenticator") + } + + mgr := &PersistentManager{ + store: store, + authenticator: pwa, + hasher: pwa, + } + + return mgr, nil +} + +func (p *PersistentManager) GetUser(id persistence.UserIdentifier) (persistence.User, error) { + return p.store.GetUser(id) +} + +func (p *PersistentManager) GetActiveUsers() ([]persistence.User, error) { + return p.store.GetUsers() +} + +func (p *PersistentManager) GetAllUsers() ([]persistence.User, error) { + return p.store.GetUsersUnscoped() +} + +func (p *PersistentManager) GetFilteredUsers(filter ...filterCondition) ([]persistence.User, error) { + return p.store.GetUsersFiltered(filter...) +} + +func (p *PersistentManager) CreateUser(user persistence.User) error { + return p.store.SaveUser(user) +} + +func (p *PersistentManager) UpdateUser(user persistence.User) error { + return p.store.SaveUser(user) +} + +func (p *PersistentManager) DeleteUser(identifier persistence.UserIdentifier) error { + return p.store.DeleteUser(identifier) } diff --git a/internal/user/persistence.go b/internal/user/persistence.go new file mode 100644 index 0000000..ccc273d --- /dev/null +++ b/internal/user/persistence.go @@ -0,0 +1,14 @@ +package user + +import ( + "github.com/h44z/wg-portal/internal/persistence" +) + +type store interface { + GetUser(id persistence.UserIdentifier) (persistence.User, error) + GetUsers() ([]persistence.User, error) + GetUsersUnscoped() ([]persistence.User, error) + GetUsersFiltered(filters ...persistence.DatabaseFilterCondition) ([]persistence.User, error) + SaveUser(user persistence.User) error + DeleteUser(identifier persistence.UserIdentifier) error +} diff --git a/internal/wireguard/manager.go b/internal/wireguard/manager.go index 4b5477b..d16545e 100644 --- a/internal/wireguard/manager.go +++ b/internal/wireguard/manager.go @@ -51,6 +51,8 @@ PeerManager ImportManager ConfigFileGenerator + + ApplyDefaultConfigs(device persistence.InterfaceIdentifier) error } // @@ -82,3 +84,8 @@ return m, nil } + +func (p *PersistentManager) ApplyDefaultConfigs(device persistence.InterfaceIdentifier) error { + // TODO: implement + return nil +} diff --git a/internal/wireguard/persistence.go b/internal/wireguard/persistence.go index 0431c74..c1e1b7c 100644 --- a/internal/wireguard/persistence.go +++ b/internal/wireguard/persistence.go @@ -10,7 +10,7 @@ GetAllInterfaces(interfaceIdentifiers ...persistence.InterfaceIdentifier) (map[persistence.InterfaceConfig][]persistence.PeerConfig, error) GetInterface(identifier persistence.InterfaceIdentifier) (persistence.InterfaceConfig, []persistence.PeerConfig, error) - SaveInterface(cfg persistence.InterfaceConfig, peers []persistence.PeerConfig) error + SaveInterface(cfg persistence.InterfaceConfig) error SavePeer(peer persistence.PeerConfig, interfaceIdentifier persistence.InterfaceIdentifier) error DeleteInterface(identifier persistence.InterfaceIdentifier) error diff --git a/internal/wireguard/wireguard.go b/internal/wireguard/wireguard.go index 7c1fbe3..52cd7bf 100644 --- a/internal/wireguard/wireguard.go +++ b/internal/wireguard/wireguard.go @@ -383,17 +383,11 @@ return nil // nothing to do } - device := m.interfaces[id] - peers := make([]persistence.PeerConfig, 0, len(m.peers[id])) - for _, config := range m.peers[id] { - peers = append(peers, config) - } - var err error if delete { err = m.store.DeleteInterface(id) } else { - err = m.store.SaveInterface(device, peers) + err = m.store.SaveInterface(m.interfaces[id]) } if err != nil { return errors.Wrapf(err, "failed to persist interface")