Refactor global variables into injected dependencies
This commit is contained in:
parent
e5cd6e0be3
commit
4a79379d8e
10 changed files with 366 additions and 182 deletions
126
handlers/auth.go
126
handlers/auth.go
|
@ -8,73 +8,77 @@ import (
|
|||
"git.klink.asia/paul/certman/models"
|
||||
)
|
||||
|
||||
func RegisterHandler(w http.ResponseWriter, req *http.Request) {
|
||||
// Get parameters
|
||||
email := req.Form.Get("email")
|
||||
password := req.Form.Get("password")
|
||||
func RegisterHandler(p *services.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
// Get parameters
|
||||
email := req.Form.Get("email")
|
||||
password := req.Form.Get("password")
|
||||
|
||||
user := models.User{}
|
||||
user.Email = email
|
||||
user.SetPassword(password)
|
||||
user := models.User{}
|
||||
user.Email = email
|
||||
user.SetPassword(password)
|
||||
|
||||
err := services.Database.Create(&user).Error
|
||||
if err != nil {
|
||||
panic(err.Error)
|
||||
err := p.DB.Create(&user).Error
|
||||
if err != nil {
|
||||
panic(err.Error)
|
||||
}
|
||||
|
||||
p.Sessions.Flash(w, req,
|
||||
services.Flash{
|
||||
Type: "success",
|
||||
Message: "The user was created. Check your inbox for the confirmation email.",
|
||||
},
|
||||
)
|
||||
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
services.SessionStore.Flash(w, req,
|
||||
services.Flash{
|
||||
Type: "success",
|
||||
Message: "The user was created. Check your inbox for the confirmation email.",
|
||||
},
|
||||
)
|
||||
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
func LoginHandler(w http.ResponseWriter, req *http.Request) {
|
||||
// Get parameters
|
||||
email := req.Form.Get("email")
|
||||
password := req.Form.Get("password")
|
||||
func LoginHandler(p *services.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
// Get parameters
|
||||
email := req.Form.Get("email")
|
||||
password := req.Form.Get("password")
|
||||
|
||||
user := models.User{}
|
||||
user := models.User{}
|
||||
|
||||
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
|
||||
if err != nil {
|
||||
// could not find user
|
||||
services.SessionStore.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "Invalid Email or Password.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
|
||||
if err != nil {
|
||||
// could not find user
|
||||
p.Sessions.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "Invalid Email or Password.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if !user.EmailValid {
|
||||
p.Sessions.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "You need to confirm your email before logging in.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if err := user.CheckPassword(password); err != nil {
|
||||
// wrong password
|
||||
p.Sessions.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "Invalid Email or Password.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// user is logged in, set cookie
|
||||
p.Sessions.SetUserEmail(w, req, email)
|
||||
|
||||
http.Redirect(w, req, "/certs", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
if !user.EmailValid {
|
||||
services.SessionStore.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "You need to confirm your email before logging in.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
if err := user.CheckPassword(password); err != nil {
|
||||
// wrong password
|
||||
services.SessionStore.Flash(
|
||||
w, req, services.Flash{
|
||||
Type: "warning", Message: "Invalid Email or Password.",
|
||||
},
|
||||
)
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
// user is logged in, set cookie
|
||||
services.SessionStore.SetUserEmail(w, req, email)
|
||||
|
||||
http.Redirect(w, req, "/certs", http.StatusSeeOther)
|
||||
}
|
||||
|
|
120
handlers/cert.go
120
handlers/cert.go
|
@ -19,69 +19,75 @@ import (
|
|||
"git.klink.asia/paul/certman/views"
|
||||
)
|
||||
|
||||
func ListCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||
v := views.New(req)
|
||||
v.Render(w, "cert_list")
|
||||
func ListCertHandler(p *services.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
v := views.NewWithSession(req, p.Sessions)
|
||||
v.Render(w, "cert_list")
|
||||
}
|
||||
}
|
||||
|
||||
func CreateCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||
email := services.SessionStore.GetUserEmail(req)
|
||||
certname := req.FormValue("certname")
|
||||
func CreateCertHandler(p *services.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
email := p.Sessions.GetUserEmail(req)
|
||||
certname := req.FormValue("certname")
|
||||
|
||||
user := models.User{}
|
||||
err := services.Database.Where(&models.User{Email: email}).Find(&user).Error
|
||||
if err != nil {
|
||||
fmt.Printf("Could not fetch user for mail %s\n", email)
|
||||
user := models.User{}
|
||||
err := p.DB.Where(&models.User{Email: email}).Find(&user).Error
|
||||
if err != nil {
|
||||
fmt.Printf("Could not fetch user for mail %s\n", email)
|
||||
}
|
||||
|
||||
// Load CA master certificate
|
||||
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
|
||||
if err != nil {
|
||||
log.Fatalf("error loading ca keyfiles: %s", err)
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
// Generate Keypair
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not generate keypair: %s", err)
|
||||
}
|
||||
|
||||
// Generate Certificate
|
||||
derBytes, err := CreateCertificate(key, caCert, caKey)
|
||||
|
||||
// Initialize new client config
|
||||
client := models.Client{
|
||||
Name: certname,
|
||||
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
|
||||
Cert: derBytes,
|
||||
UserID: user.ID,
|
||||
}
|
||||
|
||||
// Insert client into database
|
||||
if err := p.DB.Create(&client).Error; err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
p.Sessions.Flash(w, req,
|
||||
services.Flash{
|
||||
Type: "success",
|
||||
Message: "The certificate was created successfully.",
|
||||
},
|
||||
)
|
||||
|
||||
http.Redirect(w, req, "/certs", http.StatusFound)
|
||||
}
|
||||
|
||||
// Load CA master certificate
|
||||
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
|
||||
if err != nil {
|
||||
log.Fatalf("error loading ca keyfiles: %s", err)
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
// Generate Keypair
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not generate keypair: %s", err)
|
||||
}
|
||||
|
||||
// Generate Certificate
|
||||
derBytes, err := CreateCertificate(key, caCert, caKey)
|
||||
|
||||
// Initialize new client config
|
||||
client := models.Client{
|
||||
Name: certname,
|
||||
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
|
||||
Cert: derBytes,
|
||||
UserID: user.ID,
|
||||
}
|
||||
|
||||
// Insert client into database
|
||||
if err := services.Database.Create(&client).Error; err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
services.SessionStore.Flash(w, req,
|
||||
services.Flash{
|
||||
Type: "success",
|
||||
Message: "The certificate was created successfully.",
|
||||
},
|
||||
)
|
||||
|
||||
http.Redirect(w, req, "/certs", http.StatusFound)
|
||||
}
|
||||
|
||||
func DownloadCertHandler(w http.ResponseWriter, req *http.Request) {
|
||||
//v := views.New(req)
|
||||
//
|
||||
//derBytes, err := CreateCertificate(key, caCert, caKey)
|
||||
//pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
//
|
||||
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
|
||||
return
|
||||
func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
//v := views.New(req)
|
||||
//
|
||||
//derBytes, err := CreateCertificate(key, caCert, caKey)
|
||||
//pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
//
|
||||
//pkBytes := x509.MarshalPKCS1PrivateKey(key)
|
||||
//pem.Encode(w, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: pkBytes})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateKey, error) {
|
||||
|
|
34
main.go
34
main.go
|
@ -3,22 +3,46 @@ package main
|
|||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
|
||||
"git.klink.asia/paul/certman/services"
|
||||
|
||||
"git.klink.asia/paul/certman/router"
|
||||
"git.klink.asia/paul/certman/views"
|
||||
|
||||
// import sqlite3 driver once
|
||||
// import sqlite3 driver
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func main() {
|
||||
c := services.Config{
|
||||
DB: &services.DBConfig{
|
||||
Type: "sqlite3",
|
||||
DSN: "db.sqlite3",
|
||||
Log: true,
|
||||
},
|
||||
Sessions: &services.SessionsConfig{
|
||||
SessionName: "_session",
|
||||
CookieKey: string(securecookie.GenerateRandomKey(32)),
|
||||
HttpOnly: true,
|
||||
Lifetime: 24 * time.Hour,
|
||||
},
|
||||
Email: &services.EmailConfig{
|
||||
SMTPServer: "example.com",
|
||||
SMTPPort: 25,
|
||||
SMTPUsername: "test",
|
||||
SMTPPassword: "test",
|
||||
From: "Mailtest <test@example.com>",
|
||||
},
|
||||
}
|
||||
|
||||
// Connect to the database
|
||||
db := services.InitDB()
|
||||
serviceProvider := services.NewProvider(&c)
|
||||
|
||||
services.InitSession()
|
||||
// Start the mail daemon, which re-uses connections to send mails to the
|
||||
// SMTP server
|
||||
go serviceProvider.Email.Daemon()
|
||||
|
||||
//user := models.User{}
|
||||
//user.Username = "test"
|
||||
|
@ -29,7 +53,7 @@ func main() {
|
|||
// load and parse template files
|
||||
views.LoadTemplates()
|
||||
|
||||
mux := router.HandleRoutes(db)
|
||||
mux := router.HandleRoutes(serviceProvider)
|
||||
|
||||
err := http.ListenAndServe(":8000", mux)
|
||||
log.Fatalf(err.Error())
|
||||
|
|
|
@ -8,13 +8,15 @@ import (
|
|||
|
||||
// RequireLogin is a middleware that checks for a username in the active
|
||||
// session, and redirects to `/login` if no username was found.
|
||||
func RequireLogin(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, req *http.Request) {
|
||||
if username := services.SessionStore.GetUserEmail(req); username == "" {
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
}
|
||||
func RequireLogin(sessions *services.Sessions) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, req *http.Request) {
|
||||
if username := sessions.GetUserEmail(req); username == "" {
|
||||
http.Redirect(w, req, "/login", http.StatusFound)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, req)
|
||||
next.ServeHTTP(w, req)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
mw "git.klink.asia/paul/certman/middleware"
|
||||
)
|
||||
|
@ -26,15 +25,15 @@ var (
|
|||
cookieKey = []byte("osx70sMD8HZG2ouUl8uKI4wcMugiJ2WH")
|
||||
)
|
||||
|
||||
func HandleRoutes(db *gorm.DB) http.Handler {
|
||||
func HandleRoutes(provider *services.Provider) http.Handler {
|
||||
mux := chi.NewMux()
|
||||
|
||||
//mux.Use(middleware.RequestID)
|
||||
mux.Use(middleware.Logger) // log requests
|
||||
mux.Use(middleware.RealIP) // use proxy headers
|
||||
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
|
||||
mux.Use(mw.Recoverer) // recover on panic
|
||||
mux.Use(services.SessionStore.Use) // use session storage
|
||||
mux.Use(middleware.Logger) // log requests
|
||||
mux.Use(middleware.RealIP) // use proxy headers
|
||||
mux.Use(middleware.RedirectSlashes) // redirect trailing slashes
|
||||
mux.Use(mw.Recoverer) // recover on panic
|
||||
mux.Use(provider.Sessions.Manager.Use) // use session storage
|
||||
|
||||
// we are serving the static files directly from the assets package
|
||||
// this either means we use the embedded files, or live-load
|
||||
|
@ -56,26 +55,26 @@ func HandleRoutes(db *gorm.DB) http.Handler {
|
|||
|
||||
r.Route("/register", func(r chi.Router) {
|
||||
r.Get("/", v("register"))
|
||||
r.Post("/", handlers.RegisterHandler)
|
||||
r.Post("/", handlers.RegisterHandler(provider))
|
||||
})
|
||||
|
||||
r.Route("/login", func(r chi.Router) {
|
||||
r.Get("/", v("login"))
|
||||
r.Post("/", handlers.LoginHandler)
|
||||
r.Post("/", handlers.LoginHandler(provider))
|
||||
})
|
||||
|
||||
//r.Post("/confirm-email/{token}", handlers.ConfirmEmailHandler(db))
|
||||
|
||||
r.Route("/forgot-password", func(r chi.Router) {
|
||||
r.Get("/", v("forgot-password"))
|
||||
r.Post("/", handlers.LoginHandler)
|
||||
r.Post("/", handlers.LoginHandler(provider))
|
||||
})
|
||||
|
||||
r.Route("/certs", func(r chi.Router) {
|
||||
r.Use(mw.RequireLogin)
|
||||
r.Get("/", handlers.ListCertHandler)
|
||||
r.Post("/new", handlers.CreateCertHandler)
|
||||
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler)
|
||||
r.Use(mw.RequireLogin(provider.Sessions))
|
||||
r.Get("/", handlers.ListCertHandler(provider))
|
||||
r.Post("/new", handlers.CreateCertHandler(provider))
|
||||
r.HandleFunc("/download/{ID}", handlers.DownloadCertHandler(provider))
|
||||
})
|
||||
|
||||
r.HandleFunc("/500", func(w http.ResponseWriter, req *http.Request) {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"log"
|
||||
|
||||
"git.klink.asia/paul/certman/models"
|
||||
"git.klink.asia/paul/certman/settings"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
|
@ -14,28 +13,34 @@ var (
|
|||
ErrNotImplemented = errors.New("Not implemented")
|
||||
)
|
||||
|
||||
var Database *gorm.DB
|
||||
type DBConfig struct {
|
||||
Type string
|
||||
DSN string
|
||||
Log bool
|
||||
}
|
||||
|
||||
// DB is a wrapper around gorm.DB to provide custom methods
|
||||
type DB struct {
|
||||
*gorm.DB
|
||||
|
||||
conf *DBConfig
|
||||
}
|
||||
|
||||
func InitDB() *gorm.DB {
|
||||
dsn := settings.Get("DATABASE_URL", "db.sqlite3")
|
||||
|
||||
func NewDB(conf *DBConfig) *DB {
|
||||
// Establish connection
|
||||
db, err := gorm.Open("sqlite3", dsn)
|
||||
db, err := gorm.Open(conf.Type, conf.DSN)
|
||||
if err != nil {
|
||||
log.Fatalf("Could not open database: %s", err.Error())
|
||||
}
|
||||
|
||||
// Migrate models
|
||||
db.AutoMigrate(models.User{}, models.Client{})
|
||||
db.LogMode(true)
|
||||
db.LogMode(conf.Log)
|
||||
|
||||
Database = db
|
||||
return db
|
||||
return &DB{
|
||||
DB: db,
|
||||
conf: conf,
|
||||
}
|
||||
}
|
||||
|
||||
// CountUsers returns the number of Users in the datastore
|
||||
|
|
103
services/email.go
Normal file
103
services/email.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/go-mail/mail"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMailUninitializedConfig = errors.New("Mail: uninitialized config")
|
||||
)
|
||||
|
||||
type EmailConfig struct {
|
||||
From string
|
||||
SMTPServer string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
}
|
||||
|
||||
type Email struct {
|
||||
config *EmailConfig
|
||||
|
||||
mailChan chan *mail.Message
|
||||
}
|
||||
|
||||
func NewEmail(conf *EmailConfig) *Email {
|
||||
if conf == nil {
|
||||
log.Println(ErrMailUninitializedConfig)
|
||||
}
|
||||
|
||||
return &Email{
|
||||
config: conf,
|
||||
mailChan: make(chan *mail.Message, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Send sends an email to the receiver
|
||||
func (email *Email) Send(to, subject, text, html string) error {
|
||||
if email.config == nil {
|
||||
log.Print("Error: trying to send mail with uninitialized config.")
|
||||
return ErrMailUninitializedConfig
|
||||
}
|
||||
|
||||
m := mail.NewMessage()
|
||||
m.SetHeader("From", email.config.From)
|
||||
m.SetHeader("To", to)
|
||||
m.SetHeader("Subject", subject)
|
||||
m.SetBody("text/plain", text)
|
||||
m.AddAlternative("text/html", html)
|
||||
|
||||
// put email in chan
|
||||
email.mailChan <- m
|
||||
return nil
|
||||
}
|
||||
|
||||
// Daemon is a function that takes Mail and sends it without blocking.
|
||||
// WIP
|
||||
func (email *Email) Daemon() {
|
||||
if email.config == nil {
|
||||
log.Print("Error: trying to set up mail deamon with uninitialized config.")
|
||||
return
|
||||
}
|
||||
|
||||
d := mail.NewDialer(
|
||||
email.config.SMTPServer,
|
||||
email.config.SMTPPort,
|
||||
email.config.SMTPUsername,
|
||||
email.config.SMTPPassword)
|
||||
|
||||
var s mail.SendCloser
|
||||
var err error
|
||||
open := false
|
||||
for {
|
||||
select {
|
||||
case m, ok := <-email.mailChan:
|
||||
if !ok {
|
||||
// channel is closed
|
||||
return
|
||||
}
|
||||
if !open {
|
||||
if s, err = d.Dial(); err != nil {
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
open = true
|
||||
}
|
||||
if err := mail.Send(s, m); err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
// Close the connection if no email was sent in the last 30 seconds.
|
||||
case <-time.After(30 * time.Second):
|
||||
if open {
|
||||
if err := s.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
open = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
24
services/provider.go
Normal file
24
services/provider.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package services
|
||||
|
||||
type Config struct {
|
||||
DB *DBConfig
|
||||
Sessions *SessionsConfig
|
||||
Email *EmailConfig
|
||||
}
|
||||
|
||||
type Provider struct {
|
||||
DB *DB
|
||||
Sessions *Sessions
|
||||
Email *Email
|
||||
}
|
||||
|
||||
// NewProvider returns the ServiceProvider
|
||||
func NewProvider(conf *Config) *Provider {
|
||||
var provider = &Provider{}
|
||||
|
||||
provider.DB = NewDB(conf.DB)
|
||||
provider.Sessions = NewSessions(conf.Sessions)
|
||||
provider.Email = NewEmail(conf.Email)
|
||||
|
||||
return provider
|
||||
}
|
|
@ -8,16 +8,10 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"git.klink.asia/paul/certman/settings"
|
||||
"github.com/alexedwards/scs"
|
||||
"github.com/gorilla/securecookie"
|
||||
)
|
||||
|
||||
var (
|
||||
// SessionName is the name of the session cookie
|
||||
SessionName = "session"
|
||||
// CookieKey is the key the cookies are encrypted and signed with
|
||||
CookieKey = string(securecookie.GenerateRandomKey(32))
|
||||
// FlashesKey is the key used for the flashes in the cookie
|
||||
FlashesKey = "_flashes"
|
||||
// UserEmailKey is the key used to reference usernames
|
||||
|
@ -29,30 +23,33 @@ func init() {
|
|||
gob.Register(Flash{})
|
||||
}
|
||||
|
||||
// SessionStore is a globally accessible sessions store for the application
|
||||
var SessionStore *Store
|
||||
type SessionsConfig struct {
|
||||
SessionName string
|
||||
CookieKey string
|
||||
HttpOnly bool
|
||||
Secure bool
|
||||
Lifetime time.Duration
|
||||
}
|
||||
|
||||
// Store is a wrapped scs.Store in order to implement custom
|
||||
// logic
|
||||
type Store struct {
|
||||
// Sessions is a wrapped scs.Store in order to implement custom logic
|
||||
type Sessions struct {
|
||||
*scs.Manager
|
||||
}
|
||||
|
||||
// InitSession populates the default sessions Store
|
||||
func InitSession() {
|
||||
// NewSessions populates the default sessions Store
|
||||
func NewSessions(conf *SessionsConfig) *Sessions {
|
||||
store := scs.NewCookieManager(
|
||||
CookieKey,
|
||||
conf.CookieKey,
|
||||
)
|
||||
store.Name(conf.SessionName)
|
||||
store.HttpOnly(true)
|
||||
store.Lifetime(24 * time.Hour)
|
||||
store.Lifetime(conf.Lifetime)
|
||||
store.Secure(conf.Secure)
|
||||
|
||||
// Use secure cookies (HTTPS only) in production
|
||||
store.Secure(settings.Get("ENVIRONMENT", "") == "production")
|
||||
|
||||
SessionStore = &Store{store}
|
||||
return &Sessions{store}
|
||||
}
|
||||
|
||||
func (store *Store) GetUserEmail(req *http.Request) string {
|
||||
func (store *Sessions) GetUserEmail(req *http.Request) string {
|
||||
if store == nil {
|
||||
// if store was not initialized, all requests fail
|
||||
log.Println("Zero pointer when checking session for username")
|
||||
|
@ -72,7 +69,7 @@ func (store *Store) GetUserEmail(req *http.Request) string {
|
|||
return email
|
||||
}
|
||||
|
||||
func (store *Store) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
|
||||
func (store *Sessions) SetUserEmail(w http.ResponseWriter, req *http.Request, email string) {
|
||||
if store == nil {
|
||||
// if store was not initialized, do nothing
|
||||
return
|
||||
|
@ -103,7 +100,7 @@ func (flash Flash) Render() template.HTML {
|
|||
}
|
||||
|
||||
// Flash add flash message to session data
|
||||
func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
|
||||
func (store *Sessions) Flash(w http.ResponseWriter, req *http.Request, flash Flash) error {
|
||||
var flashes []Flash
|
||||
|
||||
sess := store.Load(req)
|
||||
|
@ -118,7 +115,7 @@ func (store *Store) Flash(w http.ResponseWriter, req *http.Request, flash Flash)
|
|||
}
|
||||
|
||||
// Flashes returns a slice of flash messages from session data
|
||||
func (store *Store) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
|
||||
func (store *Sessions) Flashes(w http.ResponseWriter, req *http.Request) []Flash {
|
||||
var flashes []Flash
|
||||
sess := store.Load(req)
|
||||
sess.PopObject(w, FlashesKey, &flashes)
|
||||
|
|
|
@ -13,8 +13,9 @@ import (
|
|||
)
|
||||
|
||||
type View struct {
|
||||
Vars map[string]interface{}
|
||||
Request *http.Request
|
||||
Vars map[string]interface{}
|
||||
Request *http.Request
|
||||
SessionStore *services.Sessions
|
||||
}
|
||||
|
||||
func New(req *http.Request) *View {
|
||||
|
@ -23,12 +24,29 @@ func New(req *http.Request) *View {
|
|||
Vars: map[string]interface{}{
|
||||
"CSRF_TOKEN": csrf.Token(req),
|
||||
"csrfField": csrf.TemplateField(req),
|
||||
"username": services.SessionStore.GetUserEmail(req),
|
||||
"Meta": map[string]interface{}{
|
||||
"Path": req.URL.Path,
|
||||
"Env": "develop",
|
||||
},
|
||||
"flashes": []services.Flash{},
|
||||
"flashes": []services.Flash{},
|
||||
"username": "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewWithSession(req *http.Request, sessionStore *services.Sessions) *View {
|
||||
return &View{
|
||||
Request: req,
|
||||
SessionStore: sessionStore,
|
||||
Vars: map[string]interface{}{
|
||||
"CSRF_TOKEN": csrf.Token(req),
|
||||
"csrfField": csrf.TemplateField(req),
|
||||
"Meta": map[string]interface{}{
|
||||
"Path": req.URL.Path,
|
||||
"Env": "develop",
|
||||
},
|
||||
"flashes": []services.Flash{},
|
||||
"username": sessionStore.GetUserEmail(req),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -43,8 +61,10 @@ func (view View) Render(w http.ResponseWriter, name string) {
|
|||
return
|
||||
}
|
||||
|
||||
// add flashes to template
|
||||
view.Vars["flashes"] = services.SessionStore.Flashes(w, view.Request)
|
||||
if view.SessionStore != nil {
|
||||
// add flashes to template
|
||||
view.Vars["flashes"] = view.SessionStore.Flashes(w, view.Request)
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
Loading…
Reference in a new issue