diff --git a/assets/templates/views/register.gohtml b/assets/templates/views/register.gohtml index 44ee692..149e1de 100644 --- a/assets/templates/views/register.gohtml +++ b/assets/templates/views/register.gohtml @@ -14,15 +14,6 @@ -
- -
- - - - -
-
{{ .csrfField }}
diff --git a/handlers/auth.go b/handlers/auth.go index 51043c1..504b27c 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -1,7 +1,16 @@ package handlers import ( + "bytes" + "fmt" "net/http" + "time" + + "git.klink.asia/paul/certman/views" + + "github.com/go-chi/chi" + + "github.com/gorilla/securecookie" "git.klink.asia/paul/certman/services" @@ -12,17 +21,28 @@ func RegisterHandler(p *services.Provider) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { // Get parameters email := req.Form.Get("email") - password := req.Form.Get("password") user := models.User{} user.Email = email - user.SetPassword(password) - err := p.DB.Create(&user).Error + // don't set a password, user will get password reset request via mail + user.HashedPassword = []byte{} + + err := p.DB.CreateUser(&user) if err != nil { panic(err.Error) } + if err := createPasswordReset(p, &user); err != nil { + p.Sessions.Flash(w, req, + services.Flash{ + Type: "danger", + Message: "The registration email could not be generated.", + }, + ) + http.Redirect(w, req, "/register", http.StatusFound) + } + p.Sessions.Flash(w, req, services.Flash{ Type: "success", @@ -41,9 +61,7 @@ func LoginHandler(p *services.Provider) http.HandlerFunc { email := req.Form.Get("email") password := req.Form.Get("password") - user := models.User{} - - err := p.DB.Where(&models.User{Email: email}).Find(&user).Error + user, err := p.DB.GetUserByEmail(email) if err != nil { // could not find user p.Sessions.Flash( @@ -82,3 +100,80 @@ func LoginHandler(p *services.Provider) http.HandlerFunc { http.Redirect(w, req, "/certs", http.StatusSeeOther) } } + +func ConfirmEmailHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + v := views.NewWithSession(req, p.Sessions) + + switch req.Method { + case "GET": + token := chi.URLParam(req, "token") + pwr, err := p.DB.GetPasswordResetByToken(token) + _ = pwr + if err != nil { + v.RenderError(w, 404) + return + } + v.Render(w, "email-set-password") + case "POST": + password := req.Form.Get("password") + token := req.Form.Get("token") + pwr, err := p.DB.GetPasswordResetByToken(token) + if err != nil { + v.RenderError(w, 404) + return + } + + user, err := p.DB.GetUserByID(pwr.UserID) + if err != nil { + v.RenderError(w, 500) + return + } + + user.SetPassword(password) + + //err := p.DB.UpdateUser(user.ID, &user) + if err != nil { + v.RenderError(w, 500) + return + } + + err = p.DB.DeletePasswordResetsByUserID(pwr.UserID) + + default: + v.RenderError(w, 405) + } + + // try to get post params + + fmt.Fprintln(w, "Okay.") + } +} + +func createPasswordReset(p *services.Provider, user *models.User) error { + // create the reset request + pwr := models.PasswordReset{ + UserID: user.ID, + Token: string(securecookie.GenerateRandomKey(32)), + ValidUntil: time.Now().Add(6 * time.Hour), + } + + if err := p.DB.CreatePasswordReset(&pwr); err != nil { + return err + } + + var subject string + var text *bytes.Buffer + + if user.EmailValid { + subject = "Password reset" + text.WriteString("Somebody (hopefully you) has requested a password reset.\nClick below to reset your password:\n") + } else { + // If the user email has not been confirmed yet, send out + // "mail confirmation"-mail instead + subject = "Email confirmation" + text.WriteString("Hello, thanks you for signing up!\nClick below to verify this email address\n") + } + + return p.Email.Send(user.Email, subject, text.String(), "") +} diff --git a/handlers/cert.go b/handlers/cert.go index 9c52607..5900289 100644 --- a/handlers/cert.go +++ b/handlers/cert.go @@ -31,8 +31,7 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc { email := p.Sessions.GetUserEmail(req) certname := req.FormValue("certname") - user := models.User{} - err := p.DB.Where(&models.User{Email: email}).Find(&user).Error + user, err := p.DB.GetUserByEmail(email) if err != nil { fmt.Printf("Could not fetch user for mail %s\n", email) } @@ -62,9 +61,10 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc { } // Insert client into database - if err := p.DB.Create(&client).Error; err != nil { - panic(err.Error()) - } + _ = client + //if err := p.DB.Create(&client).Error; err != nil { + // panic(err.Error()) + //} p.Sessions.Flash(w, req, services.Flash{ diff --git a/main.go b/main.go index a67e1a9..872ccea 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,7 @@ func main() { Lifetime: 24 * time.Hour, }, Email: &services.EmailConfig{ + SMTPEnabled: false, SMTPServer: "example.com", SMTPPort: 25, SMTPUsername: "test", diff --git a/models/model.go b/models/model.go new file mode 100644 index 0000000..68c2719 --- /dev/null +++ b/models/model.go @@ -0,0 +1,39 @@ +package models + +import ( + "errors" + "time" +) + +var ( + // ErrNotImplemented gets thrown if some action was not attempted, + // because it is not implemented in the code yet. + ErrNotImplemented = errors.New("Not implemented") +) + +// Model is a base model definition, including helpful fields for dealing with +// models in a database +type Model struct { + ID uint `gorm:"primary_key"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time `sql:"index"` +} + +// Client represent the OpenVPN client configuration +type Client struct { + Model + Name string + User User + UserID uint + Cert []byte + PrivateKey []byte +} + +type ClientProvider interface { + CountClients() (uint, error) + CreateClient(*User) (*User, error) + ListClients(count, offset int) ([]*User, error) + GetClientByID(id uint) (*User, error) + DeleteClient(id uint) error +} diff --git a/models/models.go b/models/user.go similarity index 56% rename from models/models.go rename to models/user.go index 00764fe..467fd15 100644 --- a/models/models.go +++ b/models/user.go @@ -1,27 +1,11 @@ package models import ( - "errors" "time" "golang.org/x/crypto/bcrypt" ) -var ( - // ErrNotImplemented gets thrown if some action was not attempted, - // because it is not implemented in the code yet. - ErrNotImplemented = errors.New("Not implemented") -) - -// Model is a base model definition, including helpful fields for dealing with -// models in a database -type Model struct { - ID uint `gorm:"primary_key"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt *time.Time `sql:"index"` -} - // User represents a User of the system which is able to log in type User struct { Model @@ -50,27 +34,22 @@ func (u *User) CheckPassword(password string) error { type UserProvider interface { CountUsers() (uint, error) - CreateUser(*User) (*User, error) + CreateUser(*User) error ListUsers(count, offset int) ([]*User, error) GetUserByID(id uint) (*User, error) GetUserByEmail(email string) (*User, error) DeleteUser(id uint) error } -// Client represent the OpenVPN client configuration -type Client struct { +type PasswordReset struct { Model - Name string - User User + User *User UserID uint - Cert []byte - PrivateKey []byte + Token string + ValidUntil time.Time } -type ClientProvider interface { - CountClients() (uint, error) - CreateClient(*User) (*User, error) - ListClients(count, offset int) ([]*User, error) - GetClientByID(id uint) (*User, error) - DeleteClient(id uint) error +type PasswordResetProvider interface { + CreatePasswordReset(*PasswordReset) error + GetPasswordResetByToken(token string) (*PasswordReset, error) } diff --git a/router/router.go b/router/router.go index 9ea982b..c9c1a64 100644 --- a/router/router.go +++ b/router/router.go @@ -63,7 +63,7 @@ func HandleRoutes(provider *services.Provider) http.Handler { r.Post("/", handlers.LoginHandler(provider)) }) - //r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db)) + r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(provider)) r.Route("/forgot-password", func(r chi.Router) { r.Get("/", v("forgot-password")) diff --git a/services/db.go b/services/db.go index e36d892..b0c26bd 100644 --- a/services/db.go +++ b/services/db.go @@ -21,7 +21,7 @@ type DBConfig struct { // DB is a wrapper around gorm.DB to provide custom methods type DB struct { - *gorm.DB + gorm *gorm.DB conf *DBConfig } @@ -38,39 +38,69 @@ func NewDB(conf *DBConfig) *DB { db.LogMode(conf.Log) return &DB{ - DB: db, + gorm: db, conf: conf, } } // CountUsers returns the number of Users in the datastore func (db *DB) CountUsers() (uint, error) { - return 0, ErrNotImplemented + var count uint + err := db.gorm.Find(&models.User{}).Count(&count).Error + return count, err } // CreateUser inserts a user into the datastore -func (db *DB) CreateUser(*models.User) (*models.User, error) { - return nil, ErrNotImplemented +func (db *DB) CreateUser(user *models.User) error { + err := db.gorm.Create(&user).Error + return err } // ListUsers returns a slice of 'count' users, starting at 'offset' func (db *DB) ListUsers(count, offset int) ([]*models.User, error) { var users = make([]*models.User, 0) - return users, ErrNotImplemented + err := db.gorm.Find(&users).Limit(count).Offset(offset).Error + + return users, err } // GetUserByID returns a single user by ID func (db *DB) GetUserByID(id uint) (*models.User, error) { - return nil, ErrNotImplemented + var user models.User + err := db.gorm.Where("id = ?", id).First(&user).Error + return &user, err } // GetUserByEmail returns a single user by email func (db *DB) GetUserByEmail(email string) (*models.User, error) { - return nil, ErrNotImplemented + var user models.User + err := db.gorm.Where("email = ?", email).First(&user).Error + return &user, err } // DeleteUser removes a user from the datastore func (db *DB) DeleteUser(id uint) error { - return ErrNotImplemented + var user models.User + err := db.gorm.Where("id = ?", id).Delete(&user).Error + return err +} + +// CreatePasswordReset creates a new password reset token +func (db *DB) CreatePasswordReset(pwReset *models.PasswordReset) error { + err := db.gorm.Create(&pwReset).Error + return err +} + +// GetPasswordResetByToken retrieves a PasswordReset by token +func (db *DB) GetPasswordResetByToken(token string) (*models.PasswordReset, error) { + var pwReset models.PasswordReset + err := db.gorm.Where("token = ?", token).First(&pwReset).Error + return &pwReset, err +} + +// DeletePasswordResetsByUserID deletes all pending password resets for a user +func (db *DB) DeletePasswordResetsByUserID(uid uint) error { + err := db.gorm.Where("user_id = ?", uid).Delete(&models.PasswordReset{}).Error + return err } diff --git a/services/email.go b/services/email.go index a0d9cca..baea8d9 100644 --- a/services/email.go +++ b/services/email.go @@ -14,6 +14,7 @@ var ( type EmailConfig struct { From string + SMTPEnabled bool SMTPServer string SMTPPort int SMTPUsername string @@ -45,12 +46,19 @@ func (email *Email) Send(to, subject, text, html string) error { return ErrMailUninitializedConfig } + if !email.config.SMTPEnabled { + log.Printf("SMTP is disabled in config, printing out email text instead:\nTo: %s\n%s", to, text) + } + m := mail.NewMessage() m.SetHeader("From", email.config.From) m.SetHeader("To", to) m.SetHeader("Subject", subject) m.SetBody("text/plain", text) - m.AddAlternative("text/html", html) + + if len(html) > 0 { + m.AddAlternative("text/html", html) + } // put email in chan email.mailChan <- m @@ -65,6 +73,11 @@ func (email *Email) Daemon() { return } + if !email.config.SMTPEnabled { + log.Print("SMTP is disabled in config, emails will be printed instead.") + return + } + log.Print("Running mail sending routine") d := mail.NewDialer(