Refactor code, delete old references

This commit is contained in:
Paul 2018-02-03 18:14:47 +01:00
parent 71e830d52f
commit 9ef61b19bb
16 changed files with 435 additions and 223 deletions

4
.gitignore vendored
View File

@ -1,3 +1,7 @@
*_vfsdata.go
certman
db.sqlite3
*.crt
*.key
clients.json
.env

View File

@ -15,61 +15,73 @@ before_script:
- cd $GOPATH/src/$REPO_NAME
stages:
- test
- build
- release
- test
- build
- release
format:
stage: test
tags:
- docker
script:
# we use tags="dev" so there is no dependency on the prebuilt assets yet
- go get -tags="dev" -v $(go list ./... | grep -v /vendor/) # get missing dependencies
- go fmt $(go list ./... | grep -v /vendor/)
- go vet $(go list ./... | grep -v /vendor/)
- go test -tags="dev" -race $(go list ./... | grep -v /vendor/) -v -coverprofile .testCoverage.txt
# Use coverage parsing regex: ^coverage:\s(\d+(?:\.\d+)?%)
stage: test
tags:
- docker
script:
# we use tags="dev" so there is no dependency on the prebuilt assets yet
- go get -tags="dev" -v $(go list ./... | grep -v /vendor/) # get missing dependencies
- go fmt $(go list ./... | grep -v /vendor/)
- go vet $(go list ./... | grep -v /vendor/)
- go test -tags="dev" -race $(go list ./... | grep -v /vendor/) -v -coverprofile .testCoverage.txt
# Use coverage parsing regex: ^coverage:\s(\d+(?:\.\d+)?%)
compile:
stage: build
tags:
- docker
script:
# we use tags="dev" so there is no dependency on the prebuilt assets yet
- go get -tags="dev" -v $(go list ./... | grep -v /vendor/) # get missing dependencies
stage: build
tags:
- docker
script:
# we use tags="dev" so there is no dependency on the prebuilt assets yet
- go get -tags="dev" -v $(go list ./... | grep -v /vendor/) # get missing dependencies
# generate assets
- go get github.com/shurcooL/vfsgen/cmd/vfsgendev
- go generate git.klink.asia/paul/certman/assets
# generate assets
- go get github.com/shurcooL/vfsgen/cmd/vfsgendev
- go generate git.klink.asia/paul/certman/assets
# build binaries -- list of supported plattforms is here:
# https://stackoverflow.com/a/20728862
- GOOS=linux GOARCH=amd64 go build -o $CI_PROJECT_DIR/certman
#- GOOS=linux GOARCH=arm GOARM=6 go build -o $CI_PROJECT_DIR/certman.arm
#- GOOS=windows GOARCH=amd64 go build -o $CI_PROJECT_DIR/certman.exe
artifacts:
expire_in: "8 hrs"
paths:
- certman
# - certman.arm
# - certman.exe
# build binaries -- list of supported plattforms is here:
# https://stackoverflow.com/a/20728862
- GOOS=linux GOARCH=amd64 go build -tags "netgo" -o $CI_PROJECT_DIR/certman
- GOOS=linux GOARCH=arm GOARM=6 go build -tags "netgo" -o $CI_PROJECT_DIR/certman.arm
- GOOS=windows GOARCH=amd64 go build -tags "netgo" -o $CI_PROJECT_DIR/certman.exe
artifacts:
expire_in: "8 hrs"
paths:
- certman
- certman.arm
- certman.exe
minify:
stage: release
tags:
- docker
dependencies:
- compile
image:
name: znly/upx:latest
entrypoint: ["/bin/sh", "-c"]
script:
- upx --best --brute $CI_PROJECT_DIR/certman certman.arm certman.exe
artifacts:
paths:
- certman
#- certman.arm
#- certman.exe
only:
- tags
stage: release
tags:
- docker
dependencies:
- compile
image:
name: znly/upx:latest
entrypoint: ["/bin/sh", "-c"]
script:
- upx --best --brute $CI_PROJECT_DIR/certman $CI_PROJECT_DIR/certman.arm $CI_PROJECT_DIR/certman.exe
artifacts:
paths:
- certman
- certman.arm
- certman.exe
only:
- tags
build_image:
stage: release
tags:
- dind
image: "docker:latest"
script:
- docker login -u gitlab-ci-token -p $CI_JOB_TOKEN $CI_REGISTRY
- docker build -t $CI_REGISTRY_IMAGE:${CI_COMMIT_REF_NAME#v} .
- docker push $CI_REGISTRY_IMAGE:${CI_COMMIT_REF_NAME#v}
only:
- tags

21
Dockerfile Normal file
View File

@ -0,0 +1,21 @@
FROM golang:1.9
WORKDIR /go/src/git.klink.asia/paul/certman
ADD . .
RUN \
go get github.com/shurcooL/vfsgen/cmd/vfsgendev && \
go generate git.klink.asia/paul/certman/assets && \
go get -v git.klink.asia/paul/certman && \
go build -tags netgo
FROM scratch
ENV \
OAUTH2_CLIENT_ID="" \
OAUTH2_CLIENT_SECRET="" \
APP_KEY="" \
OAUTH2_AUTH_URL="https://gitlab.example.com/oauth/authorize" \
OAUTH2_TOKEN_URL="https://gitlab.example.com/oauth/token" \
USER_ENDPOINT="https://gitlab.example.com/api/v4/user" \
OAUTH2_REDIRECT_URL="https://certman.example.com/login/oauth2/redirect"
COPY --from=0 /go/src/git.klink.asia/paul/certman/certman /
ENTRYPOINT ["/certman"]

View File

@ -1,4 +1,4 @@
{{ define "base" }}
{{ define "base" }}# Client configuration for {{ .User }}@{{ .Name }}
client
dev tun
proto udp
@ -45,12 +45,10 @@ Yo95ZQ==
</ca>
<cert>
{{ .Cert | html }}
</cert>
{{ .Cert | html }}</cert>
<key>
{{ .Key | html }}
</key>
{{ .Key | html }}</key>
<tls-auth>
#

View File

@ -1,5 +1,5 @@
{{ define "meta" }}
<title>Log in</title>
<title>Certificate List</title>
{{ end}}
{{ define "content" }}
@ -8,7 +8,7 @@
<div class="container">
<div class="columns">
<div class="column">
<h1>Certificates</h1>
<h1 class="title">Certificates for {{ .username }}:</h1>
<table class="table">
<thead>
@ -25,7 +25,7 @@
{{ $.username }}@
</a>
</p>
<p class="control is-marginless is-expanded">
<p class="control is-marginless is-expanded">
<input name="certname" class="input" type="text" placeholder="Certificate name (e.g. Laptop)">
</p>
</div>
@ -36,20 +36,23 @@
<tbody>
{{ range .Clients }}
<tr>
<td class="is-vcentered"><p>{{ $.username }}@{{ .Name }}</p></td>
<td class="is-vcentered"><p>{{ .User }}@{{ .Name }}</p></td>
<td><time title="{{ .CreatedAt.UTC }}">{{ .CreatedAt | humanDate }}</time></td>
<td>
<div class="field has-addons">
<p class="control is-marginless is-expanded">
<a href="/certs/download/{{ .Name }}" class="button is-primary is-fullwidth">Download</a>
</p>
<p class="control is-marginless">
<a class="button is-danger">
<span class="icon is-small">
<i class="fas fa-trash"></i>
</span>
</a>
</p>
<div class="control is-marginless">
<form action="/certs/delete/{{ .Name }}" method="POST">
{{ $.csrfField }}
<button class="button is-danger" type="submit">
<span class="icon is-small">
<i class="fas fa-trash"></i>
</span>
</button>
</form>
</div>
</div>
</td>
</tr>

37
handlers/README.md Normal file
View File

@ -0,0 +1,37 @@
# Certman
Certman is a simple certificate manager web service for OpenVPN.
## Installation
### Binary
There are prebuilt binary files for this application. They are statically
linked and have no additional dependencies. Supported plattforms are:
* Windows (XP and up)
* Linux (2.6.16 and up)
* Linux ARM (for raspberry pi, 3.0 and up)
Simply download them from the "artifacts" section of this project.
### Docker
A prebuilt docker image (10MB) is available:
```bash
docker pull docker.klink.asia/paul/certman
```
### From Source-Docker
You can easily build your own docker image from source
```bash
docker build -t docker.klink.asia/paul/certman .
```
## Configuration
Certman assumes the root certificates of the VPN CA are located in the same
directory as the binary, If that is not the case you need to copy over the
`ca.crt` and `ca.key` files before you are able to generate certificates
with this tool.
Additionally, the project is configured by the following environment
variables:
* `OAUTH2_CLIENT_ID` the Client ID, assigned during client registration
* `OAUTH2_CLIENT_SECRET` the Client secret, assigned during client registration
* `OAUTH2_AUTH_URL` the URL to the "/authorize" endpoint of the identity provider
* `OAUTH2_TOKEN_URL` the URL to the "/token" endpoint of the identity provider
* `OAUTH2_REDIRECT_URL` the redirect URL used by the app, usually the hostname suffixed by "/login/oauth2/redirect"
* `USER_ENDPOINT` the URL to the Identity provider user endpoint, for gitlab this is "/api/v4/user". The "username" attribute of the returned JSON will used for authentication.
* `APP_KEY` random ASCII string, 32 characters in length. Used for cookie generation.

View File

@ -12,25 +12,14 @@ import (
"git.klink.asia/paul/certman/services"
)
var GitlabConfig = &oauth2.Config{
ClientID: os.Getenv("OAUTH2_CLIENT_ID"),
ClientSecret: os.Getenv("OAUTH2_CLIENT_SECRET"),
Scopes: []string{"read_user"},
RedirectURL: os.Getenv("HOST") + "/login/oauth2/redirect",
Endpoint: oauth2.Endpoint{
AuthURL: os.Getenv("OAUTH2_AUTH_URL"),
TokenURL: os.Getenv("OAUTH2_TOKEN_URL"),
},
}
func OAuth2Endpoint(p *services.Provider) http.HandlerFunc {
func OAuth2Endpoint(p *services.Provider, config *oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
v := views.NewWithSession(req, p.Sessions)
code := req.FormValue("code")
// exchange code for token
accessToken, err := GitlabConfig.Exchange(oauth2.NoContext, code)
accessToken, err := config.Exchange(oauth2.NoContext, code)
if err != nil {
fmt.Println(err)
http.NotFound(w, req)
@ -39,9 +28,9 @@ func OAuth2Endpoint(p *services.Provider) http.HandlerFunc {
if accessToken.Valid() {
// generate a client using the access token
httpClient := GitlabConfig.Client(oauth2.NoContext, accessToken)
httpClient := config.Client(oauth2.NoContext, accessToken)
apiRequest, err := http.NewRequest("GET", "https://git.klink.asia/api/v4/user", nil)
apiRequest, err := http.NewRequest("GET", os.Getenv("USER_ENDPOINT"), nil)
if err != nil {
v.RenderError(w, http.StatusNotFound)
return
@ -78,9 +67,9 @@ func OAuth2Endpoint(p *services.Provider) http.HandlerFunc {
}
}
func GetLoginHandler(p *services.Provider) http.HandlerFunc {
func GetLoginHandler(p *services.Provider, config *oauth2.Config) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
authURL := GitlabConfig.AuthCodeURL("", oauth2.AccessTypeOnline)
authURL := config.AuthCodeURL("", oauth2.AccessTypeOnline)
http.Redirect(w, req, authURL, http.StatusFound)
}
}

View File

@ -8,10 +8,12 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"html/template"
"io/ioutil"
"log"
"math/big"
"net/http"
"strings"
"time"
"git.klink.asia/paul/certman/models"
@ -27,7 +29,7 @@ func ListClientsHandler(p *services.Provider) http.HandlerFunc {
username := p.Sessions.GetUsername(req)
clients, _ := p.DB.ListClientsForUser(username, 100, 0)
clients, _ := p.ClientCollection.ListClientsForUser(username)
v.Vars["Clients"] = clients
v.Render(w, "client_list")
@ -39,7 +41,8 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
username := p.Sessions.GetUsername(req)
certname := req.FormValue("certname")
if !IsByteLength(certname, 2, 64) || !IsAlphanumeric(certname) {
// Validate certificate Name
if !IsByteLength(certname, 2, 64) || !IsDNSName(certname) {
p.Sessions.Flash(w, req,
services.Flash{
Type: "danger",
@ -50,6 +53,10 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
return
}
// lowercase the certificate name, to avoid problems with the case
// insensitive matching inside OpenVPN
certname = strings.ToLower(certname)
// Load CA master certificate
caCert, caKey, err := loadX509KeyPair("ca.crt", "ca.key")
if err != nil {
@ -78,13 +85,14 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
// Initialize new client config
client := models.Client{
Name: certname,
CreatedAt: time.Now(),
PrivateKey: x509.MarshalPKCS1PrivateKey(key),
Cert: derBytes,
User: username,
}
// Insert client into database
if err := p.DB.CreateClient(&client); err != nil {
if err := p.ClientCollection.CreateClient(&client); err != nil {
log.Println(err.Error())
p.Sessions.Flash(w, req,
services.Flash{
@ -107,6 +115,40 @@ func CreateCertHandler(p *services.Provider) http.HandlerFunc {
}
}
func DeleteCertHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
v := views.New(req)
// detemine own username
username := p.Sessions.GetUsername(req)
name := chi.URLParam(req, "name")
client, err := p.ClientCollection.GetClientByNameUser(name, username)
if err != nil {
v.RenderError(w, http.StatusNotFound)
return
}
err = p.ClientCollection.DeleteClient(client.ID)
if err != nil {
p.Sessions.Flash(w, req,
services.Flash{
Type: "danger",
Message: "Failed to delete certificate",
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
}
p.Sessions.Flash(w, req,
services.Flash{
Type: "success",
Message: template.HTML(fmt.Sprintf("Successfully deleted client <strong>%s</strong>", client.Name)),
},
)
http.Redirect(w, req, "/certs", http.StatusFound)
}
}
func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
v := views.New(req)
@ -114,7 +156,7 @@ func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
username := p.Sessions.GetUsername(req)
name := chi.URLParam(req, "name")
client, err := p.DB.GetClientByNameUser(name, username)
client, err := p.ClientCollection.GetClientByNameUser(name, username)
if err != nil {
v.RenderError(w, http.StatusNotFound)
return
@ -143,7 +185,6 @@ func DownloadCertHandler(p *services.Provider) http.HandlerFunc {
w.Header().Set("Content-Type", "application/x-openvpn-profile")
w.Header().Set("Content-Disposition", "attachment; filename=\"config.ovpn\"")
w.WriteHeader(http.StatusOK)
log.Println(vars)
t.Execute(w, vars)
return
}
@ -159,10 +200,8 @@ func loadX509KeyPair(certFile, keyFile string) (*x509.Certificate, *rsa.PrivateK
if err != nil {
return nil, nil, err
}
cpb, cr := pem.Decode(cf)
fmt.Println(string(cr))
kpb, kr := pem.Decode(kf)
fmt.Println(string(kr))
cpb, _ := pem.Decode(cf)
kpb, _ := pem.Decode(kf)
crt, err := x509.ParseCertificate(cpb.Bytes)
if err != nil {
@ -191,11 +230,10 @@ func CreateCertificate(commonName string, key interface{}, caCert *x509.Certific
SerialNumber: serialNumber,
Subject: subj,
NotBefore: time.Now(),
NotAfter: time.Now().Add(24 * time.Hour * 356 * 5),
NotBefore: time.Now().Add(-5 * time.Minute), // account for clock shift
NotAfter: time.Now().Add(24 * time.Hour * 356 * 5), // 5 years ought to be enough!
SignatureAlgorithm: x509.SHA256WithRSA,
//KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageDataEncipherment,
SignatureAlgorithm: x509.SHA256WithRSA,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true,
}

View File

@ -1,6 +1,10 @@
package handlers
import "regexp"
import (
"net"
"regexp"
"strings"
)
const (
Email string = "^(((([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+(\\.([a-zA-Z]|\\d|[!#\\$%&'\\*\\+\\-\\/=\\?\\^_`{\\|}~]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])+)*)|((\\x22)((((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(([\\x01-\\x08\\x0b\\x0c\\x0e-\\x1f\\x7f]|\\x21|[\\x23-\\x5b]|[\\x5d-\\x7e]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(\\([\\x01-\\x09\\x0b\\x0c\\x0d-\\x7f]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}]))))*(((\\x20|\\x09)*(\\x0d\\x0a))?(\\x20|\\x09)+)?(\\x22)))@((([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|\\.|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|\\d|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.)+(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])|(([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])([a-zA-Z]|\\d|-|_|~|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])*([a-zA-Z]|[\\x{00A0}-\\x{D7FF}\\x{F900}-\\x{FDCF}\\x{FDF0}-\\x{FFEF}])))\\.?$"
@ -118,3 +122,17 @@ func IsNull(str string) bool {
func IsByteLength(str string, min, max int) bool {
return len(str) >= min && len(str) <= max
}
// IsDNSName will validate the given string as a DNS name
func IsDNSName(str string) bool {
if str == "" || len(strings.Replace(str, ".", "", -1)) > 255 {
// constraints already violated
return false
}
return !IsIP(str) && rxDNSName.MatchString(str)
}
// IsIP checks if a string is either IP version 4 or 6.
func IsIP(str string) bool {
return net.ParseIP(str) != nil
}

32
main.go
View File

@ -1,43 +1,53 @@
package main
import (
"errors"
"log"
"net/http"
"os"
"time"
"github.com/gorilla/securecookie"
"git.klink.asia/paul/certman/services"
"git.klink.asia/paul/certman/router"
"git.klink.asia/paul/certman/views"
// import sqlite3 driver
_ "github.com/mattn/go-sqlite3"
)
func main() {
log.Println("Initializing certman")
if err := checkCAFilesExist(); err != nil {
log.Fatalf("Could not read CA files: %s", err)
}
c := services.Config{
DB: &services.DBConfig{
Type: "sqlite3",
DSN: "db.sqlite3",
Log: true,
},
CollectionPath: "./clients.json",
Sessions: &services.SessionsConfig{
SessionName: "_session",
CookieKey: string(securecookie.GenerateRandomKey(32)),
CookieKey: os.Getenv("APP_KEY"),
HttpOnly: true,
Lifetime: 24 * time.Hour,
},
}
log.Println(".. services")
serviceProvider := services.NewProvider(&c)
// load and parse template files
log.Println(".. templates")
views.LoadTemplates()
mux := router.HandleRoutes(serviceProvider)
log.Println(".. server")
err := http.ListenAndServe(":8000", mux)
log.Fatalf(err.Error())
}
func checkCAFilesExist() error {
for _, filename := range []string{"ca.crt", "ca.key"} {
if _, err := os.Stat(filename); os.IsNotExist(err) {
return errors.New(filename + " not readable")
}
}
return nil
}

View File

@ -11,20 +11,12 @@ var (
ErrNotImplemented = errors.New("Not implemented")
)
// Model is a base model definition, including helpful fields for dealing with
// models in a database
type Model struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt *time.Time `sql:"index"`
}
// Client represent the OpenVPN client configuration
type Client struct {
Model
Name string `gorm:"index;unique_index:idx_name_user"`
User string `gorm:"index;unique_index:idx_name_user"`
ID uint
CreatedAt time.Time
Name string
User string
Cert []byte
PrivateKey []byte
}

View File

@ -6,6 +6,7 @@ import (
"strings"
"git.klink.asia/paul/certman/services"
"golang.org/x/oauth2"
"git.klink.asia/paul/certman/assets"
"git.klink.asia/paul/certman/handlers"
@ -35,6 +36,18 @@ func HandleRoutes(provider *services.Provider) http.Handler {
mux.Use(mw.Recoverer) // recover on panic
mux.Use(provider.Sessions.Manager.Use) // use session storage
// TODO: move this code away from here
oauth2Config := &oauth2.Config{
ClientID: os.Getenv("OAUTH2_CLIENT_ID"),
ClientSecret: os.Getenv("OAUTH2_CLIENT_SECRET"),
Scopes: []string{"read_user"},
RedirectURL: os.Getenv("OAUTH2_REDIRECT_URL"),
Endpoint: oauth2.Endpoint{
AuthURL: os.Getenv("OAUTH2_AUTH_URL"),
TokenURL: os.Getenv("OAUTH2_TOKEN_URL"),
},
}
// we are serving the static files directly from the assets package
// this either means we use the embedded files, or live-load
// from the file system (if `--tags="dev"` is used).
@ -54,8 +67,8 @@ func HandleRoutes(provider *services.Provider) http.Handler {
r.HandleFunc("/", http.RedirectHandler("certs", http.StatusFound).ServeHTTP)
r.Route("/login", func(r chi.Router) {
r.Get("/", handlers.GetLoginHandler(provider))
r.Get("/oauth2/redirect", handlers.OAuth2Endpoint(provider))
r.Get("/", handlers.GetLoginHandler(provider, oauth2Config))
r.Get("/oauth2/redirect", handlers.OAuth2Endpoint(provider, oauth2Config))
})
r.Route("/certs", func(r chi.Router) {
@ -63,7 +76,10 @@ func HandleRoutes(provider *services.Provider) http.Handler {
r.Get("/", handlers.ListClientsHandler(provider))
r.Post("/new", handlers.CreateCertHandler(provider))
r.HandleFunc("/download/{name}", handlers.DownloadCertHandler(provider))
r.Post("/delete/{name}", handlers.DeleteCertHandler(provider))
})
r.Get("/unconfigured-backend", handlers.NotFoundHandler)
})
// what should happen if no route matches

169
services/clientstore.go Normal file
View File

@ -0,0 +1,169 @@
package services
import (
"encoding/json"
"errors"
"io/ioutil"
"log"
"os"
"sync"
"git.klink.asia/paul/certman/models"
)
var (
ErrNilCertificate = errors.New("Trying to store nil certificate")
ErrDuplicate = errors.New("Client with that name already exists")
ErrUserNotExists = errors.New("User does not exist")
ErrClientNotExists = errors.New("Client does not exist")
)
type ClientCollection struct {
sync.RWMutex
path string
Clients map[uint]*models.Client
UserIndex map[string]map[string]uint
LastID uint
}
func NewClientCollection(path string) *ClientCollection {
// empty collection
var clientCollection = ClientCollection{
path: path,
Clients: make(map[uint]*models.Client),
UserIndex: make(map[string]map[string]uint),
LastID: 0,
}
raw, err := ioutil.ReadFile(path)
if os.IsNotExist(err) {
return &clientCollection
} else if err != nil {
log.Println(err)
return &clientCollection
}
if err := json.Unmarshal(raw, &clientCollection); err != nil {
log.Println(err)
}
return &clientCollection
}
// CreateClient inserts a client into the datastore
func (db *ClientCollection) CreateClient(client *models.Client) error {
db.Lock()
defer db.Unlock()
if client == nil {
return ErrNilCertificate
}
db.LastID++ // increment Id
client.ID = db.LastID
userIndex, exists := db.UserIndex[client.User]
if !exists {
// create user index if not exists
db.UserIndex[client.User] = make(map[string]uint)
userIndex = db.UserIndex[client.User]
}
if _, exists = userIndex[client.Name]; exists {
return ErrDuplicate
}
// if all went well, add client and set the index
db.Clients[client.ID] = client
userIndex[client.Name] = client.ID
db.UserIndex[client.User] = userIndex
return db.save()
}
// ListClientsForUser returns a slice of 'count' client for user 'user', starting at 'offset'
func (db *ClientCollection) ListClientsForUser(user string) ([]*models.Client, error) {
db.RLock()
defer db.RUnlock()
var clients = make([]*models.Client, 0)
userIndex, exists := db.UserIndex[user]
if !exists {
return nil, errors.New("user does not exist")
}
for _, clientID := range userIndex {
clients = append(clients, db.Clients[clientID])
}
return clients, nil
}
// GetClientByID returns a single client by ID
func (db *ClientCollection) GetClientByID(id uint) (*models.Client, error) {
client, exists := db.Clients[id]
if !exists {
return nil, ErrClientNotExists
}
return client, nil
}
// GetClientByNameUser returns a single client by ID
func (db *ClientCollection) GetClientByNameUser(name, user string) (*models.Client, error) {
db.RLock()
defer db.RUnlock()
userIndex, exists := db.UserIndex[user]
if !exists {
return nil, ErrUserNotExists
}
clientID, exists := userIndex[name]
if !exists {
return nil, ErrClientNotExists
}
client, exists := db.Clients[clientID]
if !exists {
return nil, ErrClientNotExists
}
return client, nil
}
// DeleteClient removes a client from the datastore
func (db *ClientCollection) DeleteClient(id uint) error {
db.Lock()
defer db.Unlock()
client, exists := db.Clients[id]
if !exists {
return nil // nothing to delete
}
userIndex, exists := db.UserIndex[client.User]
if !exists {
return ErrUserNotExists
}
delete(userIndex, client.Name) // delete client index
// if index is now empty, delete the user entry
if len(userIndex) == 0 {
delete(db.UserIndex, client.User)
}
// finally delete the client
delete(db.Clients, id)
return db.save()
}
func (c *ClientCollection) save() error {
collectionJSON, _ := json.Marshal(c)
return ioutil.WriteFile(c.path, collectionJSON, 0600)
}

View File

@ -1,95 +0,0 @@
package services
import (
"errors"
"log"
"git.klink.asia/paul/certman/models"
"github.com/jinzhu/gorm"
)
// Error Definitions
var (
ErrNotImplemented = errors.New("Not implemented")
)
type DBConfig struct {
Type string
DSN string
Log bool
}
// DB is a wrapper around gorm.DB to provide custom methods
type DB struct {
gorm *gorm.DB
conf *DBConfig
}
func NewDB(conf *DBConfig) *DB {
// Establish connection
db, err := gorm.Open(conf.Type, conf.DSN)
if err != nil {
log.Fatalf("Could not open database: %s", err.Error())
}
// Migrate models
db.AutoMigrate(models.Client{})
db.LogMode(conf.Log)
return &DB{
gorm: db,
conf: conf,
}
}
// CountClients returns the number of clients in the datastore
func (db *DB) CountClients() (uint, error) {
var count uint
err := db.gorm.Find(&models.Client{}).Count(&count).Error
return count, err
}
// CreateClient inserts a client into the datastore
func (db *DB) CreateClient(client *models.Client) error {
err := db.gorm.Create(&client).Error
return err
}
// ListClients returns a slice of 'count' client, starting at 'offset'
func (db *DB) ListClients(count, offset int) ([]*models.Client, error) {
var clients = make([]*models.Client, 0)
err := db.gorm.Find(&clients).Limit(count).Offset(offset).Error
return clients, err
}
// ListClientsForUser returns a slice of 'count' client for user 'user', starting at 'offset'
func (db *DB) ListClientsForUser(user string, count, offset int) ([]*models.Client, error) {
var clients = make([]*models.Client, 0)
err := db.gorm.Find(&clients).Where("user = ?", user).Limit(count).Offset(offset).Error
return clients, err
}
// GetClientByID returns a single client by ID
func (db *DB) GetClientByID(id uint) (*models.Client, error) {
var client models.Client
err := db.gorm.Where("id = ?", id).First(&client).Error
return &client, err
}
// GetClientByNameUser returns a single client by ID
func (db *DB) GetClientByNameUser(name, user string) (*models.Client, error) {
var client models.Client
err := db.gorm.Where("name = ?", name).Where("user = ?", user).First(&client).Error
return &client, err
}
// DeleteClient removes a client from the datastore
func (db *DB) DeleteClient(id uint) error {
err := db.gorm.Where("id = ?", id).Delete(&models.Client{}).Error
return err
}

View File

@ -1,20 +1,20 @@
package services
type Config struct {
DB *DBConfig
Sessions *SessionsConfig
CollectionPath string
Sessions *SessionsConfig
}
type Provider struct {
DB *DB
Sessions *Sessions
ClientCollection *ClientCollection
Sessions *Sessions
}
// NewProvider returns the ServiceProvider
func NewProvider(conf *Config) *Provider {
var provider = &Provider{}
provider.DB = NewDB(conf.DB)
provider.ClientCollection = NewClientCollection(conf.CollectionPath)
provider.Sessions = NewSessions(conf.Sessions)
return provider

View File

@ -52,7 +52,7 @@ func NewSessions(conf *SessionsConfig) *Sessions {
func (store *Sessions) GetUsername(req *http.Request) string {
if store == nil {
// if store was not initialized, all requests fail
log.Println("Zero pointer when checking session for username")
log.Println("Nil pointer when checking session for username")
return ""
}