Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4cce7b7454 | |||
| 109199ea45 | |||
| 40429d7618 | |||
| 11c40a75ac | |||
| d162c32a57 | |||
| 2a97ec72be | |||
| 09e0683ae0 | |||
| f6016bbdb1 | |||
| 3b31adf3e2 | |||
| efdf9fdade | |||
| 4474986909 |
@@ -13,7 +13,12 @@
|
|||||||
},
|
},
|
||||||
"server_config": {
|
"server_config": {
|
||||||
"port": 8080,
|
"port": 8080,
|
||||||
"base_url": "https://profile.example.com"
|
"base_url": "https://profile.example.com",
|
||||||
|
"session_store": "redis",
|
||||||
|
"redis_config": {
|
||||||
|
"redis_url": "redis://localhost:6379/0",
|
||||||
|
"prefix": ""
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"email_config": {
|
"email_config": {
|
||||||
"username": "noreply",
|
"username": "noreply",
|
||||||
|
|||||||
@@ -21,9 +21,16 @@ type StyleConfig struct {
|
|||||||
LogoPath string `json:"logo_path"`
|
LogoPath string `json:"logo_path"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RedisConfig struct {
|
||||||
|
RedisURL string `json:"redis_url"`
|
||||||
|
Prefix string `json:"prefix"`
|
||||||
|
}
|
||||||
|
|
||||||
type WebserverConfig struct {
|
type WebserverConfig struct {
|
||||||
Port int `json:"port"`
|
Port int `json:"port"`
|
||||||
BaseURL string `json:"base_url"`
|
BaseURL string `json:"base_url"`
|
||||||
|
SessionStore string `json:"session_store"`
|
||||||
|
RedisConfigInfo RedisConfig `json:"redis_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmailConfig struct {
|
type EmailConfig struct {
|
||||||
|
|||||||
+10
-3
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"astraltech.xyz/accountmanager/src/logging"
|
"astraltech.xyz/accountmanager/src/logging"
|
||||||
|
"astraltech.xyz/accountmanager/src/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LoginPageData struct {
|
type LoginPageData struct {
|
||||||
@@ -32,9 +33,15 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
logging.Infof("New Login request for %s\n", username)
|
logging.Infof("New Login request for %s\n", username)
|
||||||
newUserData, err := authenticateUser(username, password)
|
newUserData, err := authenticateUser(username, password)
|
||||||
userDataMutex.Lock()
|
|
||||||
userData[username] = newUserData
|
userDataErr := userData.Create(username, newUserData)
|
||||||
userDataMutex.Unlock()
|
if userDataErr == store.ErrKeyAlreadyExists {
|
||||||
|
userData.Update(username, newUserData)
|
||||||
|
} else if userDataErr != nil {
|
||||||
|
logging.Error(userDataErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err == ErrPasswordExpired {
|
if err == ErrPasswordExpired {
|
||||||
http.Redirect(w, r, "/reset-password?token=this_is_the_only_token_that_works", http.StatusFound)
|
http.Redirect(w, r, "/reset-password?token=this_is_the_only_token_that_works", http.StatusFound)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
|
|||||||
+29
-11
@@ -7,7 +7,6 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
|
|
||||||
"astraltech.xyz/accountmanager/src/components"
|
"astraltech.xyz/accountmanager/src/components"
|
||||||
"astraltech.xyz/accountmanager/src/email"
|
"astraltech.xyz/accountmanager/src/email"
|
||||||
@@ -15,6 +14,7 @@ import (
|
|||||||
"astraltech.xyz/accountmanager/src/ldap"
|
"astraltech.xyz/accountmanager/src/ldap"
|
||||||
"astraltech.xyz/accountmanager/src/logging"
|
"astraltech.xyz/accountmanager/src/logging"
|
||||||
"astraltech.xyz/accountmanager/src/session"
|
"astraltech.xyz/accountmanager/src/session"
|
||||||
|
"astraltech.xyz/accountmanager/src/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -31,8 +31,7 @@ type UserData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
userData = make(map[string]*UserData)
|
userData store.KeyValueStore[*UserData]
|
||||||
userDataMutex sync.RWMutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrPasswordExpired = errors.New("Password expired")
|
var ErrPasswordExpired = errors.New("Password expired")
|
||||||
@@ -94,16 +93,22 @@ func profileHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data, err := userData.Get(sessionData.UserID)
|
||||||
|
if err != nil {
|
||||||
|
logging.Error(err.Error())
|
||||||
|
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if r.Method == http.MethodGet {
|
if r.Method == http.MethodGet {
|
||||||
tmpl := template.Must(template.ParseFiles("src/pages/profile_page.html"))
|
tmpl := template.Must(template.ParseFiles("src/pages/profile_page.html"))
|
||||||
userDataMutex.RLock()
|
|
||||||
tmpl.Execute(w, ProfileData{
|
tmpl.Execute(w, ProfileData{
|
||||||
Username: sessionData.UserID,
|
Username: sessionData.UserID,
|
||||||
Email: userData[sessionData.UserID].Email,
|
Email: data.Email,
|
||||||
DisplayName: userData[sessionData.UserID].DisplayName,
|
DisplayName: data.DisplayName,
|
||||||
CSRFToken: sessionData.CSRFToken,
|
CSRFToken: sessionData.CSRFToken,
|
||||||
})
|
})
|
||||||
userDataMutex.RUnlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -198,8 +203,10 @@ func changePasswordHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(`{"success": true}`))
|
w.Write([]byte(`{"success": true}`))
|
||||||
|
|
||||||
|
user_data, err := userData.Get(sessionData.UserID)
|
||||||
|
|
||||||
data := map[string]any{
|
data := map[string]any{
|
||||||
"Username": userData[sessionData.UserID].DisplayName,
|
"Username": user_data.DisplayName,
|
||||||
"ServiceName": "Astral Tech",
|
"ServiceName": "Astral Tech",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,13 +214,11 @@ func changePasswordHandler(w http.ResponseWriter, r *http.Request) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Errorf("Failed to load email template: %s", err.Error())
|
logging.Errorf("Failed to load email template: %s", err.Error())
|
||||||
}
|
}
|
||||||
noReplyEmail.SendEmail([]string{userData[sessionData.UserID].Email}, "Password expired", email_template)
|
noReplyEmail.SendEmail([]string{user_data.Email}, "Password expired", email_template)
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
logging.Info("Starting the server")
|
logging.Info("Starting the server")
|
||||||
sessionManager = session.GetSessionManager()
|
|
||||||
sessionManager.SetStoreType(session.Redis)
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
serverConfig, err = loadServerConfig("./data/config.json")
|
serverConfig, err = loadServerConfig("./data/config.json")
|
||||||
@@ -221,6 +226,19 @@ func main() {
|
|||||||
log.Fatal("Could not load server config")
|
log.Fatal("Could not load server config")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sessionManager = session.GetSessionManager()
|
||||||
|
if serverConfig.WebserverConfig.SessionStore == "in_memory" {
|
||||||
|
sessionManager.SetStoreType(session.InMemory)
|
||||||
|
userData = store.NewMemoryStore[*UserData]()
|
||||||
|
} else if serverConfig.WebserverConfig.SessionStore == "redis" {
|
||||||
|
sessionManager.SetStoreType(session.Redis, serverConfig.WebserverConfig.RedisConfigInfo.RedisURL, serverConfig.WebserverConfig.RedisConfigInfo.Prefix)
|
||||||
|
userData = store.NewRedisStore[*UserData](serverConfig.WebserverConfig.RedisConfigInfo.RedisURL, serverConfig.WebserverConfig.RedisConfigInfo.Prefix)
|
||||||
|
} else {
|
||||||
|
logging.Warnf("'%s' is an unknown session store type defaulting to in memory", serverConfig.WebserverConfig.SessionStore)
|
||||||
|
sessionManager.SetStoreType(session.InMemory)
|
||||||
|
userData = store.NewMemoryStore[*UserData]()
|
||||||
|
}
|
||||||
|
|
||||||
noReplyEmail = email.CreateEmailAccount(email.EmailAccountData{
|
noReplyEmail = email.CreateEmailAccount(email.EmailAccountData{
|
||||||
Username: serverConfig.EmailConfig.Username,
|
Username: serverConfig.EmailConfig.Username,
|
||||||
Password: serverConfig.EmailConfig.Password,
|
Password: serverConfig.EmailConfig.Password,
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var ErrSessionNotFound = errors.New("session not found")
|
|
||||||
var ErrSessionAlreadyExists = errors.New("session already exists")
|
|
||||||
var ErrSessionExpired = errors.New("session expired")
|
|
||||||
var ErrSessionBackend = errors.New("session backend")
|
|
||||||
@@ -2,7 +2,6 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,9 +15,3 @@ func GenerateSessionToken(length int) (string, error) {
|
|||||||
token := base64.RawURLEncoding.EncodeToString(b)
|
token := base64.RawURLEncoding.EncodeToString(b)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// more helper
|
|
||||||
func hashSession(session_id string) string {
|
|
||||||
tokenEncoded := sha256.Sum256([]byte(session_id))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(tokenEncoded[:])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,94 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"astraltech.xyz/accountmanager/src/logging"
|
|
||||||
"astraltech.xyz/accountmanager/src/worker"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MemoryStore struct {
|
|
||||||
sessions map[string]*SessionData
|
|
||||||
lock sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMemoryStore() *MemoryStore {
|
|
||||||
logging.Debug("Creating new in memory session store")
|
|
||||||
store := &MemoryStore{
|
|
||||||
sessions: make(map[string]*SessionData),
|
|
||||||
}
|
|
||||||
worker.CreateWorker(time.Minute*5, store.cleanup)
|
|
||||||
return store
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MemoryStore) Create(sessionID string, session *SessionData) (err error) {
|
|
||||||
hashedSession := hashSession(sessionID)
|
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
_, exist := m.sessions[hashedSession]
|
|
||||||
if exist {
|
|
||||||
return ErrSessionAlreadyExists
|
|
||||||
}
|
|
||||||
|
|
||||||
m.sessions[hashedSession] = session
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (m *MemoryStore) Get(sessionID string) (*SessionData, error) {
|
|
||||||
m.lock.RLock()
|
|
||||||
hashed := hashSession(sessionID)
|
|
||||||
data, exists := m.sessions[hashed]
|
|
||||||
m.lock.RUnlock()
|
|
||||||
if exists == false {
|
|
||||||
return nil, ErrSessionNotFound
|
|
||||||
}
|
|
||||||
if time.Now().After(data.ExpiresAt) {
|
|
||||||
_ = m.Delete(sessionID) // ignore error
|
|
||||||
return nil, ErrSessionExpired
|
|
||||||
}
|
|
||||||
copy := *data
|
|
||||||
return ©, nil
|
|
||||||
}
|
|
||||||
func (m *MemoryStore) Update(sessionID string, session *SessionData) error {
|
|
||||||
hashedSession := hashSession(sessionID)
|
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
_, exist := m.sessions[hashedSession]
|
|
||||||
if !exist {
|
|
||||||
return ErrSessionNotFound
|
|
||||||
}
|
|
||||||
m.sessions[hashedSession] = session
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MemoryStore) cleanup() {
|
|
||||||
logging.Debug("Cleaning up memory store sessions")
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
|
|
||||||
deleted := 0
|
|
||||||
for id, session := range m.sessions {
|
|
||||||
if now.After(session.ExpiresAt) {
|
|
||||||
delete(m.sessions, id)
|
|
||||||
deleted = deleted + 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logging.Infof("Cleaned up %d stale sessions", deleted)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MemoryStore) Delete(sessionID string) error {
|
|
||||||
hashedSession := hashSession(sessionID)
|
|
||||||
|
|
||||||
m.lock.Lock()
|
|
||||||
defer m.lock.Unlock()
|
|
||||||
_, exist := m.sessions[hashedSession]
|
|
||||||
if !exist {
|
|
||||||
return ErrSessionNotFound
|
|
||||||
}
|
|
||||||
delete(m.sessions, hashedSession)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -6,12 +6,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"astraltech.xyz/accountmanager/src/logging"
|
"astraltech.xyz/accountmanager/src/logging"
|
||||||
|
"astraltech.xyz/accountmanager/src/store"
|
||||||
|
"astraltech.xyz/accountmanager/src/worker"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SessionCookieName = "session_token"
|
const SessionCookieName = "session_token"
|
||||||
|
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
store SessionStore
|
store store.KeyValueStore[*SessionData]
|
||||||
}
|
}
|
||||||
|
|
||||||
var instance *SessionManager
|
var instance *SessionManager
|
||||||
@@ -31,17 +33,23 @@ func GetSessionManager() *SessionManager {
|
|||||||
return instance
|
return instance
|
||||||
}
|
}
|
||||||
|
|
||||||
func (manager *SessionManager) SetStoreType(storeType StoreType) {
|
func (manager *SessionManager) SetStoreType(storeType StoreType, params ...any) {
|
||||||
logging.Infof("Changing session manager store type")
|
logging.Infof("Changing session manager store type")
|
||||||
switch storeType {
|
switch storeType {
|
||||||
case InMemory:
|
case InMemory:
|
||||||
{
|
{
|
||||||
manager.store = NewMemoryStore()
|
manager.store = store.NewMemoryStore[*SessionData]()
|
||||||
|
worker.CreateWorker(time.Minute*5, func() {
|
||||||
|
inMemStore, _ := manager.store.(*store.MemoryStore[*SessionData])
|
||||||
|
cleanupInMemoryStore(inMemStore)
|
||||||
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case Redis:
|
case Redis:
|
||||||
{
|
{
|
||||||
manager.store = NewRedisStore()
|
url, _ := params[0].(string)
|
||||||
|
prefix, _ := params[1].(string)
|
||||||
|
manager.store = store.NewRedisStore[*SessionData](url, prefix)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,10 +79,10 @@ func (manager *SessionManager) CreateSession(userID string) (cookie *http.Cookie
|
|||||||
Name: SessionCookieName,
|
Name: SessionCookieName,
|
||||||
Value: token,
|
Value: token,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
HttpOnly: true, // Essential: prevents JS access
|
HttpOnly: true,
|
||||||
Secure: true, // Set to TRUE in production (HTTPS)
|
Secure: true,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
MaxAge: 3600, // 1 hour
|
MaxAge: 3600,
|
||||||
}
|
}
|
||||||
return newCookie, nil
|
return newCookie, nil
|
||||||
}
|
}
|
||||||
@@ -93,9 +101,32 @@ func (manager *SessionManager) GetSession(r *http.Request) (*SessionData, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrSessionNotFound
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if time.Now().After(data.ExpiresAt) {
|
||||||
|
_ = manager.store.Delete(token)
|
||||||
|
return nil, ErrSessionExpired
|
||||||
|
}
|
||||||
|
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cleanupInMemoryStore(m *store.MemoryStore[*SessionData]) {
|
||||||
|
logging.Debug("Cleaning up memory store sessions")
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
m.Lock.Lock()
|
||||||
|
defer m.Lock.Unlock()
|
||||||
|
|
||||||
|
deleted := 0
|
||||||
|
for id, session := range m.Sessions {
|
||||||
|
if now.After(session.ExpiresAt) {
|
||||||
|
delete(m.Sessions, id)
|
||||||
|
deleted = deleted + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logging.Infof("Cleaned up %d stale sessions", deleted)
|
||||||
|
}
|
||||||
|
|
||||||
func (manager *SessionManager) DeleteSession(sessionId string) error {
|
func (manager *SessionManager) DeleteSession(sessionId string) error {
|
||||||
return manager.store.Delete(sessionId)
|
return manager.store.Delete(sessionId)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
package session
|
package session
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrSessionNotFound = errors.New("Session not found")
|
||||||
|
var ErrSessionAlreadyExists = errors.New("Session already exists")
|
||||||
|
var ErrSessionExpired = errors.New("Session expired")
|
||||||
|
var ErrSessionBackend = errors.New("Session backend")
|
||||||
|
|
||||||
type SessionData struct {
|
type SessionData struct {
|
||||||
UserID string `json:"userid"`
|
UserID string `json:"userid"`
|
||||||
CSRFToken string `json:"csrftoken"`
|
CSRFToken string `json:"csrftoken"`
|
||||||
ExpiresAt time.Time `json:"expiresat"`
|
ExpiresAt time.Time `json:"expiresat"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionStore interface {
|
|
||||||
Create(sessionID string, session *SessionData) error
|
|
||||||
Get(sessionID string) (*SessionData, error)
|
|
||||||
Update(sessionID string, session *SessionData) error
|
|
||||||
Delete(sessionID string) error
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
// A simple key value store that can either just be single instance in memory or a redis server (for now)
|
||||||
|
|
||||||
|
type KeyValueStore[Value any] interface {
|
||||||
|
Create(key string, value Value) error
|
||||||
|
Get(key string) (Value, error)
|
||||||
|
Update(key string, session Value) error
|
||||||
|
Delete(key string) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrKeyNotFound = errors.New("Key not found")
|
||||||
|
var ErrKeyAlreadyExists = errors.New("Key already exists")
|
||||||
|
var ErrKeyExpired = errors.New("Key expired")
|
||||||
|
var ErrKeyBackend = errors.New("Key backend")
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HashKey(key string) string {
|
||||||
|
tokenEncoded := sha256.Sum256([]byte(key))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(tokenEncoded[:])
|
||||||
|
}
|
||||||
@@ -0,0 +1,73 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"astraltech.xyz/accountmanager/src/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MemoryStore[Value any] struct {
|
||||||
|
Sessions map[string]Value
|
||||||
|
Lock sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemoryStore[Value any]() *MemoryStore[Value] {
|
||||||
|
logging.Debug("Creating new in memory session store")
|
||||||
|
store := &MemoryStore[Value]{
|
||||||
|
Sessions: make(map[string]Value),
|
||||||
|
}
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore[Value]) Create(key string, session Value) (err error) {
|
||||||
|
hashedkey := HashKey(key)
|
||||||
|
|
||||||
|
m.Lock.Lock()
|
||||||
|
defer m.Lock.Unlock()
|
||||||
|
_, exist := m.Sessions[hashedkey]
|
||||||
|
if exist {
|
||||||
|
return ErrKeyAlreadyExists
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Sessions[hashedkey] = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore[Value]) Get(key string) (Value, error) {
|
||||||
|
var data Value
|
||||||
|
|
||||||
|
m.Lock.RLock()
|
||||||
|
hashedkey := HashKey(key)
|
||||||
|
data, exists := m.Sessions[hashedkey]
|
||||||
|
m.Lock.RUnlock()
|
||||||
|
if exists == false {
|
||||||
|
return data, ErrKeyNotFound
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore[Value]) Update(sessionID string, session Value) error {
|
||||||
|
hashedkey := HashKey(sessionID)
|
||||||
|
|
||||||
|
m.Lock.Lock()
|
||||||
|
defer m.Lock.Unlock()
|
||||||
|
_, exist := m.Sessions[hashedkey]
|
||||||
|
if !exist {
|
||||||
|
return ErrKeyNotFound
|
||||||
|
}
|
||||||
|
m.Sessions[hashedkey] = session
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MemoryStore[Value]) Delete(sessionID string) error {
|
||||||
|
hashedkey := HashKey(sessionID)
|
||||||
|
|
||||||
|
m.Lock.Lock()
|
||||||
|
defer m.Lock.Unlock()
|
||||||
|
_, exist := m.Sessions[hashedkey]
|
||||||
|
if !exist {
|
||||||
|
return ErrKeyNotFound
|
||||||
|
}
|
||||||
|
delete(m.Sessions, hashedkey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package session
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,16 +9,18 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RedisStore struct {
|
type RedisStore[Value any] struct {
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRedisStore() *RedisStore {
|
func (m *RedisStore[Value]) RedisHash(value_to_hash string) string {
|
||||||
logging.Debug("Creating new redis session store")
|
return m.prefix + HashKey(value_to_hash)
|
||||||
|
}
|
||||||
|
|
||||||
// this will be replaced with a URL that can be parsed in the config file
|
func NewRedisStore[Value any](redis_server string, prefix string) *RedisStore[Value] {
|
||||||
redis_server := "redis://localhost:6379/0"
|
logging.Debug("Creating new redis session store")
|
||||||
|
|
||||||
opts, err := redis.ParseURL(redis_server)
|
opts, err := redis.ParseURL(redis_server)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -34,84 +36,79 @@ func NewRedisStore() *RedisStore {
|
|||||||
logging.Infof("Successfully connected to redis server %s", redis_server)
|
logging.Infof("Successfully connected to redis server %s", redis_server)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := &RedisStore{
|
store := &RedisStore[Value]{
|
||||||
client: rdb,
|
client: rdb,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
prefix: prefix,
|
||||||
}
|
}
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
// return rdb.Set(ctx, key, data, 0).Err()
|
func (m *RedisStore[Value]) Create(key string, value Value) (err error) {
|
||||||
|
hashedSession := m.RedisHash(key)
|
||||||
|
|
||||||
func (m *RedisStore) Create(sessionID string, session *SessionData) (err error) {
|
data, err := json.Marshal(value)
|
||||||
hashedSession := hashSession(sessionID)
|
|
||||||
|
|
||||||
data, err := json.Marshal(*session)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := m.client.SetNX(m.ctx, hashedSession, data, time.Hour).Result()
|
created, err := m.client.SetNX(m.ctx, hashedSession, data, time.Hour).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if !created {
|
if !created {
|
||||||
return ErrSessionAlreadyExists
|
return ErrKeyAlreadyExists
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *RedisStore) Get(sessionID string) (*SessionData, error) {
|
func (m *RedisStore[Value]) Get(sessionID string) (Value, error) {
|
||||||
hashed := hashSession(sessionID)
|
hashed := m.RedisHash(sessionID)
|
||||||
|
var session_data Value
|
||||||
|
|
||||||
data, err := m.client.Get(m.ctx, hashed).Bytes()
|
data, err := m.client.Get(m.ctx, hashed).Bytes()
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, ErrSessionNotFound
|
return session_data, ErrKeyNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return nil, ErrSessionBackend
|
return session_data, ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
var session_data SessionData
|
|
||||||
if err := json.Unmarshal(data, &session_data); err != nil {
|
if err := json.Unmarshal(data, &session_data); err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return nil, ErrSessionBackend
|
return session_data, ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(session_data.ExpiresAt) {
|
return session_data, nil
|
||||||
_ = m.Delete(sessionID)
|
|
||||||
return nil, ErrSessionBackend
|
|
||||||
}
|
|
||||||
return &session_data, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RedisStore) Update(sessionID string, session *SessionData) error {
|
func (m *RedisStore[Value]) Update(key string, value Value) error {
|
||||||
hashedSession := hashSession(sessionID)
|
hashedSession := m.RedisHash(key)
|
||||||
|
|
||||||
data, err := json.Marshal(*session)
|
data, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := m.client.SetXX(m.ctx, hashedSession, data, time.Hour).Result()
|
updated, err := m.client.SetXX(m.ctx, hashedSession, data, time.Hour).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if !updated {
|
if !updated {
|
||||||
return ErrSessionNotFound
|
return ErrKeyNotFound
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RedisStore) Delete(sessionID string) error {
|
func (m *RedisStore[Value]) Delete(sessionID string) error {
|
||||||
hashedSession := hashSession(sessionID)
|
hashedSession := m.RedisHash(sessionID)
|
||||||
err := m.client.Del(m.ctx, hashedSession).Err()
|
err := m.client.Del(m.ctx, hashedSession).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user