diff --git a/assets/templates/views/register.gohtml b/assets/templates/views/register.gohtml
index 44ee692..149e1de 100644
--- a/assets/templates/views/register.gohtml
+++ b/assets/templates/views/register.gohtml
@@ -14,15 +14,6 @@
-
{{ .csrfField }}
diff --git a/handlers/auth.go b/handlers/auth.go
index 51043c1..504b27c 100644
--- a/handlers/auth.go
+++ b/handlers/auth.go
@@ -1,7 +1,16 @@
package handlers
import (
+ "bytes"
+ "fmt"
"net/http"
+ "time"
+
+ "git.klink.asia/paul/certman/views"
+
+ "github.com/go-chi/chi"
+
+ "github.com/gorilla/securecookie"
"git.klink.asia/paul/certman/services"
@@ -12,17 +21,28 @@ func RegisterHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
// Get parameters
email := req.Form.Get("email")
- password := req.Form.Get("password")
user := models.User{}
user.Email = email
- user.SetPassword(password)
- err := p.DB.Create(&user).Error
+ // don't set a password, user will get password reset request via mail
+ user.HashedPassword = []byte{}
+
+ err := p.DB.CreateUser(&user)
if err != nil {
panic(err.Error)
}
+ if err := createPasswordReset(p, &user); err != nil {
+ p.Sessions.Flash(w, req,
+ services.Flash{
+ Type: "danger",
+ Message: "The registration email could not be generated.",
+ },
+ )
+ http.Redirect(w, req, "/register", http.StatusFound)
+ }
+
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
@@ -41,9 +61,7 @@ func LoginHandler(p *services.Provider) http.HandlerFunc {
email := req.Form.Get("email")
password := req.Form.Get("password")
- user := models.User{}
-
- err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
+ user, err := p.DB.GetUserByEmail(email)
if err != nil {
// could not find user
p.Sessions.Flash(
@@ -82,3 +100,80 @@ func LoginHandler(p *services.Provider) http.HandlerFunc {
http.Redirect(w, req, "/certs", http.StatusSeeOther)
}
}
+
+func ConfirmEmailHandler(p *services.Provider) http.HandlerFunc {
+ return func(w http.ResponseWriter, req *http.Request) {
+ v := views.NewWithSession(req, p.Sessions)
+
+ switch req.Method {
+ case "GET":
+ token := chi.URLParam(req, "token")
+ pwr, err := p.DB.GetPasswordResetByToken(token)
+ _ = pwr
+ if err != nil {
+ v.RenderError(w, 404)
+ return
+ }
+ v.Render(w, "email-set-password")
+ case "POST":
+ password := req.Form.Get("password")
+ token := req.Form.Get("token")
+ pwr, err := p.DB.GetPasswordResetByToken(token)
+ if err != nil {
+ v.RenderError(w, 404)
+ return
+ }
+
+ user, err := p.DB.GetUserByID(pwr.UserID)
+ if err != nil {
+ v.RenderError(w, 500)
+ return
+ }
+
+ user.SetPassword(password)
+
+ //err := p.DB.UpdateUser(user.ID, &user)
+ if err != nil {
+ v.RenderError(w, 500)
+ return
+ }
+
+ err = p.DB.DeletePasswordResetsByUserID(pwr.UserID)
+
+ default:
+ v.RenderError(w, 405)
+ }
+
+ // try to get post params
+
+ fmt.Fprintln(w, "Okay.")
+ }
+}
+
+func createPasswordReset(p *services.Provider, user *models.User) error {
+ // create the reset request
+ pwr := models.PasswordReset{
+ UserID: user.ID,
+ Token: string(securecookie.GenerateRandomKey(32)),
+ ValidUntil: time.Now().Add(6 * time.Hour),
+ }
+
+ if err := p.DB.CreatePasswordReset(&pwr); err != nil {
+ return err
+ }
+
+ var subject string
+ var text *bytes.Buffer
+
+ if user.EmailValid {
+ subject = "Password reset"
+ text.WriteString("Somebody (hopefully you) has requested a password reset.\nClick below to reset your password:\n")
+ } else {
+ // If the user email has not been confirmed yet, send out
+ // "mail confirmation"-mail instead
+ subject = "Email confirmation"
+ text.WriteString("Hello, thanks you for signing up!\nClick below to verify this email address\n")
+ }
+
+ return p.Email.Send(user.Email, subject, text.String(), "")
+}
diff --git a/handlers/cert.go b/handlers/cert.go
index 9c52607..5900289 100644
--- a/handlers/cert.go
+++ b/handlers/cert.go
@@ -31,8 +31,7 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
email := p.Sessions.GetUserEmail(req)
certname := req.FormValue("certname")
- user := models.User{}
- err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
+ user, err := p.DB.GetUserByEmail(email)
if err != nil {
fmt.Printf("Could not fetch user for mail %s\n", email)
}
@@ -62,9 +61,10 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
}
// Insert client into database
- if err := p.DB.Create(&client).Error; err != nil {
- panic(err.Error())
- }
+ _ = client
+ //if err := p.DB.Create(&client).Error; err != nil {
+ // panic(err.Error())
+ //}
p.Sessions.Flash(w, req,
services.Flash{
diff --git a/main.go b/main.go
index a67e1a9..872ccea 100644
--- a/main.go
+++ b/main.go
@@ -30,6 +30,7 @@ func main() {
Lifetime: 24 * time.Hour,
},
Email: &services.EmailConfig{
+ SMTPEnabled: false,
SMTPServer: "example.com",
SMTPPort: 25,
SMTPUsername: "test",
diff --git a/models/model.go b/models/model.go
new file mode 100644
index 0000000..68c2719
--- /dev/null
+++ b/models/model.go
@@ -0,0 +1,39 @@
+package models
+
+import (
+ "errors"
+ "time"
+)
+
+var (
+ // ErrNotImplemented gets thrown if some action was not attempted,
+ // because it is not implemented in the code yet.
+ ErrNotImplemented = errors.New("Not implemented")
+)
+
+// Model is a base model definition, including helpful fields for dealing with
+// models in a database
+type Model struct {
+ ID uint `gorm:"primary_key"`
+ CreatedAt time.Time
+ UpdatedAt time.Time
+ DeletedAt *time.Time `sql:"index"`
+}
+
+// Client represent the OpenVPN client configuration
+type Client struct {
+ Model
+ Name string
+ User User
+ UserID uint
+ Cert []byte
+ PrivateKey []byte
+}
+
+type ClientProvider interface {
+ CountClients() (uint, error)
+ CreateClient(*User) (*User, error)
+ ListClients(count, offset int) ([]*User, error)
+ GetClientByID(id uint) (*User, error)
+ DeleteClient(id uint) error
+}
diff --git a/models/models.go b/models/user.go
similarity index 56%
rename from models/models.go
rename to models/user.go
index 00764fe..467fd15 100644
--- a/models/models.go
+++ b/models/user.go
@@ -1,27 +1,11 @@
package models
import (
- "errors"
"time"
"golang.org/x/crypto/bcrypt"
)
-var (
- // ErrNotImplemented gets thrown if some action was not attempted,
- // because it is not implemented in the code yet.
- ErrNotImplemented = errors.New("Not implemented")
-)
-
-// Model is a base model definition, including helpful fields for dealing with
-// models in a database
-type Model struct {
- ID uint `gorm:"primary_key"`
- CreatedAt time.Time
- UpdatedAt time.Time
- DeletedAt *time.Time `sql:"index"`
-}
-
// User represents a User of the system which is able to log in
type User struct {
Model
@@ -50,27 +34,22 @@ func (u *User) CheckPassword(password string) error {
type UserProvider interface {
CountUsers() (uint, error)
- CreateUser(*User) (*User, error)
+ CreateUser(*User) error
ListUsers(count, offset int) ([]*User, error)
GetUserByID(id uint) (*User, error)
GetUserByEmail(email string) (*User, error)
DeleteUser(id uint) error
}
-// Client represent the OpenVPN client configuration
-type Client struct {
+type PasswordReset struct {
Model
- Name string
- User User
+ User *User
UserID uint
- Cert []byte
- PrivateKey []byte
+ Token string
+ ValidUntil time.Time
}
-type ClientProvider interface {
- CountClients() (uint, error)
- CreateClient(*User) (*User, error)
- ListClients(count, offset int) ([]*User, error)
- GetClientByID(id uint) (*User, error)
- DeleteClient(id uint) error
+type PasswordResetProvider interface {
+ CreatePasswordReset(*PasswordReset) error
+ GetPasswordResetByToken(token string) (*PasswordReset, error)
}
diff --git a/router/router.go b/router/router.go
index 9ea982b..c9c1a64 100644
--- a/router/router.go
+++ b/router/router.go
@@ -63,7 +63,7 @@ func HandleRoutes(provider *services.Provider) http.Handler {
r.Post("/", handlers.LoginHandler(provider))
})
- //r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
+ r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(provider))
r.Route("/forgot-password", func(r chi.Router) {
r.Get("/", v("forgot-password"))
diff --git a/services/db.go b/services/db.go
index e36d892..b0c26bd 100644
--- a/services/db.go
+++ b/services/db.go
@@ -21,7 +21,7 @@ type DBConfig struct {
// DB is a wrapper around gorm.DB to provide custom methods
type DB struct {
- *gorm.DB
+ gorm *gorm.DB
conf *DBConfig
}
@@ -38,39 +38,69 @@ func NewDB(conf *DBConfig) *DB {
db.LogMode(conf.Log)
return &DB{
- DB: db,
+ gorm: db,
conf: conf,
}
}
// CountUsers returns the number of Users in the datastore
func (db *DB) CountUsers() (uint, error) {
- return 0, ErrNotImplemented
+ var count uint
+ err := db.gorm.Find(&models.User{}).Count(&count).Error
+ return count, err
}
// CreateUser inserts a user into the datastore
-func (db *DB) CreateUser(*models.User) (*models.User, error) {
- return nil, ErrNotImplemented
+func (db *DB) CreateUser(user *models.User) error {
+ err := db.gorm.Create(&user).Error
+ return err
}
// ListUsers returns a slice of 'count' users, starting at 'offset'
func (db *DB) ListUsers(count, offset int) ([]*models.User, error) {
var users = make([]*models.User, 0)
- return users, ErrNotImplemented
+ err := db.gorm.Find(&users).Limit(count).Offset(offset).Error
+
+ return users, err
}
// GetUserByID returns a single user by ID
func (db *DB) GetUserByID(id uint) (*models.User, error) {
- return nil, ErrNotImplemented
+ var user models.User
+ err := db.gorm.Where("id = ?", id).First(&user).Error
+ return &user, err
}
// GetUserByEmail returns a single user by email
func (db *DB) GetUserByEmail(email string) (*models.User, error) {
- return nil, ErrNotImplemented
+ var user models.User
+ err := db.gorm.Where("email = ?", email).First(&user).Error
+ return &user, err
}
// DeleteUser removes a user from the datastore
func (db *DB) DeleteUser(id uint) error {
- return ErrNotImplemented
+ var user models.User
+ err := db.gorm.Where("id = ?", id).Delete(&user).Error
+ return err
+}
+
+// CreatePasswordReset creates a new password reset token
+func (db *DB) CreatePasswordReset(pwReset *models.PasswordReset) error {
+ err := db.gorm.Create(&pwReset).Error
+ return err
+}
+
+// GetPasswordResetByToken retrieves a PasswordReset by token
+func (db *DB) GetPasswordResetByToken(token string) (*models.PasswordReset, error) {
+ var pwReset models.PasswordReset
+ err := db.gorm.Where("token = ?", token).First(&pwReset).Error
+ return &pwReset, err
+}
+
+// DeletePasswordResetsByUserID deletes all pending password resets for a user
+func (db *DB) DeletePasswordResetsByUserID(uid uint) error {
+ err := db.gorm.Where("user_id = ?", uid).Delete(&models.PasswordReset{}).Error
+ return err
}
diff --git a/services/email.go b/services/email.go
index a0d9cca..baea8d9 100644
--- a/services/email.go
+++ b/services/email.go
@@ -14,6 +14,7 @@ var (
type EmailConfig struct {
From string
+ SMTPEnabled bool
SMTPServer string
SMTPPort int
SMTPUsername string
@@ -45,12 +46,19 @@ func (email *Email) Send(to, subject, text, html string) error {
return ErrMailUninitializedConfig
}
+ if !email.config.SMTPEnabled {
+ log.Printf("SMTP is disabled in config, printing out email text instead:\nTo: %s\n%s", to, text)
+ }
+
m := mail.NewMessage()
m.SetHeader("From", email.config.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", text)
- m.AddAlternative("text/html", html)
+
+ if len(html) > 0 {
+ m.AddAlternative("text/html", html)
+ }
// put email in chan
email.mailChan <- m
@@ -65,6 +73,11 @@ func (email *Email) Daemon() {
return
}
+ if !email.config.SMTPEnabled {
+ log.Print("SMTP is disabled in config, emails will be printed instead.")
+ return
+ }
+
log.Print("Running mail sending routine")
d := mail.NewDialer(