176 lines
6.9 KiB
Python
176 lines
6.9 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
import httpx
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.security import hash_token, utc_now
|
|
from app.models import ConnectionToken, RemoteCachedObject, RemoteServerConnection, RemoteSyncCursor
|
|
from app.services.dashboard import home_dashboard
|
|
|
|
|
|
def mask_store_token(raw_token: str) -> str:
|
|
return f"dev:{raw_token}"
|
|
|
|
|
|
def unmask_store_token(stored: str) -> str:
|
|
if stored.startswith("dev:"):
|
|
return stored[4:]
|
|
return stored
|
|
|
|
|
|
def validate_connection_token(db: Session, raw_token: str) -> ConnectionToken | None:
|
|
token = db.scalar(select(ConnectionToken).where(ConnectionToken.token_hash == hash_token(raw_token)))
|
|
if not token or token.revoked_at:
|
|
return None
|
|
if token.expires_at:
|
|
now = utc_now()
|
|
expires = token.expires_at if token.expires_at.tzinfo else token.expires_at.replace(tzinfo=now.tzinfo)
|
|
if expires < now:
|
|
return None
|
|
return token
|
|
|
|
|
|
def manifest() -> dict[str, Any]:
|
|
settings = get_settings()
|
|
return {
|
|
"server_name": settings.server_name,
|
|
"api_base": settings.api_base_url,
|
|
"protocol_version": "0.1",
|
|
"capabilities": {
|
|
"events": True,
|
|
"tasks": True,
|
|
"files": True,
|
|
"chat": True,
|
|
"polls": True,
|
|
"federation": False,
|
|
},
|
|
}
|
|
|
|
|
|
def sync_payload_for_token(db: Session, token: ConnectionToken | None) -> dict[str, Any]:
|
|
settings = get_settings()
|
|
fake_ctx = type("SyncContext", (), {"home_profile": None, "member": None, "session": None})()
|
|
payload = home_dashboard(db, fake_ctx) if False else None
|
|
actions: list[dict[str, Any]] = []
|
|
events: list[dict[str, Any]] = []
|
|
announcements: list[dict[str, Any]] = []
|
|
files: list[dict[str, Any]] = []
|
|
threads: list[dict[str, Any]] = []
|
|
|
|
from sqlalchemy import desc
|
|
|
|
from app.models import Announcement, Event, FileAsset, Group, Poll, Task, Thread
|
|
from app.services.dashboard import _local_actions_for_member
|
|
from app.services.serializers import announcement_dict, event_dict, file_dict, thread_dict
|
|
|
|
group_filter = [token.group_id] if token and token.group_id else [group.id for group in db.scalars(select(Group)).all()]
|
|
groups = [db.get(Group, group_id) for group_id in group_filter]
|
|
for group in [item for item in groups if item]:
|
|
for member in group.members:
|
|
if member.status in {"joined", "verified"}:
|
|
actions.extend(_local_actions_for_member(db, member))
|
|
break
|
|
events.extend([event_dict(item, group) for item in db.scalars(select(Event).where(Event.group_id == group.id)).all()])
|
|
announcements.extend(
|
|
[announcement_dict(item, group) for item in db.scalars(select(Announcement).where(Announcement.group_id == group.id, Announcement.official.is_(True))).all()]
|
|
)
|
|
files.extend([file_dict(item, group) for item in db.scalars(select(FileAsset).where(FileAsset.group_id == group.id)).all()])
|
|
threads.extend(
|
|
[thread_dict(item, group=group) for item in db.scalars(select(Thread).where(Thread.group_id == group.id).order_by(desc(Thread.updated_at))).all()]
|
|
)
|
|
|
|
for collection in (actions, events, announcements, files, threads):
|
|
for item in collection:
|
|
item["source_type"] = "remote"
|
|
item["source_server_origin"] = settings.server_origin
|
|
|
|
return {
|
|
"cursor": utc_now().isoformat(),
|
|
"server_time": utc_now().isoformat(),
|
|
"actions": actions,
|
|
"events": events,
|
|
"announcements": announcements,
|
|
"files": files,
|
|
"threads": threads,
|
|
}
|
|
|
|
|
|
def fetch_manifest(server_url: str) -> dict[str, Any]:
|
|
settings = get_settings()
|
|
with httpx.Client(timeout=settings.remote_request_timeout_seconds, follow_redirects=True) as client:
|
|
response = client.get(f"{server_url.rstrip('/')}/.well-known/group-platform.json")
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
|
|
def sync_connection(db: Session, connection: RemoteServerConnection) -> RemoteServerConnection:
|
|
settings = get_settings()
|
|
cursor = db.scalar(select(RemoteSyncCursor).where(RemoteSyncCursor.remote_connection_id == connection.id))
|
|
since = cursor.cursor if cursor else None
|
|
params = {"since": since} if since else {}
|
|
raw_token = unmask_store_token(connection.access_token_encrypted)
|
|
try:
|
|
with httpx.Client(timeout=settings.remote_request_timeout_seconds, follow_redirects=True) as client:
|
|
response = client.get(f"{connection.api_base.rstrip('/')}/sync", params=params, headers={"Authorization": f"Bearer {raw_token}"})
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
except Exception as exc: # noqa: BLE001
|
|
connection.status = "error"
|
|
connection.last_error = str(exc)
|
|
connection.updated_at = utc_now()
|
|
db.flush()
|
|
return connection
|
|
|
|
for object_type, collection_name in [
|
|
("action", "actions"),
|
|
("event", "events"),
|
|
("announcement", "announcements"),
|
|
("file", "files"),
|
|
("thread", "threads"),
|
|
]:
|
|
for item in payload.get(collection_name, []):
|
|
remote_id = str(item.get("id") or item.get("object_id") or f"{object_type}:{len(item)}")
|
|
group_remote_id = str(item.get("source_group_id") or item.get("group_id") or "remote")
|
|
group_name = str(item.get("source_group_name") or item.get("group_name") or connection.server_name)
|
|
existing = db.scalar(
|
|
select(RemoteCachedObject).where(
|
|
RemoteCachedObject.remote_connection_id == connection.id,
|
|
RemoteCachedObject.object_type == object_type,
|
|
RemoteCachedObject.remote_id == remote_id,
|
|
)
|
|
)
|
|
if existing:
|
|
existing.group_remote_id = group_remote_id
|
|
existing.group_name = group_name
|
|
existing.payload_json = item
|
|
existing.cached_at = utc_now()
|
|
else:
|
|
db.add(
|
|
RemoteCachedObject(
|
|
remote_connection_id=connection.id,
|
|
object_type=object_type,
|
|
remote_id=remote_id,
|
|
group_remote_id=group_remote_id,
|
|
group_name=group_name,
|
|
payload_json=item,
|
|
)
|
|
)
|
|
next_cursor = payload.get("cursor")
|
|
if cursor:
|
|
cursor.cursor = next_cursor
|
|
cursor.updated_at = utc_now()
|
|
else:
|
|
db.add(RemoteSyncCursor(remote_connection_id=connection.id, cursor=next_cursor))
|
|
connection.status = "active"
|
|
connection.last_error = None
|
|
connection.last_sync_at = utc_now()
|
|
connection.updated_at = utc_now()
|
|
db.flush()
|
|
return connection
|
|
|