from __future__ import annotations from dataclasses import dataclass from datetime import datetime from fastapi import Depends, HTTPException, Request, Response, status from sqlalchemy import select from sqlalchemy.orm import Session from app.core.config import get_settings from app.core.security import hash_token, session_expiry, token_urlsafe, utc_now from app.db.base import get_db from app.models import ( AppSession, AuditLog, HomeDevice, HomeProfile, Member, MemberDevice, ) @dataclass class CurrentContext: session: AppSession | None home_profile: HomeProfile | None member: Member | None home_device: HomeDevice | None member_device: MemberDevice | None @property def authenticated(self) -> bool: return self.session is not None def _expired(dt: datetime) -> bool: now = utc_now() if dt.tzinfo is None: return dt < now.replace(tzinfo=None) return dt < now def create_session( db: Session, *, home_profile: HomeProfile | None = None, member: Member | None = None, home_device: HomeDevice | None = None, member_device: MemberDevice | None = None, ) -> tuple[AppSession, str]: csrf_token = token_urlsafe(24) session = AppSession( home_profile_id=home_profile.id if home_profile else None, member_id=member.id if member else None, home_device_id=home_device.id if home_device else None, member_device_id=member_device.id if member_device else None, csrf_token_hash=hash_token(csrf_token), expires_at=session_expiry(), ) db.add(session) db.flush() return session, csrf_token def set_session_cookies(response: Response, session: AppSession, csrf_token: str) -> None: settings = get_settings() response.set_cookie( settings.session_cookie_name, session.id, httponly=True, secure=settings.cookie_secure, samesite="lax", max_age=60 * 60 * 24 * 30, path="/", ) response.set_cookie( "grouphome_csrf", csrf_token, httponly=False, secure=settings.cookie_secure, samesite="lax", max_age=60 * 60 * 24 * 30, path="/", ) def clear_session_cookies(response: Response) -> None: settings = get_settings() response.delete_cookie(settings.session_cookie_name, path="/") response.delete_cookie("grouphome_csrf", path="/") def load_context_from_request(request: Request, db: Session) -> CurrentContext: settings = get_settings() session_id = request.cookies.get(settings.session_cookie_name) if not session_id: return CurrentContext(None, None, None, None, None) session = db.get(AppSession, session_id) if not session or session.revoked_at or _expired(session.expires_at): return CurrentContext(None, None, None, None, None) home_profile = db.get(HomeProfile, session.home_profile_id) if session.home_profile_id else None member = db.get(Member, session.member_id) if session.member_id else None home_device = db.get(HomeDevice, session.home_device_id) if session.home_device_id else None member_device = db.get(MemberDevice, session.member_device_id) if session.member_device_id else None now = utc_now() if home_device and not home_device.revoked_at: home_device.last_seen_at = now if member_device and not member_device.revoked_at: member_device.last_seen_at = now db.flush() return CurrentContext(session, home_profile, member, home_device, member_device) def get_optional_context(request: Request, db: Session = Depends(get_db)) -> CurrentContext: return load_context_from_request(request, db) def get_current_context(request: Request, db: Session = Depends(get_db)) -> CurrentContext: ctx = load_context_from_request(request, db) if not ctx.authenticated: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail={"error": {"code": "not_authenticated", "message": "Open an invite or recover access to continue.", "details": {}}}, ) return ctx def get_members_for_context(db: Session, ctx: CurrentContext) -> list[Member]: if ctx.home_profile: return list( db.scalars( select(Member).where( Member.home_profile_id == ctx.home_profile.id, Member.status.in_(["joined", "verified"]), ) ).all() ) if ctx.member: return [ctx.member] return [] def get_member_for_group(db: Session, ctx: CurrentContext, group_id: str) -> Member | None: if ctx.member and ctx.member.group_id == group_id and ctx.member.status != "left": return ctx.member if ctx.home_profile: return db.scalar( select(Member).where( Member.home_profile_id == ctx.home_profile.id, Member.group_id == group_id, Member.status.in_(["joined", "verified"]), ) ) return None def ensure_home_profile(db: Session, ctx: CurrentContext, display_name: str | None = None) -> HomeProfile: if ctx.home_profile: return ctx.home_profile name = display_name or (ctx.member.display_name if ctx.member else "GroupHome member") profile = HomeProfile(primary_display_name=name) db.add(profile) db.flush() if ctx.member and ctx.member.home_profile_id is None: ctx.member.home_profile_id = profile.id if ctx.session: ctx.session.home_profile_id = profile.id ctx.home_profile = profile db.flush() return profile def audit( db: Session, *, ctx: CurrentContext | None = None, action: str, resource_type: str = "", resource_id: str = "", details: dict | None = None, ) -> None: db.add( AuditLog( actor_member_id=ctx.member.id if ctx and ctx.member else None, actor_home_profile_id=ctx.home_profile.id if ctx and ctx.home_profile else None, action=action, resource_type=resource_type, resource_id=resource_id, details_json=details or {}, ) )