72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
from __future__ import annotations
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.security import constant_time_equal, hash_token
|
|
from app.db.base import SessionLocal, init_db
|
|
from app.models import AppSession
|
|
from app.routers import auth, chat, groups, home, remote
|
|
|
|
|
|
settings = get_settings()
|
|
app = FastAPI(title=settings.app_name, version="0.1.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.allowed_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.on_event("startup")
|
|
def on_startup() -> None:
|
|
init_db()
|
|
|
|
|
|
@app.middleware("http")
|
|
async def csrf_protection(request: Request, call_next):
|
|
if not settings.dev_mode and request.method not in {"GET", "HEAD", "OPTIONS"}:
|
|
session_id = request.cookies.get(settings.session_cookie_name)
|
|
if session_id:
|
|
header_token = request.headers.get("x-csrf-token") or request.cookies.get("grouphome_csrf")
|
|
if not header_token:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"error": {"code": "csrf_required", "message": "Security check failed.", "details": {}}},
|
|
)
|
|
db: Session = SessionLocal()
|
|
try:
|
|
session = db.get(AppSession, session_id)
|
|
if not session or not constant_time_equal(session.csrf_token_hash, hash_token(header_token)):
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={"error": {"code": "csrf_invalid", "message": "Security check failed.", "details": {}}},
|
|
)
|
|
finally:
|
|
db.close()
|
|
return await call_next(request)
|
|
|
|
|
|
@app.exception_handler(Exception)
|
|
async def unhandled_exception_handler(request: Request, exc: Exception):
|
|
if settings.dev_mode:
|
|
raise exc
|
|
return JSONResponse(
|
|
status_code=500,
|
|
content={"error": {"code": "server_error", "message": "Something went wrong.", "details": {}}},
|
|
)
|
|
|
|
|
|
app.include_router(auth.router)
|
|
app.include_router(home.router)
|
|
app.include_router(chat.router)
|
|
app.include_router(groups.router)
|
|
app.include_router(remote.api_router)
|
|
app.include_router(remote.well_known_router)
|