#!/usr/bin/env python3 from __future__ import annotations import argparse import copy import sys import xml.etree.ElementTree as ET from dataclasses import dataclass from pathlib import Path from typing import Iterable @dataclass class ProviderSource: provider_id: str expected_id: str path: Path element: ET.Element def local_name(tag: str) -> str: """Return local XML tag name, ignoring namespace if present.""" if "}" in tag: return tag.rsplit("}", 1)[1] return tag def indent_xml(element: ET.Element, level: int = 0) -> None: indentation = "\n" + level * "\t" child_indentation = "\n" + (level + 1) * "\t" children = list(element) if children: if not element.text or not element.text.strip(): element.text = child_indentation for child in children: indent_xml(child, level + 1) last_child = children[-1] if not last_child.tail or not last_child.tail.strip(): last_child.tail = indentation if level and (not element.tail or not element.tail.strip()): element.tail = indentation def parse_csv(value: str | None) -> set[str] | None: if not value: return None result = {part.strip() for part in value.split(",") if part.strip()} return result or None def parse_provider_operations(values: list[str] | None) -> dict[str, set[str]]: """ Parse repeated provider-specific operation filters. Example: --provider-operations atmosfair=flight,airports --provider-operations myclimate=flight """ result: dict[str, set[str]] = {} if not values: return result for value in values: if "=" not in value: raise ValueError( f"Invalid provider operation filter '{value}'. " "Expected format: provider_id=operation1,operation2" ) provider_id, operations_raw = value.split("=", 1) provider_id = provider_id.strip() if not provider_id: raise ValueError(f"Invalid provider operation filter '{value}': missing provider id.") operations = parse_csv(operations_raw) if not operations: raise ValueError(f"Invalid provider operation filter '{value}': missing operations.") result[provider_id] = operations return result def find_xml_files(input_dir: Path, output_file: Path | None = None) -> list[Path]: files = sorted(input_dir.rglob("*.xml")) if output_file: output_file = output_file.resolve() files = [path for path in files if path.resolve() != output_file] return files def expected_provider_id_for_file(input_dir: Path, path: Path) -> str: """ If file is directly in input dir: providers/atmosfair.xml -> atmosfair If file is in provider folder: providers/atmosfair/provider.xml -> atmosfair """ relative = path.relative_to(input_dir) if len(relative.parts) == 1: return path.stem return relative.parts[0] def find_provider_nodes(root: ET.Element) -> list[ET.Element]: if local_name(root.tag) == "provider": return [root] return [element for element in root.iter() if local_name(element.tag) == "provider"] def provider_id(provider: ET.Element, fallback: str) -> str: value = ( provider.attrib.get("id") or provider.attrib.get("name") or provider.attrib.get("provider") or fallback ) if "id" not in provider.attrib: provider.attrib["id"] = value return value def read_providers(input_dir: Path, output_file: Path | None, strict_ids: bool) -> list[ProviderSource]: providers: list[ProviderSource] = [] for path in find_xml_files(input_dir, output_file): expected_id = expected_provider_id_for_file(input_dir, path) try: tree = ET.parse(path) except ET.ParseError as exc: raise RuntimeError(f"Could not parse XML file '{path}': {exc}") from exc root = tree.getroot() provider_nodes = find_provider_nodes(root) if not provider_nodes: print(f"Warning: no node found in {path}", file=sys.stderr) continue if len(provider_nodes) > 1: print( f"Warning: multiple nodes found in {path}; all will be considered.", file=sys.stderr, ) for provider_node in provider_nodes: current_id = provider_id(provider_node, expected_id) if strict_ids and current_id != expected_id: raise RuntimeError( f"Provider id mismatch in '{path}': " f"expected '{expected_id}', found '{current_id}'." ) if current_id != expected_id: print( f"Warning: provider id mismatch in '{path}': " f"folder/file suggests '{expected_id}', XML says '{current_id}'.", file=sys.stderr, ) providers.append( ProviderSource( provider_id=current_id, expected_id=expected_id, path=path, element=provider_node, ) ) return providers def operation_id(operation: ET.Element) -> str | None: return ( operation.attrib.get("id") or operation.attrib.get("name") or operation.attrib.get("operation") ) def find_operations_container(provider: ET.Element) -> ET.Element | None: for child in provider: if local_name(child.tag) == "operations": return child return None def list_provider_operations(provider: ET.Element) -> list[str]: operations_container = find_operations_container(provider) if operations_container is None: return [] operation_ids: list[str] = [] for child in operations_container: if local_name(child.tag) != "operation": continue op_id = operation_id(child) if op_id: operation_ids.append(op_id) return operation_ids def filter_provider_operations(provider: ET.Element, allowed_operations: set[str] | None) -> ET.Element: provider_copy = copy.deepcopy(provider) if allowed_operations is None: return provider_copy operations_container = find_operations_container(provider_copy) if operations_container is None: return provider_copy for child in list(operations_container): if local_name(child.tag) != "operation": continue op_id = operation_id(child) if op_id not in allowed_operations: operations_container.remove(child) return provider_copy def ask_selection( label: str, options: list[str], default_all: bool = True, ) -> set[str] | None: if not options: return set() print() print(label) for index, option in enumerate(options, start=1): print(f" {index}) {option}") if default_all: prompt = "Selection [Enter = all, comma-separated numbers or names]: " else: prompt = "Selection [Enter = none, comma-separated numbers or names]: " raw = input(prompt).strip() if not raw: return None if default_all else set() selected: set[str] = set() for part in raw.split(","): part = part.strip() if part.isdigit(): index = int(part) if index < 1 or index > len(options): raise ValueError(f"Invalid selection number: {index}") selected.add(options[index - 1]) else: if part not in options: raise ValueError(f"Invalid selection value: {part}") selected.add(part) return selected def interactive_selection( providers: list[ProviderSource], ) -> tuple[set[str] | None, set[str] | None, dict[str, set[str]]]: provider_ids = sorted({provider.provider_id for provider in providers}) selected_providers = ask_selection( "Available providers:", provider_ids, default_all=True, ) effective_providers = selected_providers or set(provider_ids) all_operations = sorted( { operation for provider in providers if provider.provider_id in effective_providers for operation in list_provider_operations(provider.element) } ) global_operations = ask_selection( "Available operations:", all_operations, default_all=True, ) provider_operations: dict[str, set[str]] = {} print() custom = input("Define provider-specific operation filters? [y/N]: ").strip().lower() if custom in {"y", "yes"}: for provider_id in sorted(effective_providers): provider = next(p for p in providers if p.provider_id == provider_id) ops = sorted(list_provider_operations(provider.element)) if not ops: continue selected_ops = ask_selection( f"Operations for provider '{provider_id}':", ops, default_all=True, ) if selected_ops is not None: provider_operations[provider_id] = selected_ops return selected_providers, global_operations, provider_operations def merge_providers( providers: list[ProviderSource], selected_providers: set[str] | None, global_operations: set[str] | None, provider_operations: dict[str, set[str]], ) -> ET.Element: root = ET.Element("providers") seen_provider_ids: set[str] = set() for provider in providers: if selected_providers is not None and provider.provider_id not in selected_providers: continue if provider.provider_id in seen_provider_ids: raise RuntimeError( f"Duplicate provider id '{provider.provider_id}' after selection. " "Provider ids must be unique in the merged output." ) seen_provider_ids.add(provider.provider_id) operations_for_provider = provider_operations.get( provider.provider_id, global_operations, ) provider_element = filter_provider_operations( provider.element, operations_for_provider, ) root.append(provider_element) return root def write_output(root: ET.Element, output_file: Path) -> None: output_file.parent.mkdir(parents=True, exist_ok=True) indent_xml(root) ET.register_namespace("", "https://calco2la.to/schema/providers/v1") tree = ET.ElementTree(root) tree.write( output_file, encoding="utf-8", xml_declaration=True, ) def build_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Merge single-provider XML descriptions into one provider.xml file." ) parser.add_argument( "input_dir", type=Path, help="Directory containing provider XML files or provider folders.", ) parser.add_argument( "-o", "--output", type=Path, default=Path("provider.xml"), help="Output file. Default: provider.xml", ) parser.add_argument( "--providers", help="Comma-separated provider ids to merge. Default: all providers.", ) parser.add_argument( "--operations", help="Comma-separated global operation ids to keep. Default: all operations.", ) parser.add_argument( "--provider-operations", action="append", help=( "Provider-specific operation filter. " "Format: provider_id=operation1,operation2. " "Can be used multiple times." ), ) parser.add_argument( "-i", "--interactive", action="store_true", help="Interactively select providers and operations.", ) parser.add_argument( "--strict-provider-id", action="store_true", help="Fail if folder/file name and XML provider id do not match.", ) return parser def main() -> int: parser = build_arg_parser() args = parser.parse_args() input_dir: Path = args.input_dir output_file: Path = args.output if not input_dir.exists(): print(f"Input directory does not exist: {input_dir}", file=sys.stderr) return 1 if not input_dir.is_dir(): print(f"Input path is not a directory: {input_dir}", file=sys.stderr) return 1 try: providers = read_providers( input_dir=input_dir, output_file=output_file, strict_ids=args.strict_provider_id, ) if not providers: print("No providers found.", file=sys.stderr) return 1 if args.interactive: selected_providers, global_operations, provider_operations = interactive_selection(providers) else: selected_providers = parse_csv(args.providers) global_operations = parse_csv(args.operations) provider_operations = parse_provider_operations(args.provider_operations) merged_root = merge_providers( providers=providers, selected_providers=selected_providers, global_operations=global_operations, provider_operations=provider_operations, ) write_output(merged_root, output_file) except Exception as exc: print(f"Error: {exc}", file=sys.stderr) return 1 print(f"Created merged provider configuration: {output_file}") return 0 if __name__ == "__main__": raise SystemExit(main())