from __future__ import annotations from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from sqlalchemy.orm import Session from app.core.config import get_settings from app.core.security import constant_time_equal, hash_token from app.db.base import SessionLocal, init_db from app.models import AppSession from app.routers import auth, chat, groups, home, remote settings = get_settings() app = FastAPI(title=settings.app_name, version="0.1.0") app.add_middleware( CORSMiddleware, allow_origins=settings.allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") def on_startup() -> None: init_db() @app.middleware("http") async def csrf_protection(request: Request, call_next): if not settings.dev_mode and request.method not in {"GET", "HEAD", "OPTIONS"}: session_id = request.cookies.get(settings.session_cookie_name) if session_id: header_token = request.headers.get("x-csrf-token") or request.cookies.get("grouphome_csrf") if not header_token: return JSONResponse( status_code=403, content={"error": {"code": "csrf_required", "message": "Security check failed.", "details": {}}}, ) db: Session = SessionLocal() try: session = db.get(AppSession, session_id) if not session or not constant_time_equal(session.csrf_token_hash, hash_token(header_token)): return JSONResponse( status_code=403, content={"error": {"code": "csrf_invalid", "message": "Security check failed.", "details": {}}}, ) finally: db.close() return await call_next(request) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): if settings.dev_mode: raise exc return JSONResponse( status_code=500, content={"error": {"code": "server_error", "message": "Something went wrong.", "details": {}}}, ) app.include_router(auth.router) app.include_router(home.router) app.include_router(chat.router) app.include_router(groups.router) app.include_router(remote.api_router) app.include_router(remote.well_known_router)