From 4a79379d8e4b75ef25da96aa39d7bfc98119c096 Mon Sep 17 00:00:00 2001 From: Paul Date: Mon, 29 Jan 2018 16:52:59 +0100 Subject: [PATCH] Refactor global variables into injected dependencies --- handlers/auth.go | 126 +++++++++++++++++++------------------ handlers/cert.go | 120 ++++++++++++++++++----------------- main.go | 34 ++++++++-- middleware/requirelogin.go | 16 ++--- router/router.go | 27 ++++---- services/db.go | 23 ++++--- services/email.go | 103 ++++++++++++++++++++++++++++++ services/provider.go | 24 +++++++ services/sessions.go | 43 ++++++------- views/views.go | 32 ++++++++-- 10 files changed, 366 insertions(+), 182 deletions(-) create mode 100644 services/email.go create mode 100644 services/provider.go diff --git a/handlers/auth.go b/handlers/auth.go index 5968650..51043c1 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -8,73 +8,77 @@ import ( "git.klink.asia/paul/certman/models" ) -func RegisterHandler(w http.ResponseWriter, req *http.Request) { - // Get parameters - email := req.Form.Get("email") - password := req.Form.Get("password") +func RegisterHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + // Get parameters + email := req.Form.Get("email") + password := req.Form.Get("password") - user := models.User{} - user.Email = email - user.SetPassword(password) + user := models.User{} + user.Email = email + user.SetPassword(password) - err := services.Database.Create(&user).Error - if err != nil { - panic(err.Error) + err := p.DB.Create(&user).Error + if err != nil { + panic(err.Error) + } + + p.Sessions.Flash(w, req, + services.Flash{ + Type: "success", + Message: "The user was created. Check your inbox for the confirmation email.", + }, + ) + + http.Redirect(w, req, "/login", http.StatusFound) + return } - - services.SessionStore.Flash(w, req, - services.Flash{ - Type: "success", - Message: "The user was created. Check your inbox for the confirmation email.", - }, - ) - - http.Redirect(w, req, "/login", http.StatusFound) - return } -func LoginHandler(w http.ResponseWriter, req *http.Request) { - // Get parameters - email := req.Form.Get("email") - password := req.Form.Get("password") +func LoginHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + // Get parameters + email := req.Form.Get("email") + password := req.Form.Get("password") - user := models.User{} + user := models.User{} - err := services.Database.Where(&models.User{Email: email}).Find(&user).Error - if err != nil { - // could not find user - services.SessionStore.Flash( - w, req, services.Flash{ - Type: "warning", Message: "Invalid Email or Password.", - }, - ) - http.Redirect(w, req, "/login", http.StatusFound) - return + err := p.DB.Where(&models.User{Email: email}).Find(&user).Error + if err != nil { + // could not find user + p.Sessions.Flash( + w, req, services.Flash{ + Type: "warning", Message: "Invalid Email or Password.", + }, + ) + http.Redirect(w, req, "/login", http.StatusFound) + return + } + + if !user.EmailValid { + p.Sessions.Flash( + w, req, services.Flash{ + Type: "warning", Message: "You need to confirm your email before logging in.", + }, + ) + http.Redirect(w, req, "/login", http.StatusFound) + return + } + + if err := user.CheckPassword(password); err != nil { + // wrong password + p.Sessions.Flash( + w, req, services.Flash{ + Type: "warning", Message: "Invalid Email or Password.", + }, + ) + http.Redirect(w, req, "/login", http.StatusFound) + return + } + + // user is logged in, set cookie + p.Sessions.SetUserEmail(w, req, email) + + http.Redirect(w, req, "/certs", http.StatusSeeOther) } - - if !user.EmailValid { - services.SessionStore.Flash( - w, req, services.Flash{ - Type: "warning", Message: "You need to confirm your email before logging in.", - }, - ) - http.Redirect(w, req, "/login", http.StatusFound) - return - } - - if err := user.CheckPassword(password); err != nil { - // wrong password - services.SessionStore.Flash( - w, req, services.Flash{ - Type: "warning", Message: "Invalid Email or Password.", - }, - ) - http.Redirect(w, req, "/login", http.StatusFound) - return - } - - // user is logged in, set cookie - services.SessionStore.SetUserEmail(w, req, email) - - http.Redirect(w, req, "/certs", http.StatusSeeOther) } diff --git a/handlers/cert.go b/handlers/cert.go index 5f59ab6..9c52607 100644 --- a/handlers/cert.go +++ b/handlers/cert.go @@ -19,69 +19,75 @@ import ( "git.klink.asia/paul/certman/views" ) -func ListCertHandler(w http.ResponseWriter, req *http.Request) { - v := views.New(req) - v.Render(w, "cert_list") +func ListCertHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + v := views.NewWithSession(req, p.Sessions) + v.Render(w, "cert_list") + } } -func CreateCertHandler(w http.ResponseWriter, req *http.Request) { - email := services.SessionStore.GetUserEmail(req) - certname := req.FormValue("certname") +func CreateCertHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + email := p.Sessions.GetUserEmail(req) + certname := req.FormValue("certname") - user := models.User{} - err := services.Database.Where(&models.User{Email: email}).Find(&user).Error - if err != nil { - fmt.Printf("Could not fetch user for mail %s\n", email) + user := models.User{} + err := p.DB.Where(&models.User{Email: email}).Find(&user).Error + if err != nil { + fmt.Printf("Could not fetch user for mail %s\n", email) + } + + // Load CA master certificate + caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key") + if err != nil { + log.Fatalf("error loading ca keyfiles: %s", err) + panic(err.Error()) + } + + // Generate Keypair + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("Could not generate keypair: %s", err) + } + + // Generate Certificate + derBytes, err := CreateCertificate(key, caCert, caKey) + + // Initialize new client config + client := models.Client{ + Name: certname, + PrivateKey: x509.MarshalPKCS1PrivateKey(key), + Cert: derBytes, + UserID: user.ID, + } + + // Insert client into database + if err := p.DB.Create(&client).Error; err != nil { + panic(err.Error()) + } + + p.Sessions.Flash(w, req, + services.Flash{ + Type: "success", + Message: "The certificate was created successfully.", + }, + ) + + http.Redirect(w, req, "/certs", http.StatusFound) } - - // Load CA master certificate - caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key") - if err != nil { - log.Fatalf("error loading ca keyfiles: %s", err) - panic(err.Error()) - } - - // Generate Keypair - key, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - log.Fatalf("Could not generate keypair: %s", err) - } - - // Generate Certificate - derBytes, err := CreateCertificate(key, caCert, caKey) - - // Initialize new client config - client := models.Client{ - Name: certname, - PrivateKey: x509.MarshalPKCS1PrivateKey(key), - Cert: derBytes, - UserID: user.ID, - } - - // Insert client into database - if err := services.Database.Create(&client).Error; err != nil { - panic(err.Error()) - } - - services.SessionStore.Flash(w, req, - services.Flash{ - Type: "success", - Message: "The certificate was created successfully.", - }, - ) - - http.Redirect(w, req, "/certs", http.StatusFound) } -func DownloadCertHandler(w http.ResponseWriter, req *http.Request) { - //v := views.New(req) - // - //derBytes, err := CreateCertificate(key, caCert, caKey) - //pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) - // - //pkBytes := x509.MarshalPKCS1PrivateKey(key) - //pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes}) - return +func DownloadCertHandler(p *services.Provider) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + //v := views.New(req) + // + //derBytes, err := CreateCertificate(key, caCert, caKey) + //pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + // + //pkBytes := x509.MarshalPKCS1PrivateKey(key) + //pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes}) + return + } } func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) { diff --git a/main.go b/main.go index 15f0c37..b5a1e4a 100644 --- a/main.go +++ b/main.go @@ -3,22 +3,46 @@ package main import ( "log" "net/http" + "time" + + "github.com/gorilla/securecookie" "git.klink.asia/paul/certman/services" "git.klink.asia/paul/certman/router" "git.klink.asia/paul/certman/views" - // import sqlite3 driver once + // import sqlite3 driver _ "github.com/mattn/go-sqlite3" ) func main() { + c := services.Config{ + DB: &services.DBConfig{ + Type: "sqlite3", + DSN: "db.sqlite3", + Log: true, + }, + Sessions: &services.SessionsConfig{ + SessionName: "_session", + CookieKey: string(securecookie.GenerateRandomKey(32)), + HttpOnly: true, + Lifetime: 24 * time.Hour, + }, + Email: &services.EmailConfig{ + SMTPServer: "example.com", + SMTPPort: 25, + SMTPUsername: "test", + SMTPPassword: "test", + From: "Mailtest ", + }, + } - // Connect to the database - db := services.InitDB() + serviceProvider := services.NewProvider(&c) - services.InitSession() + // Start the mail daemon, which re-uses connections to send mails to the + // SMTP server + go serviceProvider.Email.Daemon() //user := models.User{} //user.Username = "test" @@ -29,7 +53,7 @@ func main() { // load and parse template files views.LoadTemplates() - mux := router.HandleRoutes(db) + mux := router.HandleRoutes(serviceProvider) err := http.ListenAndServe(":8000", mux) log.Fatalf(err.Error()) diff --git a/middleware/requirelogin.go b/middleware/requirelogin.go index a1917a2..98bbac7 100644 --- a/middleware/requirelogin.go +++ b/middleware/requirelogin.go @@ -8,13 +8,15 @@ import ( // RequireLogin is a middleware that checks for a username in the active // session, and redirects to `/login` if no username was found. -func RequireLogin(next http.Handler) http.Handler { - fn := func(w http.ResponseWriter, req *http.Request) { - if username := services.SessionStore.GetUserEmail(req); username == "" { - http.Redirect(w, req, "/login", http.StatusFound) - } +func RequireLogin(sessions *services.Sessions) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, req *http.Request) { + if username := sessions.GetUserEmail(req); username == "" { + http.Redirect(w, req, "/login", http.StatusFound) + } - next.ServeHTTP(w, req) + next.ServeHTTP(w, req) + } + return http.HandlerFunc(fn) } - return http.HandlerFunc(fn) } diff --git a/router/router.go b/router/router.go index 5883cbe..9ea982b 100644 --- a/router/router.go +++ b/router/router.go @@ -13,7 +13,6 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/gorilla/csrf" - "github.com/jinzhu/gorm" mw "git.klink.asia/paul/certman/middleware" ) @@ -26,15 +25,15 @@ var ( cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH") ) -func HandleRoutes(db *gorm.DB) http.Handler { +func HandleRoutes(provider *services.Provider) http.Handler { mux := chi.NewMux() //mux.Use(middleware.RequestID) - mux.Use(middleware.Logger) // log requests - mux.Use(middleware.RealIP) // use proxy headers - mux.Use(middleware.RedirectSlashes) // redirect trailing slashes - mux.Use(mw.Recoverer) // recover on panic - mux.Use(services.SessionStore.Use) // use session storage + mux.Use(middleware.Logger) // log requests + mux.Use(middleware.RealIP) // use proxy headers + mux.Use(middleware.RedirectSlashes) // redirect trailing slashes + mux.Use(mw.Recoverer) // recover on panic + mux.Use(provider.Sessions.Manager.Use) // use session storage // we are serving the static files directly from the assets package // this either means we use the embedded files, or live-load @@ -56,26 +55,26 @@ func HandleRoutes(db *gorm.DB) http.Handler { r.Route("/register", func(r chi.Router) { r.Get("/", v("register")) - r.Post("/", handlers.RegisterHandler) + r.Post("/", handlers.RegisterHandler(provider)) }) r.Route("/login", func(r chi.Router) { r.Get("/", v("login")) - r.Post("/", handlers.LoginHandler) + r.Post("/", handlers.LoginHandler(provider)) }) //r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db)) r.Route("/forgot-password", func(r chi.Router) { r.Get("/", v("forgot-password")) - r.Post("/", handlers.LoginHandler) + r.Post("/", handlers.LoginHandler(provider)) }) r.Route("/certs", func(r chi.Router) { - r.Use(mw.RequireLogin) - r.Get("/", handlers.ListCertHandler) - r.Post("/new", handlers.CreateCertHandler) - r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler) + r.Use(mw.RequireLogin(provider.Sessions)) + r.Get("/", handlers.ListCertHandler(provider)) + r.Post("/new", handlers.CreateCertHandler(provider)) + r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler(provider)) }) r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) { diff --git a/services/db.go b/services/db.go index f86af66..e36d892 100644 --- a/services/db.go +++ b/services/db.go @@ -5,7 +5,6 @@ import ( "log" "git.klink.asia/paul/certman/models" - "git.klink.asia/paul/certman/settings" "github.com/jinzhu/gorm" ) @@ -14,28 +13,34 @@ var ( ErrNotImplemented = errors.New("Not implemented") ) -var Database *gorm.DB +type DBConfig struct { + Type string + DSN string + Log bool +} // DB is a wrapper around gorm.DB to provide custom methods type DB struct { *gorm.DB + + conf *DBConfig } -func InitDB() *gorm.DB { - dsn := settings.Get("DATABASE_URL", "db.sqlite3") - +func NewDB(conf *DBConfig) *DB { // Establish connection - db, err := gorm.Open("sqlite3", dsn) + db, err := gorm.Open(conf.Type, conf.DSN) if err != nil { log.Fatalf("Could not open database: %s", err.Error()) } // Migrate models db.AutoMigrate(models.User{}, models.Client{}) - db.LogMode(true) + db.LogMode(conf.Log) - Database = db - return db + return &DB{ + DB: db, + conf: conf, + } } // CountUsers returns the number of Users in the datastore diff --git a/services/email.go b/services/email.go new file mode 100644 index 0000000..428a37d --- /dev/null +++ b/services/email.go @@ -0,0 +1,103 @@ +package services + +import ( + "errors" + "log" + "time" + + "github.com/go-mail/mail" +) + +var ( + ErrMailUninitializedConfig = errors.New("Mail: uninitialized config") +) + +type EmailConfig struct { + From string + SMTPServer string + SMTPPort int + SMTPUsername string + SMTPPassword string +} + +type Email struct { + config *EmailConfig + + mailChan chan *mail.Message +} + +func NewEmail(conf *EmailConfig) *Email { + if conf == nil { + log.Println(ErrMailUninitializedConfig) + } + + return &Email{ + config: conf, + mailChan: make(chan *mail.Message, 0), + } +} + +// Send sends an email to the receiver +func (email *Email) Send(to, subject, text, html string) error { + if email.config == nil { + log.Print("Error: trying to send mail with uninitialized config.") + return ErrMailUninitializedConfig + } + + m := mail.NewMessage() + m.SetHeader("From", email.config.From) + m.SetHeader("To", to) + m.SetHeader("Subject", subject) + m.SetBody("text/plain", text) + m.AddAlternative("text/html", html) + + // put email in chan + email.mailChan <- m + return nil +} + +// Daemon is a function that takes Mail and sends it without blocking. +// WIP +func (email *Email) Daemon() { + if email.config == nil { + log.Print("Error: trying to set up mail deamon with uninitialized config.") + return + } + + d := mail.NewDialer( + email.config.SMTPServer, + email.config.SMTPPort, + email.config.SMTPUsername, + email.config.SMTPPassword) + + var s mail.SendCloser + var err error + open := false + for { + select { + case m, ok := <-email.mailChan: + if !ok { + // channel is closed + return + } + if !open { + if s, err = d.Dial(); err != nil { + log.Print(err) + return + } + open = true + } + if err := mail.Send(s, m); err != nil { + log.Print(err) + } + // Close the connection if no email was sent in the last 30 seconds. + case <-time.After(30 * time.Second): + if open { + if err := s.Close(); err != nil { + panic(err) + } + open = false + } + } + } +} diff --git a/services/provider.go b/services/provider.go new file mode 100644 index 0000000..ab2abaa --- /dev/null +++ b/services/provider.go @@ -0,0 +1,24 @@ +package services + +type Config struct { + DB *DBConfig + Sessions *SessionsConfig + Email *EmailConfig +} + +type Provider struct { + DB *DB + Sessions *Sessions + Email *Email +} + +// NewProvider returns the ServiceProvider +func NewProvider(conf *Config) *Provider { + var provider = &Provider{} + + provider.DB = NewDB(conf.DB) + provider.Sessions = NewSessions(conf.Sessions) + provider.Email = NewEmail(conf.Email) + + return provider +} diff --git a/services/sessions.go b/services/sessions.go index 9efe6a5..36ee58b 100644 --- a/services/sessions.go +++ b/services/sessions.go @@ -8,16 +8,10 @@ import ( "net/http" "time" - "git.klink.asia/paul/certman/settings" "github.com/alexedwards/scs" - "github.com/gorilla/securecookie" ) var ( - // SessionName is the name of the session cookie - SessionName = "session" - // CookieKey is the key the cookies are encrypted and signed with - CookieKey = string(securecookie.GenerateRandomKey(32)) // FlashesKey is the key used for the flashes in the cookie FlashesKey = "_flashes" // UserEmailKey is the key used to reference usernames @@ -29,30 +23,33 @@ func init() { gob.Register(Flash{}) } -// SessionStore is a globally accessible sessions store for the application -var SessionStore *Store +type SessionsConfig struct { + SessionName string + CookieKey string + HttpOnly bool + Secure bool + Lifetime time.Duration +} -// Store is a wrapped scs.Store in order to implement custom -// logic -type Store struct { +// Sessions is a wrapped scs.Store in order to implement custom logic +type Sessions struct { *scs.Manager } -// InitSession populates the default sessions Store -func InitSession() { +// NewSessions populates the default sessions Store +func NewSessions(conf *SessionsConfig) *Sessions { store := scs.NewCookieManager( - CookieKey, + conf.CookieKey, ) + store.Name(conf.SessionName) store.HttpOnly(true) - store.Lifetime(24 * time.Hour) + store.Lifetime(conf.Lifetime) + store.Secure(conf.Secure) - // Use secure cookies (HTTPS only) in production - store.Secure(settings.Get("ENVIRONMENT", "") == "production") - - SessionStore = &Store{store} + return &Sessions{store} } -func (store *Store) GetUserEmail(req *http.Request) string { +func (store *Sessions) GetUserEmail(req *http.Request) string { if store == nil { // if store was not initialized, all requests fail log.Println("Zero pointer when checking session for username") @@ -72,7 +69,7 @@ func (store *Store) GetUserEmail(req *http.Request) string { return email } -func (store *Store) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) { +func (store *Sessions) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) { if store == nil { // if store was not initialized, do nothing return @@ -103,7 +100,7 @@ func (flash Flash) Render() template.HTML { } // Flash add flash message to session data -func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error { +func (store *Sessions) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error { var flashes []Flash sess := store.Load(req) @@ -118,7 +115,7 @@ func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) } // Flashes returns a slice of flash messages from session data -func (store *Store) Flashes(w http.ResponseWriter, req *http.Request) []Flash { +func (store *Sessions) Flashes(w http.ResponseWriter, req *http.Request) []Flash { var flashes []Flash sess := store.Load(req) sess.PopObject(w, FlashesKey, &flashes) diff --git a/views/views.go b/views/views.go index 563d31c..c0e33e7 100644 --- a/views/views.go +++ b/views/views.go @@ -13,8 +13,9 @@ import ( ) type View struct { - Vars map[string]interface{} - Request *http.Request + Vars map[string]interface{} + Request *http.Request + SessionStore *services.Sessions } func New(req *http.Request) *View { @@ -23,12 +24,29 @@ func New(req *http.Request) *View { Vars: map[string]interface{}{ "CSRF_TOKEN": csrf.Token(req), "csrfField": csrf.TemplateField(req), - "username": services.SessionStore.GetUserEmail(req), "Meta": map[string]interface{}{ "Path": req.URL.Path, "Env": "develop", }, - "flashes": []services.Flash{}, + "flashes": []services.Flash{}, + "username": "", + }, + } +} + +func NewWithSession(req *http.Request, sessionStore *services.Sessions) *View { + return &View{ + Request: req, + SessionStore: sessionStore, + Vars: map[string]interface{}{ + "CSRF_TOKEN": csrf.Token(req), + "csrfField": csrf.TemplateField(req), + "Meta": map[string]interface{}{ + "Path": req.URL.Path, + "Env": "develop", + }, + "flashes": []services.Flash{}, + "username": sessionStore.GetUserEmail(req), }, } } @@ -43,8 +61,10 @@ func (view View) Render(w http.ResponseWriter, name string) { return } - // add flashes to template - view.Vars["flashes"] = services.SessionStore.Flashes(w, view.Request) + if view.SessionStore != nil { + // add flashes to template + view.Vars["flashes"] = view.SessionStore.Flashes(w, view.Request) + } w.Header().Set("Content-Type", "text/html; charset=utf-8") w.WriteHeader(http.StatusOK)