Files
comiaunicaty/backend/app/services/remote.py

176 lines
6.9 KiB
Python

from __future__ import annotations
from datetime import datetime
from typing import Any
import httpx
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.core.security import hash_token, utc_now
from app.models import ConnectionToken, RemoteCachedObject, RemoteServerConnection, RemoteSyncCursor
from app.services.dashboard import home_dashboard
def mask_store_token(raw_token: str) -> str:
return f"dev:{raw_token}"
def unmask_store_token(stored: str) -> str:
if stored.startswith("dev:"):
return stored[4:]
return stored
def validate_connection_token(db: Session, raw_token: str) -> ConnectionToken | None:
token = db.scalar(select(ConnectionToken).where(ConnectionToken.token_hash == hash_token(raw_token)))
if not token or token.revoked_at:
return None
if token.expires_at:
now = utc_now()
expires = token.expires_at if token.expires_at.tzinfo else token.expires_at.replace(tzinfo=now.tzinfo)
if expires < now:
return None
return token
def manifest() -> dict[str, Any]:
settings = get_settings()
return {
"server_name": settings.server_name,
"api_base": settings.api_base_url,
"protocol_version": "0.1",
"capabilities": {
"events": True,
"tasks": True,
"files": True,
"chat": True,
"polls": True,
"federation": False,
},
}
def sync_payload_for_token(db: Session, token: ConnectionToken | None) -> dict[str, Any]:
settings = get_settings()
fake_ctx = type("SyncContext", (), {"home_profile": None, "member": None, "session": None})()
payload = home_dashboard(db, fake_ctx) if False else None
actions: list[dict[str, Any]] = []
events: list[dict[str, Any]] = []
announcements: list[dict[str, Any]] = []
files: list[dict[str, Any]] = []
threads: list[dict[str, Any]] = []
from sqlalchemy import desc
from app.models import Announcement, Event, FileAsset, Group, Poll, Task, Thread
from app.services.dashboard import _local_actions_for_member
from app.services.serializers import announcement_dict, event_dict, file_dict, thread_dict
group_filter = [token.group_id] if token and token.group_id else [group.id for group in db.scalars(select(Group)).all()]
groups = [db.get(Group, group_id) for group_id in group_filter]
for group in [item for item in groups if item]:
for member in group.members:
if member.status in {"joined", "verified"}:
actions.extend(_local_actions_for_member(db, member))
break
events.extend([event_dict(item, group) for item in db.scalars(select(Event).where(Event.group_id == group.id)).all()])
announcements.extend(
[announcement_dict(item, group) for item in db.scalars(select(Announcement).where(Announcement.group_id == group.id, Announcement.official.is_(True))).all()]
)
files.extend([file_dict(item, group) for item in db.scalars(select(FileAsset).where(FileAsset.group_id == group.id)).all()])
threads.extend(
[thread_dict(item, group=group) for item in db.scalars(select(Thread).where(Thread.group_id == group.id).order_by(desc(Thread.updated_at))).all()]
)
for collection in (actions, events, announcements, files, threads):
for item in collection:
item["source_type"] = "remote"
item["source_server_origin"] = settings.server_origin
return {
"cursor": utc_now().isoformat(),
"server_time": utc_now().isoformat(),
"actions": actions,
"events": events,
"announcements": announcements,
"files": files,
"threads": threads,
}
def fetch_manifest(server_url: str) -> dict[str, Any]:
settings = get_settings()
with httpx.Client(timeout=settings.remote_request_timeout_seconds, follow_redirects=True) as client:
response = client.get(f"{server_url.rstrip('/')}/.well-known/group-platform.json")
response.raise_for_status()
return response.json()
def sync_connection(db: Session, connection: RemoteServerConnection) -> RemoteServerConnection:
settings = get_settings()
cursor = db.scalar(select(RemoteSyncCursor).where(RemoteSyncCursor.remote_connection_id == connection.id))
since = cursor.cursor if cursor else None
params = {"since": since} if since else {}
raw_token = unmask_store_token(connection.access_token_encrypted)
try:
with httpx.Client(timeout=settings.remote_request_timeout_seconds, follow_redirects=True) as client:
response = client.get(f"{connection.api_base.rstrip('/')}/sync", params=params, headers={"Authorization": f"Bearer {raw_token}"})
response.raise_for_status()
payload = response.json()
except Exception as exc: # noqa: BLE001
connection.status = "error"
connection.last_error = str(exc)
connection.updated_at = utc_now()
db.flush()
return connection
for object_type, collection_name in [
("action", "actions"),
("event", "events"),
("announcement", "announcements"),
("file", "files"),
("thread", "threads"),
]:
for item in payload.get(collection_name, []):
remote_id = str(item.get("id") or item.get("object_id") or f"{object_type}:{len(item)}")
group_remote_id = str(item.get("source_group_id") or item.get("group_id") or "remote")
group_name = str(item.get("source_group_name") or item.get("group_name") or connection.server_name)
existing = db.scalar(
select(RemoteCachedObject).where(
RemoteCachedObject.remote_connection_id == connection.id,
RemoteCachedObject.object_type == object_type,
RemoteCachedObject.remote_id == remote_id,
)
)
if existing:
existing.group_remote_id = group_remote_id
existing.group_name = group_name
existing.payload_json = item
existing.cached_at = utc_now()
else:
db.add(
RemoteCachedObject(
remote_connection_id=connection.id,
object_type=object_type,
remote_id=remote_id,
group_remote_id=group_remote_id,
group_name=group_name,
payload_json=item,
)
)
next_cursor = payload.get("cursor")
if cursor:
cursor.cursor = next_cursor
cursor.updated_at = utc_now()
else:
db.add(RemoteSyncCursor(remote_connection_id=connection.id, cursor=next_cursor))
connection.status = "active"
connection.last_error = None
connection.last_sync_at = utc_now()
connection.updated_at = utc_now()
db.flush()
return connection