start to move session stores into there own key value in memory store
This commit is contained in:
@@ -1,8 +0,0 @@
|
|||||||
package session
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var ErrSessionNotFound = errors.New("session not found")
|
|
||||||
var ErrSessionAlreadyExists = errors.New("session already exists")
|
|
||||||
var ErrSessionExpired = errors.New("session expired")
|
|
||||||
var ErrSessionBackend = errors.New("session backend")
|
|
||||||
@@ -2,7 +2,6 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,9 +15,3 @@ func GenerateSessionToken(length int) (string, error) {
|
|||||||
token := base64.RawURLEncoding.EncodeToString(b)
|
token := base64.RawURLEncoding.EncodeToString(b)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// more helper
|
|
||||||
func hashSession(session_id string) string {
|
|
||||||
tokenEncoded := sha256.Sum256([]byte(session_id))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(tokenEncoded[:])
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -43,10 +43,6 @@ func (m *MemoryStore) Get(sessionID string) (*SessionData, error) {
|
|||||||
if exists == false {
|
if exists == false {
|
||||||
return nil, ErrSessionNotFound
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
if time.Now().After(data.ExpiresAt) {
|
|
||||||
_ = m.Delete(sessionID) // ignore error
|
|
||||||
return nil, ErrSessionExpired
|
|
||||||
}
|
|
||||||
copy := *data
|
copy := *data
|
||||||
return ©, nil
|
return ©, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,13 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"astraltech.xyz/accountmanager/src/logging"
|
"astraltech.xyz/accountmanager/src/logging"
|
||||||
|
"astraltech.xyz/accountmanager/src/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SessionCookieName = "session_token"
|
const SessionCookieName = "session_token"
|
||||||
|
|
||||||
type SessionManager struct {
|
type SessionManager struct {
|
||||||
store SessionStore
|
store store.Store[string, *SessionData]
|
||||||
}
|
}
|
||||||
|
|
||||||
var instance *SessionManager
|
var instance *SessionManager
|
||||||
@@ -93,6 +94,7 @@ func (manager *SessionManager) GetSession(r *http.Request) (*SessionData, error)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrSessionNotFound
|
return nil, ErrSessionNotFound
|
||||||
}
|
}
|
||||||
|
// TODO: handle token expiry here
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,10 +7,3 @@ type SessionData struct {
|
|||||||
CSRFToken string `json:"csrftoken"`
|
CSRFToken string `json:"csrftoken"`
|
||||||
ExpiresAt time.Time `json:"expiresat"`
|
ExpiresAt time.Time `json:"expiresat"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SessionStore interface {
|
|
||||||
Create(sessionID string, session *SessionData) error
|
|
||||||
Get(sessionID string) (*SessionData, error)
|
|
||||||
Update(sessionID string, session *SessionData) error
|
|
||||||
Delete(sessionID string) error
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,10 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
// A simple key value store that can either just be single instance in memory or a redis server (for now)
|
||||||
|
|
||||||
|
type Store[Value any] interface {
|
||||||
|
Create(key string, value Value) error
|
||||||
|
Get(key string) (Value, error)
|
||||||
|
Update(key string, session Value) error
|
||||||
|
Delete(key string) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var ErrKeyNotFound = errors.New("Key not found")
|
||||||
|
var ErrKeyAlreadyExists = errors.New("Key already exists")
|
||||||
|
var ErrKeyExpired = errors.New("Key expired")
|
||||||
|
var ErrKeyBackend = errors.New("Key backend")
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
func HashKey(key string) string {
|
||||||
|
tokenEncoded := sha256.Sum256([]byte(key))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(tokenEncoded[:])
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package session
|
package store
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,16 +9,17 @@ import (
|
|||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RedisStore struct {
|
type RedisStore[Value any] struct {
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
prefix string
|
||||||
}
|
}
|
||||||
|
|
||||||
func RedisHash(sessionID string) string {
|
func (m *RedisStore[Value]) RedisHash(value_to_hash string) string {
|
||||||
return "selfservicedashboard_" + hashSession(sessionID)
|
return m.prefix + HashKey(value_to_hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRedisStore() *RedisStore {
|
func NewRedisStore[Value any]() *RedisStore[Value] {
|
||||||
logging.Debug("Creating new redis session store")
|
logging.Debug("Creating new redis session store")
|
||||||
|
|
||||||
// this will be replaced with a URL that can be parsed in the config file
|
// this will be replaced with a URL that can be parsed in the config file
|
||||||
@@ -38,82 +39,78 @@ func NewRedisStore() *RedisStore {
|
|||||||
logging.Infof("Successfully connected to redis server %s", redis_server)
|
logging.Infof("Successfully connected to redis server %s", redis_server)
|
||||||
}
|
}
|
||||||
|
|
||||||
store := &RedisStore{
|
store := &RedisStore[Value]{
|
||||||
client: rdb,
|
client: rdb,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RedisStore) Create(sessionID string, session *SessionData) (err error) {
|
func (m *RedisStore[Value]) Create(key string, value Value) (err error) {
|
||||||
hashedSession := RedisHash(sessionID)
|
hashedSession := m.RedisHash(key)
|
||||||
|
|
||||||
data, err := json.Marshal(*session)
|
data, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
created, err := m.client.SetNX(m.ctx, hashedSession, data, time.Hour).Result()
|
created, err := m.client.SetNX(m.ctx, hashedSession, data, time.Hour).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if !created {
|
if !created {
|
||||||
return ErrSessionAlreadyExists
|
return ErrKeyAlreadyExists
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (m *RedisStore) Get(sessionID string) (*SessionData, error) {
|
func (m *RedisStore[Value]) Get(sessionID string) (Value, error) {
|
||||||
hashed := RedisHash(sessionID)
|
hashed := m.RedisHash(sessionID)
|
||||||
|
var session_data Value
|
||||||
|
|
||||||
data, err := m.client.Get(m.ctx, hashed).Bytes()
|
data, err := m.client.Get(m.ctx, hashed).Bytes()
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, ErrSessionNotFound
|
return session_data, ErrKeyNotFound
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return nil, ErrSessionBackend
|
return session_data, ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
var session_data SessionData
|
if err := json.Unmarshal(data, session_data); err != nil {
|
||||||
if err := json.Unmarshal(data, &session_data); err != nil {
|
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return nil, ErrSessionBackend
|
return session_data, ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(session_data.ExpiresAt) {
|
return session_data, nil
|
||||||
_ = m.Delete(sessionID)
|
|
||||||
return nil, ErrSessionBackend
|
|
||||||
}
|
|
||||||
return &session_data, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RedisStore) Update(sessionID string, session *SessionData) error {
|
func (m *RedisStore[Value]) Update(key string, value Value) error {
|
||||||
hashedSession := RedisHash(sessionID)
|
hashedSession := m.RedisHash(key)
|
||||||
|
|
||||||
data, err := json.Marshal(*session)
|
data, err := json.Marshal(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
updated, err := m.client.SetXX(m.ctx, hashedSession, data, time.Hour).Result()
|
updated, err := m.client.SetXX(m.ctx, hashedSession, data, time.Hour).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
|
|
||||||
if !updated {
|
if !updated {
|
||||||
return ErrSessionNotFound
|
return ErrKeyNotFound
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *RedisStore) Delete(sessionID string) error {
|
func (m *RedisStore[Value]) Delete(sessionID string) error {
|
||||||
hashedSession := RedisHash(sessionID)
|
hashedSession := m.RedisHash(sessionID)
|
||||||
err := m.client.Del(m.ctx, hashedSession).Err()
|
err := m.client.Del(m.ctx, hashedSession).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.Error(err.Error())
|
logging.Error(err.Error())
|
||||||
return ErrSessionBackend
|
return ErrKeyBackend
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user