6 Commits

7 changed files with 273 additions and 160 deletions

View File

@@ -13,21 +13,27 @@ import (
"time" "time"
"astraltech.xyz/accountmanager/src/logging" "astraltech.xyz/accountmanager/src/logging"
"astraltech.xyz/accountmanager/src/session"
) )
var ( var (
ldapServer *LDAPServer ldapServer *LDAPServer
ldapServerMutex sync.Mutex ldapServerMutex sync.Mutex
serverConfig *ServerConfig serverConfig *ServerConfig
sessionManager *session.SessionManager
) )
type UserData struct { type UserData struct {
isAuth bool isAuth bool
Username string
DisplayName string DisplayName string
Email string Email string
} }
var (
userData = make(map[string]UserData)
userDataMutex sync.RWMutex
)
var ( var (
photoCreatedTimestamp = make(map[string]time.Time) photoCreatedTimestamp = make(map[string]time.Time)
photoCreatedMutex sync.Mutex photoCreatedMutex sync.Mutex
@@ -82,14 +88,13 @@ func authenticateUser(username, password string) (UserData, error) {
entry := userSearch.LDAPSearch.Entries[0] entry := userSearch.LDAPSearch.Entries[0]
user := UserData{ user := UserData{
isAuth: true, isAuth: true,
Username: username,
DisplayName: entry.GetAttributeValue("displayName"), DisplayName: entry.GetAttributeValue("displayName"),
Email: entry.GetAttributeValue("mail"), Email: entry.GetAttributeValue("mail"),
} }
photoData := entry.GetRawAttributeValue("jpegphoto") photoData := entry.GetRawAttributeValue("jpegphoto")
if len(photoData) > 0 { if len(photoData) > 0 {
createUserPhoto(user.Username, photoData) createUserPhoto(username, photoData)
} }
return user, nil return user, nil
} }
@@ -118,15 +123,19 @@ func loginHandler(w http.ResponseWriter, r *http.Request) {
password := r.FormValue("password") password := r.FormValue("password")
logging.Infof("New Login request for %s\n", username) logging.Infof("New Login request for %s\n", username)
userData, err := authenticateUser(username, password) newUserData, err := authenticateUser(username, password)
userDataMutex.Lock()
userData[username] = newUserData
userDataMutex.Unlock()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
tmpl.Execute(w, LoginPageData{IsHiddenClassList: ""}) tmpl.Execute(w, LoginPageData{IsHiddenClassList: ""})
} else { } else {
if userData.isAuth == true { if newUserData.isAuth == true {
cookie := createSession(&userData) cookie, err := sessionManager.CreateSession(username)
if cookie == nil { if err != nil {
http.Error(w, "Session error", 500) logging.Error(err.Error())
http.Error(w, "Session error", http.StatusInternalServerError)
return return
} }
http.SetCookie(w, cookie) http.SetCookie(w, cookie)
@@ -147,20 +156,23 @@ type ProfileData struct {
func profileHandler(w http.ResponseWriter, r *http.Request) { func profileHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
exist, sessionData := validateSession(r) sessionData, err := sessionManager.GetSession(r)
if !exist { if err != nil {
logging.Error(err.Error())
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
return return
} }
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
tmpl := template.Must(template.ParseFiles("src/pages/profile_page.html")) tmpl := template.Must(template.ParseFiles("src/pages/profile_page.html"))
userDataMutex.RLock()
tmpl.Execute(w, ProfileData{ tmpl.Execute(w, ProfileData{
Username: sessionData.data.Username, Username: sessionData.UserID,
Email: sessionData.data.Email, Email: userData[sessionData.UserID].Email,
DisplayName: sessionData.data.DisplayName, DisplayName: userData[sessionData.UserID].DisplayName,
CSRFToken: sessionData.CSRFToken, CSRFToken: sessionData.CSRFToken,
}) })
userDataMutex.RUnlock()
return return
} }
} }
@@ -224,27 +236,29 @@ func logoutHandler(w http.ResponseWriter, r *http.Request) {
} }
token := cookie.Value token := cookie.Value
exist, sessionData := validateSession(r) sessionData, err := sessionManager.GetSession(r)
if exist { if err != nil {
logging.Error(err.Error())
}
if r.FormValue("csrf_token") != sessionData.CSRFToken { if r.FormValue("csrf_token") != sessionData.CSRFToken {
http.Error(w, "Unable to log user out", http.StatusForbidden) http.Error(w, "Unable to log user out", http.StatusForbidden)
logging.Debugf("%s attempted to logout with invalid csrf token", sessionData.data.Username) logging.Debugf("%s attempted to logout with invalid csrf token", sessionData.UserID)
return return
} }
} logging.Infof("handling logout event for %s", sessionData.UserID)
logging.Infof("handling logout event for %s", sessionData.data.Username)
deleteSession(token) sessionManager.DeleteSession(token)
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
} }
func uploadPhotoHandler(w http.ResponseWriter, r *http.Request) { func uploadPhotoHandler(w http.ResponseWriter, r *http.Request) {
exist, sessionData := validateSession(r) sessionData, err := sessionManager.GetSession(r)
if !exist { if err != nil {
logging.Error(err.Error())
http.Redirect(w, r, "/login", http.StatusSeeOther) http.Redirect(w, r, "/login", http.StatusSeeOther)
return return
} }
err := r.ParseMultipartForm(10 << 20) // 10MB limit err = r.ParseMultipartForm(10 << 20) // 10MB limit
if err != nil { if err != nil {
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
return return
@@ -272,11 +286,11 @@ func uploadPhotoHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Failed to read file", http.StatusInternalServerError) http.Error(w, "Failed to read file", http.StatusInternalServerError)
return return
} }
userDN := fmt.Sprintf("uid=%s,cn=users,cn=accounts,%s", sessionData.data.Username, serverConfig.LDAPConfig.BaseDN) userDN := fmt.Sprintf("uid=%s,cn=users,cn=accounts,%s", sessionData.UserID, serverConfig.LDAPConfig.BaseDN)
ldapServerMutex.Lock() ldapServerMutex.Lock()
defer ldapServerMutex.Unlock() defer ldapServerMutex.Unlock()
modifyLDAPAttribute(ldapServer, userDN, "jpegphoto", []string{string(data)}) modifyLDAPAttribute(ldapServer, userDN, "jpegphoto", []string{string(data)})
createUserPhoto(sessionData.data.Username, data) createUserPhoto(sessionData.UserID, data)
} }
func faviconHandler(w http.ResponseWriter, r *http.Request) { func faviconHandler(w http.ResponseWriter, r *http.Request) {
@@ -289,37 +303,18 @@ func logoHandler(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, serverConfig.StyleConfig.LogoPath) http.ServeFile(w, r, serverConfig.StyleConfig.LogoPath)
} }
func cleanupSessions() {
logging.Debug("Cleaning up stale session\n")
sessionMutex.Lock()
sessions_to_delete := []string{}
for session_token, session_data := range sessions {
timeUntilRemoval := time.Minute * 5
if session_data.loggedIn {
timeUntilRemoval = time.Hour
}
if time.Since(session_data.timeCreated) > timeUntilRemoval {
sessions_to_delete = append(sessions_to_delete, session_token)
}
}
sessionMutex.Unlock()
for _, session_id := range sessions_to_delete {
deleteSession(session_id)
}
}
func changePasswordHandler(w http.ResponseWriter, r *http.Request) { func changePasswordHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
exist, sessionData := validateSession(r) sessionData, err := sessionManager.GetSession(r)
if !exist { if err != nil {
logging.Error(err.Error())
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"success": false, "error": "Not authenticated"}`)) w.Write([]byte(`{"success": false, "error": "Not authenticated"}`))
return return
} }
err := r.ParseMultipartForm(10 << 20) err = r.ParseMultipartForm(10 << 20)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"success": false, "error": "Bad request"}`)) w.Write([]byte(`{"success": false, "error": "Bad request"}`))
@@ -344,7 +339,7 @@ func changePasswordHandler(w http.ResponseWriter, r *http.Request) {
userDN := fmt.Sprintf( userDN := fmt.Sprintf(
"uid=%s,cn=users,cn=accounts,%s", "uid=%s,cn=users,cn=accounts,%s",
sessionData.data.Username, sessionData.UserID,
serverConfig.LDAPConfig.BaseDN, serverConfig.LDAPConfig.BaseDN,
) )
@@ -368,6 +363,7 @@ func changePasswordHandler(w http.ResponseWriter, r *http.Request) {
func main() { func main() {
logging.Info("Starting the server") logging.Info("Starting the server")
sessionManager = session.CreateSessionManager(session.InMemory)
var err error = nil var err error = nil
blankPhotoData, err = ReadFile("static/blank_profile.jpg") blankPhotoData, err = ReadFile("static/blank_profile.jpg")
@@ -385,7 +381,6 @@ func main() {
ldapServerMutex.Unlock() ldapServerMutex.Unlock()
defer closeLDAPServer(ldapServer) defer closeLDAPServer(ldapServer)
createWorker(time.Minute*5, cleanupSessions)
HandleFunc("/favicon.ico", faviconHandler) HandleFunc("/favicon.ico", faviconHandler)
HandleFunc("/logo", logoHandler) HandleFunc("/logo", logoHandler)

View File

@@ -1,106 +0,0 @@
package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"net/http"
"sync"
"time"
"astraltech.xyz/accountmanager/src/logging"
)
type SessionData struct {
loggedIn bool
data *UserData
timeCreated time.Time
CSRFToken string
}
var (
sessions = make(map[string]SessionData)
sessionMutex sync.Mutex
)
func GenerateSessionToken(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
token := base64.RawURLEncoding.EncodeToString(b)
return token, nil
}
func createSession(userData *UserData) *http.Cookie {
logging.Debugf("Creating a new session for %s", userData.Username)
token, err := GenerateSessionToken(32) // Use crypto/rand for this
if err != nil {
logging.Error(err.Error())
return nil
}
CSRFToken, err := GenerateSessionToken(32)
if err != nil {
logging.Error(err.Error())
return nil
}
tokenEncoded := sha256.Sum256([]byte(token))
tokenEncodedString := string(tokenEncoded[:])
sessionMutex.Lock()
defer sessionMutex.Unlock()
loggedIn := false
if userData != nil {
loggedIn = true
}
sessions[tokenEncodedString] = SessionData{
data: userData,
timeCreated: time.Now(),
CSRFToken: CSRFToken,
loggedIn: loggedIn,
}
cookie := &http.Cookie{
Name: "session_token",
Value: token,
Path: "/",
HttpOnly: true, // Essential: prevents JS access
Secure: true, // Set to TRUE in production (HTTPS)
SameSite: http.SameSiteLaxMode,
MaxAge: 3600, // 1 hour
}
return cookie
}
func validateSession(r *http.Request) (bool, *SessionData) {
logging.Debugf("Validating session")
cookie, err := r.Cookie("session_token")
if err != nil {
logging.Error(err.Error())
return false, &SessionData{}
}
token := cookie.Value
tokenEncoded := sha256.Sum256([]byte(token))
tokenEncodedString := string(tokenEncoded[:])
sessionMutex.Lock()
sessionData, exists := sessions[tokenEncodedString]
sessionMutex.Unlock()
if !exists || !sessionData.loggedIn {
return false, &SessionData{}
}
logging.Infof("Validated session for %s", sessionData.data.Username)
return true, &sessionData
}
func deleteSession(session_id string) {
sessionMutex.Lock()
tokenEncoded := sha256.Sum256([]byte(session_id))
tokenEncodedString := string(tokenEncoded[:])
delete(sessions, tokenEncodedString)
sessionMutex.Unlock()
}

View File

@@ -0,0 +1,24 @@
package session
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
// helper function for secure session storage
func GenerateSessionToken(length int) (string, error) {
b := make([]byte, length)
_, err := rand.Read(b)
if err != nil {
return "", err
}
token := base64.RawURLEncoding.EncodeToString(b)
return token, nil
}
// more helper
func hashSession(session_id string) string {
tokenEncoded := sha256.Sum256([]byte(session_id))
return base64.RawURLEncoding.EncodeToString(tokenEncoded[:])
}

View File

@@ -0,0 +1,99 @@
package session
import (
"errors"
"sync"
"time"
"astraltech.xyz/accountmanager/src/logging"
"astraltech.xyz/accountmanager/src/worker"
)
var ErrSessionNotFound = errors.New("session not found")
var ErrSessionAlreadyExists = errors.New("session already exists")
var ErrSessionExpired = errors.New("session expired")
type MemoryStore struct {
sessions map[string]*SessionData
lock sync.RWMutex
}
func NewMemoryStore() *MemoryStore {
logging.Debug("Creating new in memory session store")
store := &MemoryStore{
sessions: make(map[string]*SessionData),
}
worker.CreateWorker(time.Minute*5, store.cleanup)
return store
}
func (m *MemoryStore) Create(sessionID string, session *SessionData) (err error) {
hashedSession := hashSession(sessionID)
m.lock.Lock()
defer m.lock.Unlock()
_, exist := m.sessions[hashedSession]
if exist {
return ErrSessionAlreadyExists
}
m.sessions[hashedSession] = session
return nil
}
func (m *MemoryStore) Get(sessionID string) (*SessionData, error) {
m.lock.RLock()
hashed := hashSession(sessionID)
data, exists := m.sessions[hashed]
m.lock.RUnlock()
if exists == false {
return nil, ErrSessionNotFound
}
if time.Now().After(data.ExpiresAt) {
_ = m.Delete(sessionID) // ignore error
return nil, ErrSessionExpired
}
copy := *data
return &copy, nil
}
func (m *MemoryStore) Update(sessionID string, session *SessionData) error {
hashedSession := hashSession(sessionID)
m.lock.Lock()
defer m.lock.Unlock()
_, exist := m.sessions[hashedSession]
if !exist {
return ErrSessionNotFound
}
m.sessions[hashedSession] = session
return nil
}
func (m *MemoryStore) cleanup() {
logging.Debug("Cleaning up memory store sessions")
now := time.Now()
m.lock.Lock()
defer m.lock.Unlock()
deleted := 0
for id, session := range m.sessions {
if now.After(session.ExpiresAt) {
delete(m.sessions, id)
deleted = deleted + 1
}
}
logging.Infof("Cleaned up %d stale sessions", deleted)
}
func (m *MemoryStore) Delete(sessionID string) error {
hashedSession := hashSession(sessionID)
m.lock.Lock()
defer m.lock.Unlock()
_, exist := m.sessions[hashedSession]
if !exist {
return ErrSessionNotFound
}
delete(m.sessions, hashedSession)
return nil
}

View File

@@ -0,0 +1,85 @@
package session
import (
"net/http"
"time"
"astraltech.xyz/accountmanager/src/logging"
)
const SessionCookieName = "session_token"
type SessionManager struct {
store SessionStore
}
type StoreType int
const (
InMemory StoreType = iota
)
func CreateSessionManager(storeType StoreType) *SessionManager {
sessionManager := SessionManager{}
switch storeType {
case InMemory:
{
sessionManager.store = NewMemoryStore()
break
}
}
return &sessionManager
}
func (manager *SessionManager) CreateSession(userID string) (cookie *http.Cookie, err error) {
logging.Debugf("Creating a new session for %s", userID)
token, err := GenerateSessionToken(32) // Use crypto/rand for this
if err != nil {
return nil, err
}
CSRFToken, err := GenerateSessionToken(32)
if err != nil {
return nil, err
}
newSessionData := SessionData{
UserID: userID,
CSRFToken: CSRFToken,
ExpiresAt: time.Now().Add(time.Hour),
}
err = manager.store.Create(token, &newSessionData)
if err != nil {
return nil, err
}
newCookie := &http.Cookie{
Name: SessionCookieName,
Value: token,
Path: "/",
HttpOnly: true, // Essential: prevents JS access
Secure: true, // Set to TRUE in production (HTTPS)
SameSite: http.SameSiteLaxMode,
MaxAge: 3600, // 1 hour
}
return newCookie, nil
}
func (manager *SessionManager) GetSession(r *http.Request) (*SessionData, error) {
logging.Debug("Validating session from request")
cookie, err := r.Cookie(SessionCookieName)
if err != nil {
return nil, ErrSessionNotFound
}
token := cookie.Value
if token == "" {
return nil, ErrSessionNotFound
}
data, err := manager.store.Get(token)
if err != nil {
return nil, ErrSessionNotFound
}
return data, nil
}
func (manager *SessionManager) DeleteSession(sessionId string) error {
return manager.store.Delete(sessionId)
}

View File

@@ -0,0 +1,16 @@
package session
import "time"
type SessionData struct {
UserID string
CSRFToken string
ExpiresAt time.Time
}
type SessionStore interface {
Create(sessionID string, session *SessionData) error
Get(sessionID string) (*SessionData, error)
Update(sessionID string, session *SessionData) error
Delete(sessionID string) error
}

View File

@@ -1,4 +1,4 @@
package main package worker
import ( import (
"time" "time"
@@ -6,7 +6,7 @@ import (
"astraltech.xyz/accountmanager/src/logging" "astraltech.xyz/accountmanager/src/logging"
) )
func createWorker(interval time.Duration, task func()) { func CreateWorker(interval time.Duration, task func()) {
logging.Debugf("Creating worker that runs on a %s interval", interval.String()) logging.Debugf("Creating worker that runs on a %s interval", interval.String())
go func() { go func() {
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)