Refactor global variables into injected dependencies

This commit is contained in:
Paul 2018-01-29 16:52:59 +01:00
parent e5cd6e0be3
commit 4a79379d8e
10 changed files with 366 additions and 182 deletions

View file

@ -8,73 +8,77 @@ import (
"git.klink.asia/paul/certman/models"
)
func RegisterHandler(w http.ResponseWriter, req *http.Request) {
// Get parameters
email := req.Form.Get("email")
password := req.Form.Get("password")
func RegisterHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
// Get parameters
email := req.Form.Get("email")
password := req.Form.Get("password")
user := models.User{}
user.Email = email
user.SetPassword(password)
user := models.User{}
user.Email = email
user.SetPassword(password)
err := services.Database.Create(&user).Error
if err != nil {
panic(err.Error)
err := p.DB.Create(&user).Error
if err != nil {
panic(err.Error)
}
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
Message: "The user was created. Check your inbox for the confirmation email.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
services.SessionStore.Flash(w, req,
services.Flash{
Type: "success",
Message: "The user was created. Check your inbox for the confirmation email.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
func LoginHandler(w http.ResponseWriter, req *http.Request) {
// Get parameters
email := req.Form.Get("email")
password := req.Form.Get("password")
func LoginHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
// Get parameters
email := req.Form.Get("email")
password := req.Form.Get("password")
user := models.User{}
user := models.User{}
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
if err != nil {
// could not find user
services.SessionStore.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
if err != nil {
// could not find user
p.Sessions.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if !user.EmailValid {
p.Sessions.Flash(
w, req, services.Flash{
Type: "warning", Message: "You need to confirm your email before logging in.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if err := user.CheckPassword(password); err != nil {
// wrong password
p.Sessions.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
// user is logged in, set cookie
p.Sessions.SetUserEmail(w, req, email)
http.Redirect(w, req, "/certs", http.StatusSeeOther)
}
if !user.EmailValid {
services.SessionStore.Flash(
w, req, services.Flash{
Type: "warning", Message: "You need to confirm your email before logging in.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
if err := user.CheckPassword(password); err != nil {
// wrong password
services.SessionStore.Flash(
w, req, services.Flash{
Type: "warning", Message: "Invalid Email or Password.",
},
)
http.Redirect(w, req, "/login", http.StatusFound)
return
}
// user is logged in, set cookie
services.SessionStore.SetUserEmail(w, req, email)
http.Redirect(w, req, "/certs", http.StatusSeeOther)
}

View file

@ -19,69 +19,75 @@ import (
"git.klink.asia/paul/certman/views"
)
func ListCertHandler(w http.ResponseWriter, req *http.Request) {
v := views.New(req)
v.Render(w, "cert_list")
func ListCertHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
v := views.NewWithSession(req, p.Sessions)
v.Render(w, "cert_list")
}
}
func CreateCertHandler(w http.ResponseWriter, req *http.Request) {
email := services.SessionStore.GetUserEmail(req)
certname := req.FormValue("certname")
func CreateCertHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
email := p.Sessions.GetUserEmail(req)
certname := req.FormValue("certname")
user := models.User{}
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
if err != nil {
fmt.Printf("Could not fetch user for mail %s\n", email)
user := models.User{}
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
if err != nil {
fmt.Printf("Could not fetch user for mail %s\n", email)
}
// Load CA master certificate
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
if err != nil {
log.Fatalf("error loading ca keyfiles: %s", err)
panic(err.Error())
}
// Generate Keypair
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Could not generate keypair: %s", err)
}
// Generate Certificate
derBytes, err := CreateCertificate(key, caCert, caKey)
// Initialize new client config
client := models.Client{
Name: certname,
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Cert: derBytes,
UserID: user.ID,
}
// Insert client into database
if err := p.DB.Create(&client).Error; err != nil {
panic(err.Error())
}
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
Message: "The certificate was created successfully.",
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
}
// Load CA master certificate
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
if err != nil {
log.Fatalf("error loading ca keyfiles: %s", err)
panic(err.Error())
}
// Generate Keypair
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
log.Fatalf("Could not generate keypair: %s", err)
}
// Generate Certificate
derBytes, err := CreateCertificate(key, caCert, caKey)
// Initialize new client config
client := models.Client{
Name: certname,
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Cert: derBytes,
UserID: user.ID,
}
// Insert client into database
if err := services.Database.Create(&client).Error; err != nil {
panic(err.Error())
}
services.SessionStore.Flash(w, req,
services.Flash{
Type: "success",
Message: "The certificate was created successfully.",
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
}
func DownloadCertHandler(w http.ResponseWriter, req *http.Request) {
//v := views.New(req)
//
//derBytes, err := CreateCertificate(key, caCert, caKey)
//pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
//
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
return
func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
//v := views.New(req)
//
//derBytes, err := CreateCertificate(key, caCert, caKey)
//pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
//
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
return
}
}
func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {

34
main.go
View file

@ -3,22 +3,46 @@ package main
import (
"log"
"net/http"
"time"
"github.com/gorilla/securecookie"
"git.klink.asia/paul/certman/services"
"git.klink.asia/paul/certman/router"
"git.klink.asia/paul/certman/views"
// import sqlite3 driver once
// import sqlite3 driver
_ "github.com/mattn/go-sqlite3"
)
func main() {
c := services.Config{
DB: &services.DBConfig{
Type: "sqlite3",
DSN: "db.sqlite3",
Log: true,
},
Sessions: &services.SessionsConfig{
SessionName: "_session",
CookieKey: string(securecookie.GenerateRandomKey(32)),
HttpOnly: true,
Lifetime: 24 * time.Hour,
},
Email: &services.EmailConfig{
SMTPServer: "example.com",
SMTPPort: 25,
SMTPUsername: "test",
SMTPPassword: "test",
From: "Mailtest <test@example.com>",
},
}
// Connect to the database
db := services.InitDB()
serviceProvider := services.NewProvider(&c)
services.InitSession()
// Start the mail daemon, which re-uses connections to send mails to the
// SMTP server
go serviceProvider.Email.Daemon()
//user := models.User{}
//user.Username = "test"
@ -29,7 +53,7 @@ func main() {
// load and parse template files
views.LoadTemplates()
mux := router.HandleRoutes(db)
mux := router.HandleRoutes(serviceProvider)
err := http.ListenAndServe(":8000", mux)
log.Fatalf(err.Error())

View file

@ -8,13 +8,15 @@ import (
// RequireLogin is a middleware that checks for a username in the active
// session, and redirects to `/login` if no username was found.
func RequireLogin(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
if username := services.SessionStore.GetUserEmail(req); username == "" {
http.Redirect(w, req, "/login", http.StatusFound)
}
func RequireLogin(sessions *services.Sessions) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, req *http.Request) {
if username := sessions.GetUserEmail(req); username == "" {
http.Redirect(w, req, "/login", http.StatusFound)
}
next.ServeHTTP(w, req)
next.ServeHTTP(w, req)
}
return http.HandlerFunc(fn)
}
return http.HandlerFunc(fn)
}

View file

@ -13,7 +13,6 @@ import (
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/gorilla/csrf"
"github.com/jinzhu/gorm"
mw "git.klink.asia/paul/certman/middleware"
)
@ -26,15 +25,15 @@ var (
cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH")
)
func HandleRoutes(db *gorm.DB) http.Handler {
func HandleRoutes(provider *services.Provider) http.Handler {
mux := chi.NewMux()
//mux.Use(middleware.RequestID)
mux.Use(middleware.Logger) // log requests
mux.Use(middleware.RealIP) // use proxy headers
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
mux.Use(mw.Recoverer) // recover on panic
mux.Use(services.SessionStore.Use) // use session storage
mux.Use(middleware.Logger) // log requests
mux.Use(middleware.RealIP) // use proxy headers
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
mux.Use(mw.Recoverer) // recover on panic
mux.Use(provider.Sessions.Manager.Use) // use session storage
// we are serving the static files directly from the assets package
// this either means we use the embedded files, or live-load
@ -56,26 +55,26 @@ func HandleRoutes(db *gorm.DB) http.Handler {
r.Route("/register", func(r chi.Router) {
r.Get("/", v("register"))
r.Post("/", handlers.RegisterHandler)
r.Post("/", handlers.RegisterHandler(provider))
})
r.Route("/login", func(r chi.Router) {
r.Get("/", v("login"))
r.Post("/", handlers.LoginHandler)
r.Post("/", handlers.LoginHandler(provider))
})
//r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
r.Route("/forgot-password", func(r chi.Router) {
r.Get("/", v("forgot-password"))
r.Post("/", handlers.LoginHandler)
r.Post("/", handlers.LoginHandler(provider))
})
r.Route("/certs", func(r chi.Router) {
r.Use(mw.RequireLogin)
r.Get("/", handlers.ListCertHandler)
r.Post("/new", handlers.CreateCertHandler)
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler)
r.Use(mw.RequireLogin(provider.Sessions))
r.Get("/", handlers.ListCertHandler(provider))
r.Post("/new", handlers.CreateCertHandler(provider))
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler(provider))
})
r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) {

View file

@ -5,7 +5,6 @@ import (
"log"
"git.klink.asia/paul/certman/models"
"git.klink.asia/paul/certman/settings"
"github.com/jinzhu/gorm"
)
@ -14,28 +13,34 @@ var (
ErrNotImplemented = errors.New("Not implemented")
)
var Database *gorm.DB
type DBConfig struct {
Type string
DSN string
Log bool
}
// DB is a wrapper around gorm.DB to provide custom methods
type DB struct {
*gorm.DB
conf *DBConfig
}
func InitDB() *gorm.DB {
dsn := settings.Get("DATABASE_URL", "db.sqlite3")
func NewDB(conf *DBConfig) *DB {
// Establish connection
db, err := gorm.Open("sqlite3", dsn)
db, err := gorm.Open(conf.Type, conf.DSN)
if err != nil {
log.Fatalf("Could not open database: %s", err.Error())
}
// Migrate models
db.AutoMigrate(models.User{}, models.Client{})
db.LogMode(true)
db.LogMode(conf.Log)
Database = db
return db
return &DB{
DB: db,
conf: conf,
}
}
// CountUsers returns the number of Users in the datastore

103
services/email.go Normal file
View file

@ -0,0 +1,103 @@
package services
import (
"errors"
"log"
"time"
"github.com/go-mail/mail"
)
var (
ErrMailUninitializedConfig = errors.New("Mail: uninitialized config")
)
type EmailConfig struct {
From string
SMTPServer string
SMTPPort int
SMTPUsername string
SMTPPassword string
}
type Email struct {
config *EmailConfig
mailChan chan *mail.Message
}
func NewEmail(conf *EmailConfig) *Email {
if conf == nil {
log.Println(ErrMailUninitializedConfig)
}
return &Email{
config: conf,
mailChan: make(chan *mail.Message, 0),
}
}
// Send sends an email to the receiver
func (email *Email) Send(to, subject, text, html string) error {
if email.config == nil {
log.Print("Error: trying to send mail with uninitialized config.")
return ErrMailUninitializedConfig
}
m := mail.NewMessage()
m.SetHeader("From", email.config.From)
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
m.SetBody("text/plain", text)
m.AddAlternative("text/html", html)
// put email in chan
email.mailChan <- m
return nil
}
// Daemon is a function that takes Mail and sends it without blocking.
// WIP
func (email *Email) Daemon() {
if email.config == nil {
log.Print("Error: trying to set up mail deamon with uninitialized config.")
return
}
d := mail.NewDialer(
email.config.SMTPServer,
email.config.SMTPPort,
email.config.SMTPUsername,
email.config.SMTPPassword)
var s mail.SendCloser
var err error
open := false
for {
select {
case m, ok := <-email.mailChan:
if !ok {
// channel is closed
return
}
if !open {
if s, err = d.Dial(); err != nil {
log.Print(err)
return
}
open = true
}
if err := mail.Send(s, m); err != nil {
log.Print(err)
}
// Close the connection if no email was sent in the last 30 seconds.
case <-time.After(30 * time.Second):
if open {
if err := s.Close(); err != nil {
panic(err)
}
open = false
}
}
}
}

24
services/provider.go Normal file
View file

@ -0,0 +1,24 @@
package services
type Config struct {
DB *DBConfig
Sessions *SessionsConfig
Email *EmailConfig
}
type Provider struct {
DB *DB
Sessions *Sessions
Email *Email
}
// NewProvider returns the ServiceProvider
func NewProvider(conf *Config) *Provider {
var provider = &Provider{}
provider.DB = NewDB(conf.DB)
provider.Sessions = NewSessions(conf.Sessions)
provider.Email = NewEmail(conf.Email)
return provider
}

View file

@ -8,16 +8,10 @@ import (
"net/http"
"time"
"git.klink.asia/paul/certman/settings"
"github.com/alexedwards/scs"
"github.com/gorilla/securecookie"
)
var (
// SessionName is the name of the session cookie
SessionName = "session"
// CookieKey is the key the cookies are encrypted and signed with
CookieKey = string(securecookie.GenerateRandomKey(32))
// FlashesKey is the key used for the flashes in the cookie
FlashesKey = "_flashes"
// UserEmailKey is the key used to reference usernames
@ -29,30 +23,33 @@ func init() {
gob.Register(Flash{})
}
// SessionStore is a globally accessible sessions store for the application
var SessionStore *Store
type SessionsConfig struct {
SessionName string
CookieKey string
HttpOnly bool
Secure bool
Lifetime time.Duration
}
// Store is a wrapped scs.Store in order to implement custom
// logic
type Store struct {
// Sessions is a wrapped scs.Store in order to implement custom logic
type Sessions struct {
*scs.Manager
}
// InitSession populates the default sessions Store
func InitSession() {
// NewSessions populates the default sessions Store
func NewSessions(conf *SessionsConfig) *Sessions {
store := scs.NewCookieManager(
CookieKey,
conf.CookieKey,
)
store.Name(conf.SessionName)
store.HttpOnly(true)
store.Lifetime(24 * time.Hour)
store.Lifetime(conf.Lifetime)
store.Secure(conf.Secure)
// Use secure cookies (HTTPS only) in production
store.Secure(settings.Get("ENVIRONMENT", "") == "production")
SessionStore = &Store{store}
return &Sessions{store}
}
func (store *Store) GetUserEmail(req *http.Request) string {
func (store *Sessions) GetUserEmail(req *http.Request) string {
if store == nil {
// if store was not initialized, all requests fail
log.Println("Zero pointer when checking session for username")
@ -72,7 +69,7 @@ func (store *Store) GetUserEmail(req *http.Request) string {
return email
}
func (store *Store) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
func (store *Sessions) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
if store == nil {
// if store was not initialized, do nothing
return
@ -103,7 +100,7 @@ func (flash Flash) Render() template.HTML {
}
// Flash add flash message to session data
func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
func (store *Sessions) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
var flashes []Flash
sess := store.Load(req)
@ -118,7 +115,7 @@ func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash)
}
// Flashes returns a slice of flash messages from session data
func (store *Store) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
func (store *Sessions) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
var flashes []Flash
sess := store.Load(req)
sess.PopObject(w, FlashesKey, &flashes)

View file

@ -13,8 +13,9 @@ import (
)
type View struct {
Vars map[string]interface{}
Request *http.Request
Vars map[string]interface{}
Request *http.Request
SessionStore *services.Sessions
}
func New(req *http.Request) *View {
@ -23,12 +24,29 @@ func New(req *http.Request) *View {
Vars: map[string]interface{}{
"CSRF_TOKEN": csrf.Token(req),
"csrfField": csrf.TemplateField(req),
"username": services.SessionStore.GetUserEmail(req),
"Meta": map[string]interface{}{
"Path": req.URL.Path,
"Env": "develop",
},
"flashes": []services.Flash{},
"flashes": []services.Flash{},
"username": "",
},
}
}
func NewWithSession(req *http.Request, sessionStore *services.Sessions) *View {
return &View{
Request: req,
SessionStore: sessionStore,
Vars: map[string]interface{}{
"CSRF_TOKEN": csrf.Token(req),
"csrfField": csrf.TemplateField(req),
"Meta": map[string]interface{}{
"Path": req.URL.Path,
"Env": "develop",
},
"flashes": []services.Flash{},
"username": sessionStore.GetUserEmail(req),
},
}
}
@ -43,8 +61,10 @@ func (view View) Render(w http.ResponseWriter, name string) {
return
}
// add flashes to template
view.Vars["flashes"] = services.SessionStore.Flashes(w, view.Request)
if view.SessionStore != nil {
// add flashes to template
view.Vars["flashes"] = view.SessionStore.Flashes(w, view.Request)
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)