Refactor global variables into injected dependencies

This commit is contained in:
Paul 2018-01-29 16:52:59 +01:00
parent e5cd6e0be3
commit 4a79379d8e
10 changed files with 366 additions and 182 deletions

View file

@ -8,73 +8,77 @@ import (
"git.klink.asia/paul/certman/models" "git.klink.asia/paul/certman/models"
) )
func RegisterHandler(w http.ResponseWriter, req *http.Request) { func RegisterHandler(p *services.Provider) http.HandlerFunc {
// Get parameters return func(w http.ResponseWriter, req *http.Request) {
email := req.Form.Get("email") // Get parameters
password := req.Form.Get("password") email := req.Form.Get("email")
password := req.Form.Get("password")
user := models.User{} user := models.User{}
user.Email = email user.Email = email
user.SetPassword(password) user.SetPassword(password)
err := services.Database.Create(&user).Error err := p.DB.Create(&user).Error
if err != nil { if err != nil {
panic(err.Error) panic(err.Error)
}
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
Message: "The user was created. Check your inbox for the confirmation email.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
} }
services.SessionStore.Flash(w, req,
services.Flash{
Type: "success",
Message: "The user was created. Check your inbox for the confirmation email.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
} }
func LoginHandler(w http.ResponseWriter, req *http.Request) { func LoginHandler(p *services.Provider) http.HandlerFunc {
// Get parameters return func(w http.ResponseWriter, req *http.Request) {
email := req.Form.Get("email") // Get parameters
password := req.Form.Get("password") email := req.Form.Get("email")
password := req.Form.Get("password")
user := models.User{} user := models.User{}
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
if err != nil { if err != nil {
// could not find user // could not find user
services.SessionStore.Flash( p.Sessions.Flash(
w, req, services.Flash{ w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.", Type: "warning", Message: "Invalid Email or Password.",
}, },
) )
http.Redirect(w, req, "/login", http.StatusFound) http.Redirect(w, req, "/login", http.StatusFound)
return return
}
if !user.EmailValid {
p.Sessions.Flash(
w, req, services.Flash{
Type: "warning", Message: "You need to confirm your email before logging in.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if err := user.CheckPassword(password); err != nil {
// wrong password
p.Sessions.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
// user is logged in, set cookie
p.Sessions.SetUserEmail(w, req, email)
http.Redirect(w, req, "/certs", http.StatusSeeOther)
} }
if !user.EmailValid {
services.SessionStore.Flash(
w, req, services.Flash{
Type: "warning", Message: "You need to confirm your email before logging in.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if err := user.CheckPassword(password); err != nil {
// wrong password
services.SessionStore.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
// user is logged in, set cookie
services.SessionStore.SetUserEmail(w, req, email)
http.Redirect(w, req, "/certs", http.StatusSeeOther)
} }

View file

@ -19,69 +19,75 @@ import (
"git.klink.asia/paul/certman/views" "git.klink.asia/paul/certman/views"
) )
func ListCertHandler(w http.ResponseWriter, req *http.Request) { func ListCertHandler(p *services.Provider) http.HandlerFunc {
v := views.New(req) return func(w http.ResponseWriter, req *http.Request) {
v.Render(w, "cert_list") v := views.NewWithSession(req, p.Sessions)
v.Render(w, "cert_list")
}
} }
func CreateCertHandler(w http.ResponseWriter, req *http.Request) { func CreateCertHandler(p *services.Provider) http.HandlerFunc {
email := services.SessionStore.GetUserEmail(req) return func(w http.ResponseWriter, req *http.Request) {
certname := req.FormValue("certname") email := p.Sessions.GetUserEmail(req)
certname := req.FormValue("certname")
user := models.User{} user := models.User{}
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
if err != nil { if err != nil {
fmt.Printf("Could not fetch user for mail %s\n", email) fmt.Printf("Could not fetch user for mail %s\n", email)
}
// Load CA master certificate
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
if err != nil {
log.Fatalf("error loading ca keyfiles: %s", err)
panic(err.Error())
}
// Generate Keypair
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Could not generate keypair: %s", err)
}
// Generate Certificate
derBytes, err := CreateCertificate(key, caCert, caKey)
// Initialize new client config
client := models.Client{
Name: certname,
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Cert: derBytes,
UserID: user.ID,
}
// Insert client into database
if err := p.DB.Create(&client).Error; err != nil {
panic(err.Error())
}
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
Message: "The certificate was created successfully.",
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
} }
// Load CA master certificate
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
if err != nil {
log.Fatalf("error loading ca keyfiles: %s", err)
panic(err.Error())
}
// Generate Keypair
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Could not generate keypair: %s", err)
}
// Generate Certificate
derBytes, err := CreateCertificate(key, caCert, caKey)
// Initialize new client config
client := models.Client{
Name: certname,
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Cert: derBytes,
UserID: user.ID,
}
// Insert client into database
if err := services.Database.Create(&client).Error; err != nil {
panic(err.Error())
}
services.SessionStore.Flash(w, req,
services.Flash{
Type: "success",
Message: "The certificate was created successfully.",
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
} }
func DownloadCertHandler(w http.ResponseWriter, req *http.Request) { func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
//v := views.New(req) return func(w http.ResponseWriter, req *http.Request) {
// //v := views.New(req)
//derBytes, err := CreateCertificate(key, caCert, caKey) //
//pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) //derBytes, err := CreateCertificate(key, caCert, caKey)
// //pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
//pkBytes := x509.MarshalPKCS1PrivateKey(key) //
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes}) //pkBytes := x509.MarshalPKCS1PrivateKey(key)
return //pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
return
}
} }
func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {

34
main.go
View file

@ -3,22 +3,46 @@ package main
import ( import (
"log" "log"
"net/http" "net/http"
"time"
"github.com/gorilla/securecookie"
"git.klink.asia/paul/certman/services" "git.klink.asia/paul/certman/services"
"git.klink.asia/paul/certman/router" "git.klink.asia/paul/certman/router"
"git.klink.asia/paul/certman/views" "git.klink.asia/paul/certman/views"
// import sqlite3 driver once // import sqlite3 driver
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
func main() { func main() {
c := services.Config{
DB: &services.DBConfig{
Type: "sqlite3",
DSN: "db.sqlite3",
Log: true,
},
Sessions: &services.SessionsConfig{
SessionName: "_session",
CookieKey: string(securecookie.GenerateRandomKey(32)),
HttpOnly: true,
Lifetime: 24 * time.Hour,
},
Email: &services.EmailConfig{
SMTPServer: "example.com",
SMTPPort: 25,
SMTPUsername: "test",
SMTPPassword: "test",
From: "Mailtest <test@example.com>",
},
}
// Connect to the database serviceProvider := services.NewProvider(&c)
db := services.InitDB()
services.InitSession() // Start the mail daemon, which re-uses connections to send mails to the
// SMTP server
go serviceProvider.Email.Daemon()
//user := models.User{} //user := models.User{}
//user.Username = "test" //user.Username = "test"
@ -29,7 +53,7 @@ func main() {
// load and parse template files // load and parse template files
views.LoadTemplates() views.LoadTemplates()
mux := router.HandleRoutes(db) mux := router.HandleRoutes(serviceProvider)
err := http.ListenAndServe(":8000", mux) err := http.ListenAndServe(":8000", mux)
log.Fatalf(err.Error()) log.Fatalf(err.Error())

View file

@ -8,13 +8,15 @@ import (
// RequireLogin is a middleware that checks for a username in the active // RequireLogin is a middleware that checks for a username in the active
// session, and redirects to `/login` if no username was found. // session, and redirects to `/login` if no username was found.
func RequireLogin(next http.Handler) http.Handler { func RequireLogin(sessions *services.Sessions) func(http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) { return func(next http.Handler) http.Handler {
if username := services.SessionStore.GetUserEmail(req); username == "" { fn := func(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, "/login", http.StatusFound) if username := sessions.GetUserEmail(req); username == "" {
} http.Redirect(w, req, "/login", http.StatusFound)
}
next.ServeHTTP(w, req) next.ServeHTTP(w, req)
}
return http.HandlerFunc(fn)
} }
return http.HandlerFunc(fn)
} }

View file

@ -13,7 +13,6 @@ import (
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/jinzhu/gorm"
mw "git.klink.asia/paul/certman/middleware" mw "git.klink.asia/paul/certman/middleware"
) )
@ -26,15 +25,15 @@ var (
cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH") cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH")
) )
func HandleRoutes(db *gorm.DB) http.Handler { func HandleRoutes(provider *services.Provider) http.Handler {
mux := chi.NewMux() mux := chi.NewMux()
//mux.Use(middleware.RequestID) //mux.Use(middleware.RequestID)
mux.Use(middleware.Logger) // log requests mux.Use(middleware.Logger) // log requests
mux.Use(middleware.RealIP) // use proxy headers mux.Use(middleware.RealIP) // use proxy headers
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
mux.Use(mw.Recoverer) // recover on panic mux.Use(mw.Recoverer) // recover on panic
mux.Use(services.SessionStore.Use) // use session storage mux.Use(provider.Sessions.Manager.Use) // use session storage
// we are serving the static files directly from the assets package // we are serving the static files directly from the assets package
// this either means we use the embedded files, or live-load // this either means we use the embedded files, or live-load
@ -56,26 +55,26 @@ func HandleRoutes(db *gorm.DB) http.Handler {
r.Route("/register", func(r chi.Router) { r.Route("/register", func(r chi.Router) {
r.Get("/", v("register")) r.Get("/", v("register"))
r.Post("/", handlers.RegisterHandler) r.Post("/", handlers.RegisterHandler(provider))
}) })
r.Route("/login", func(r chi.Router) { r.Route("/login", func(r chi.Router) {
r.Get("/", v("login")) r.Get("/", v("login"))
r.Post("/", handlers.LoginHandler) r.Post("/", handlers.LoginHandler(provider))
}) })
//r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db)) //r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
r.Route("/forgot-password", func(r chi.Router) { r.Route("/forgot-password", func(r chi.Router) {
r.Get("/", v("forgot-password")) r.Get("/", v("forgot-password"))
r.Post("/", handlers.LoginHandler) r.Post("/", handlers.LoginHandler(provider))
}) })
r.Route("/certs", func(r chi.Router) { r.Route("/certs", func(r chi.Router) {
r.Use(mw.RequireLogin) r.Use(mw.RequireLogin(provider.Sessions))
r.Get("/", handlers.ListCertHandler) r.Get("/", handlers.ListCertHandler(provider))
r.Post("/new", handlers.CreateCertHandler) r.Post("/new", handlers.CreateCertHandler(provider))
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler) r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler(provider))
}) })
r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) { r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) {

View file

@ -5,7 +5,6 @@ import (
"log" "log"
"git.klink.asia/paul/certman/models" "git.klink.asia/paul/certman/models"
"git.klink.asia/paul/certman/settings"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
) )
@ -14,28 +13,34 @@ var (
ErrNotImplemented = errors.New("Not implemented") ErrNotImplemented = errors.New("Not implemented")
) )
var Database *gorm.DB type DBConfig struct {
Type string
DSN string
Log bool
}
// DB is a wrapper around gorm.DB to provide custom methods // DB is a wrapper around gorm.DB to provide custom methods
type DB struct { type DB struct {
*gorm.DB *gorm.DB
conf *DBConfig
} }
func InitDB() *gorm.DB { func NewDB(conf *DBConfig) *DB {
dsn := settings.Get("DATABASE_URL", "db.sqlite3")
// Establish connection // Establish connection
db, err := gorm.Open("sqlite3", dsn) db, err := gorm.Open(conf.Type, conf.DSN)
if err != nil { if err != nil {
log.Fatalf("Could not open database: %s", err.Error()) log.Fatalf("Could not open database: %s", err.Error())
} }
// Migrate models // Migrate models
db.AutoMigrate(models.User{}, models.Client{}) db.AutoMigrate(models.User{}, models.Client{})
db.LogMode(true) db.LogMode(conf.Log)
Database = db return &DB{
return db DB: db,
conf: conf,
}
} }
// CountUsers returns the number of Users in the datastore // CountUsers returns the number of Users in the datastore

103
services/email.go Normal file
View file

@ -0,0 +1,103 @@
package services
import (
"errors"
"log"
"time"
"github.com/go-mail/mail"
)
var (
ErrMailUninitializedConfig = errors.New("Mail: uninitialized config")
)
type EmailConfig struct {
From string
SMTPServer string
SMTPPort int
SMTPUsername string
SMTPPassword string
}
type Email struct {
config *EmailConfig
mailChan chan *mail.Message
}
func NewEmail(conf *EmailConfig) *Email {
if conf == nil {
log.Println(ErrMailUninitializedConfig)
}
return &Email{
config: conf,
mailChan: make(chan *mail.Message, 0),
}
}
// Send sends an email to the receiver
func (email *Email) Send(to, subject, text, html string) error {
if email.config == nil {
log.Print("Error: trying to send mail with uninitialized config.")
return ErrMailUninitializedConfig
}
m := mail.NewMessage()
m.SetHeader("From", email.config.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", text)
m.AddAlternative("text/html", html)
// put email in chan
email.mailChan <- m
return nil
}
// Daemon is a function that takes Mail and sends it without blocking.
// WIP
func (email *Email) Daemon() {
if email.config == nil {
log.Print("Error: trying to set up mail deamon with uninitialized config.")
return
}
d := mail.NewDialer(
email.config.SMTPServer,
email.config.SMTPPort,
email.config.SMTPUsername,
email.config.SMTPPassword)
var s mail.SendCloser
var err error
open := false
for {
select {
case m, ok := <-email.mailChan:
if !ok {
// channel is closed
return
}
if !open {
if s, err = d.Dial(); err != nil {
log.Print(err)
return
}
open = true
}
if err := mail.Send(s, m); err != nil {
log.Print(err)
}
// Close the connection if no email was sent in the last 30 seconds.
case <-time.After(30 * time.Second):
if open {
if err := s.Close(); err != nil {
panic(err)
}
open = false
}
}
}
}

24
services/provider.go Normal file
View file

@ -0,0 +1,24 @@
package services
type Config struct {
DB *DBConfig
Sessions *SessionsConfig
Email *EmailConfig
}
type Provider struct {
DB *DB
Sessions *Sessions
Email *Email
}
// NewProvider returns the ServiceProvider
func NewProvider(conf *Config) *Provider {
var provider = &Provider{}
provider.DB = NewDB(conf.DB)
provider.Sessions = NewSessions(conf.Sessions)
provider.Email = NewEmail(conf.Email)
return provider
}

View file

@ -8,16 +8,10 @@ import (
"net/http" "net/http"
"time" "time"
"git.klink.asia/paul/certman/settings"
"github.com/alexedwards/scs" "github.com/alexedwards/scs"
"github.com/gorilla/securecookie"
) )
var ( var (
// SessionName is the name of the session cookie
SessionName = "session"
// CookieKey is the key the cookies are encrypted and signed with
CookieKey = string(securecookie.GenerateRandomKey(32))
// FlashesKey is the key used for the flashes in the cookie // FlashesKey is the key used for the flashes in the cookie
FlashesKey = "_flashes" FlashesKey = "_flashes"
// UserEmailKey is the key used to reference usernames // UserEmailKey is the key used to reference usernames
@ -29,30 +23,33 @@ func init() {
gob.Register(Flash{}) gob.Register(Flash{})
} }
// SessionStore is a globally accessible sessions store for the application type SessionsConfig struct {
var SessionStore *Store SessionName string
CookieKey string
HttpOnly bool
Secure bool
Lifetime time.Duration
}
// Store is a wrapped scs.Store in order to implement custom // Sessions is a wrapped scs.Store in order to implement custom logic
// logic type Sessions struct {
type Store struct {
*scs.Manager *scs.Manager
} }
// InitSession populates the default sessions Store // NewSessions populates the default sessions Store
func InitSession() { func NewSessions(conf *SessionsConfig) *Sessions {
store := scs.NewCookieManager( store := scs.NewCookieManager(
CookieKey, conf.CookieKey,
) )
store.Name(conf.SessionName)
store.HttpOnly(true) store.HttpOnly(true)
store.Lifetime(24 * time.Hour) store.Lifetime(conf.Lifetime)
store.Secure(conf.Secure)
// Use secure cookies (HTTPS only) in production return &Sessions{store}
store.Secure(settings.Get("ENVIRONMENT", "") == "production")
SessionStore = &Store{store}
} }
func (store *Store) GetUserEmail(req *http.Request) string { func (store *Sessions) GetUserEmail(req *http.Request) string {
if store == nil { if store == nil {
// if store was not initialized, all requests fail // if store was not initialized, all requests fail
log.Println("Zero pointer when checking session for username") log.Println("Zero pointer when checking session for username")
@ -72,7 +69,7 @@ func (store *Store) GetUserEmail(req *http.Request) string {
return email return email
} }
func (store *Store) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) { func (store *Sessions) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
if store == nil { if store == nil {
// if store was not initialized, do nothing // if store was not initialized, do nothing
return return
@ -103,7 +100,7 @@ func (flash Flash) Render() template.HTML {
} }
// Flash add flash message to session data // Flash add flash message to session data
func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error { func (store *Sessions) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
var flashes []Flash var flashes []Flash
sess := store.Load(req) sess := store.Load(req)
@ -118,7 +115,7 @@ func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash)
} }
// Flashes returns a slice of flash messages from session data // Flashes returns a slice of flash messages from session data
func (store *Store) Flashes(w http.ResponseWriter, req *http.Request) []Flash { func (store *Sessions) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
var flashes []Flash var flashes []Flash
sess := store.Load(req) sess := store.Load(req)
sess.PopObject(w, FlashesKey, &flashes) sess.PopObject(w, FlashesKey, &flashes)

View file

@ -13,8 +13,9 @@ import (
) )
type View struct { type View struct {
Vars map[string]interface{} Vars map[string]interface{}
Request *http.Request Request *http.Request
SessionStore *services.Sessions
} }
func New(req *http.Request) *View { func New(req *http.Request) *View {
@ -23,12 +24,29 @@ func New(req *http.Request) *View {
Vars: map[string]interface{}{ Vars: map[string]interface{}{
"CSRF_TOKEN": csrf.Token(req), "CSRF_TOKEN": csrf.Token(req),
"csrfField": csrf.TemplateField(req), "csrfField": csrf.TemplateField(req),
"username": services.SessionStore.GetUserEmail(req),
"Meta": map[string]interface{}{ "Meta": map[string]interface{}{
"Path": req.URL.Path, "Path": req.URL.Path,
"Env": "develop", "Env": "develop",
}, },
"flashes": []services.Flash{}, "flashes": []services.Flash{},
"username": "",
},
}
}
func NewWithSession(req *http.Request, sessionStore *services.Sessions) *View {
return &View{
Request: req,
SessionStore: sessionStore,
Vars: map[string]interface{}{
"CSRF_TOKEN": csrf.Token(req),
"csrfField": csrf.TemplateField(req),
"Meta": map[string]interface{}{
"Path": req.URL.Path,
"Env": "develop",
},
"flashes": []services.Flash{},
"username": sessionStore.GetUserEmail(req),
}, },
} }
} }
@ -43,8 +61,10 @@ func (view View) Render(w http.ResponseWriter, name string) {
return return
} }
// add flashes to template if view.SessionStore != nil {
view.Vars["flashes"] = services.SessionStore.Flashes(w, view.Request) // add flashes to template
view.Vars["flashes"] = view.SessionStore.Flashes(w, view.Request)
}
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)