500 lines
14 KiB
Python
500 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import copy
|
|
import sys
|
|
import xml.etree.ElementTree as ET
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
|
|
@dataclass
|
|
class ProviderSource:
|
|
provider_id: str
|
|
expected_id: str
|
|
path: Path
|
|
element: ET.Element
|
|
|
|
|
|
def local_name(tag: str) -> str:
|
|
"""Return local XML tag name, ignoring namespace if present."""
|
|
if "}" in tag:
|
|
return tag.rsplit("}", 1)[1]
|
|
return tag
|
|
|
|
|
|
def indent_xml(element: ET.Element, level: int = 0) -> None:
|
|
indentation = "\n" + level * "\t"
|
|
child_indentation = "\n" + (level + 1) * "\t"
|
|
|
|
children = list(element)
|
|
|
|
if children:
|
|
if not element.text or not element.text.strip():
|
|
element.text = child_indentation
|
|
|
|
for child in children:
|
|
indent_xml(child, level + 1)
|
|
|
|
last_child = children[-1]
|
|
if not last_child.tail or not last_child.tail.strip():
|
|
last_child.tail = indentation
|
|
|
|
if level and (not element.tail or not element.tail.strip()):
|
|
element.tail = indentation
|
|
|
|
|
|
def parse_csv(value: str | None) -> set[str] | None:
|
|
if not value:
|
|
return None
|
|
|
|
result = {part.strip() for part in value.split(",") if part.strip()}
|
|
return result or None
|
|
|
|
|
|
def parse_provider_operations(values: list[str] | None) -> dict[str, set[str]]:
|
|
"""
|
|
Parse repeated provider-specific operation filters.
|
|
|
|
Example:
|
|
--provider-operations atmosfair=flight,airports
|
|
--provider-operations myclimate=flight
|
|
"""
|
|
result: dict[str, set[str]] = {}
|
|
|
|
if not values:
|
|
return result
|
|
|
|
for value in values:
|
|
if "=" not in value:
|
|
raise ValueError(
|
|
f"Invalid provider operation filter '{value}'. "
|
|
"Expected format: provider_id=operation1,operation2"
|
|
)
|
|
|
|
provider_id, operations_raw = value.split("=", 1)
|
|
provider_id = provider_id.strip()
|
|
|
|
if not provider_id:
|
|
raise ValueError(f"Invalid provider operation filter '{value}': missing provider id.")
|
|
|
|
operations = parse_csv(operations_raw)
|
|
if not operations:
|
|
raise ValueError(f"Invalid provider operation filter '{value}': missing operations.")
|
|
|
|
result[provider_id] = operations
|
|
|
|
return result
|
|
|
|
|
|
def find_xml_files(input_dir: Path, output_file: Path | None = None) -> list[Path]:
|
|
files = sorted(input_dir.rglob("*.xml"))
|
|
|
|
if output_file:
|
|
output_file = output_file.resolve()
|
|
files = [path for path in files if path.resolve() != output_file]
|
|
|
|
return files
|
|
|
|
|
|
def expected_provider_id_for_file(input_dir: Path, path: Path) -> str:
|
|
"""
|
|
If file is directly in input dir:
|
|
providers/atmosfair.xml -> atmosfair
|
|
|
|
If file is in provider folder:
|
|
providers/atmosfair/provider.xml -> atmosfair
|
|
"""
|
|
relative = path.relative_to(input_dir)
|
|
|
|
if len(relative.parts) == 1:
|
|
return path.stem
|
|
|
|
return relative.parts[0]
|
|
|
|
|
|
def find_provider_nodes(root: ET.Element) -> list[ET.Element]:
|
|
if local_name(root.tag) == "provider":
|
|
return [root]
|
|
|
|
return [element for element in root.iter() if local_name(element.tag) == "provider"]
|
|
|
|
|
|
def provider_id(provider: ET.Element, fallback: str) -> str:
|
|
value = (
|
|
provider.attrib.get("id")
|
|
or provider.attrib.get("name")
|
|
or provider.attrib.get("provider")
|
|
or fallback
|
|
)
|
|
|
|
if "id" not in provider.attrib:
|
|
provider.attrib["id"] = value
|
|
|
|
return value
|
|
|
|
|
|
def read_providers(input_dir: Path, output_file: Path | None, strict_ids: bool) -> list[ProviderSource]:
|
|
providers: list[ProviderSource] = []
|
|
|
|
for path in find_xml_files(input_dir, output_file):
|
|
expected_id = expected_provider_id_for_file(input_dir, path)
|
|
|
|
try:
|
|
tree = ET.parse(path)
|
|
except ET.ParseError as exc:
|
|
raise RuntimeError(f"Could not parse XML file '{path}': {exc}") from exc
|
|
|
|
root = tree.getroot()
|
|
provider_nodes = find_provider_nodes(root)
|
|
|
|
if not provider_nodes:
|
|
print(f"Warning: no <provider> node found in {path}", file=sys.stderr)
|
|
continue
|
|
|
|
if len(provider_nodes) > 1:
|
|
print(
|
|
f"Warning: multiple <provider> nodes found in {path}; all will be considered.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
for provider_node in provider_nodes:
|
|
current_id = provider_id(provider_node, expected_id)
|
|
|
|
if strict_ids and current_id != expected_id:
|
|
raise RuntimeError(
|
|
f"Provider id mismatch in '{path}': "
|
|
f"expected '{expected_id}', found '{current_id}'."
|
|
)
|
|
|
|
if current_id != expected_id:
|
|
print(
|
|
f"Warning: provider id mismatch in '{path}': "
|
|
f"folder/file suggests '{expected_id}', XML says '{current_id}'.",
|
|
file=sys.stderr,
|
|
)
|
|
|
|
providers.append(
|
|
ProviderSource(
|
|
provider_id=current_id,
|
|
expected_id=expected_id,
|
|
path=path,
|
|
element=provider_node,
|
|
)
|
|
)
|
|
|
|
return providers
|
|
|
|
|
|
def operation_id(operation: ET.Element) -> str | None:
|
|
return (
|
|
operation.attrib.get("id")
|
|
or operation.attrib.get("name")
|
|
or operation.attrib.get("operation")
|
|
)
|
|
|
|
|
|
def find_operations_container(provider: ET.Element) -> ET.Element | None:
|
|
for child in provider:
|
|
if local_name(child.tag) == "operations":
|
|
return child
|
|
return None
|
|
|
|
|
|
def list_provider_operations(provider: ET.Element) -> list[str]:
|
|
operations_container = find_operations_container(provider)
|
|
|
|
if operations_container is None:
|
|
return []
|
|
|
|
operation_ids: list[str] = []
|
|
|
|
for child in operations_container:
|
|
if local_name(child.tag) != "operation":
|
|
continue
|
|
|
|
op_id = operation_id(child)
|
|
if op_id:
|
|
operation_ids.append(op_id)
|
|
|
|
return operation_ids
|
|
|
|
|
|
def filter_provider_operations(provider: ET.Element, allowed_operations: set[str] | None) -> ET.Element:
|
|
provider_copy = copy.deepcopy(provider)
|
|
|
|
if allowed_operations is None:
|
|
return provider_copy
|
|
|
|
operations_container = find_operations_container(provider_copy)
|
|
|
|
if operations_container is None:
|
|
return provider_copy
|
|
|
|
for child in list(operations_container):
|
|
if local_name(child.tag) != "operation":
|
|
continue
|
|
|
|
op_id = operation_id(child)
|
|
|
|
if op_id not in allowed_operations:
|
|
operations_container.remove(child)
|
|
|
|
return provider_copy
|
|
|
|
|
|
def ask_selection(
|
|
label: str,
|
|
options: list[str],
|
|
default_all: bool = True,
|
|
) -> set[str] | None:
|
|
if not options:
|
|
return set()
|
|
|
|
print()
|
|
print(label)
|
|
for index, option in enumerate(options, start=1):
|
|
print(f" {index}) {option}")
|
|
|
|
if default_all:
|
|
prompt = "Selection [Enter = all, comma-separated numbers or names]: "
|
|
else:
|
|
prompt = "Selection [Enter = none, comma-separated numbers or names]: "
|
|
|
|
raw = input(prompt).strip()
|
|
|
|
if not raw:
|
|
return None if default_all else set()
|
|
|
|
selected: set[str] = set()
|
|
|
|
for part in raw.split(","):
|
|
part = part.strip()
|
|
|
|
if part.isdigit():
|
|
index = int(part)
|
|
if index < 1 or index > len(options):
|
|
raise ValueError(f"Invalid selection number: {index}")
|
|
selected.add(options[index - 1])
|
|
else:
|
|
if part not in options:
|
|
raise ValueError(f"Invalid selection value: {part}")
|
|
selected.add(part)
|
|
|
|
return selected
|
|
|
|
|
|
def interactive_selection(
|
|
providers: list[ProviderSource],
|
|
) -> tuple[set[str] | None, set[str] | None, dict[str, set[str]]]:
|
|
provider_ids = sorted({provider.provider_id for provider in providers})
|
|
|
|
selected_providers = ask_selection(
|
|
"Available providers:",
|
|
provider_ids,
|
|
default_all=True,
|
|
)
|
|
|
|
effective_providers = selected_providers or set(provider_ids)
|
|
|
|
all_operations = sorted(
|
|
{
|
|
operation
|
|
for provider in providers
|
|
if provider.provider_id in effective_providers
|
|
for operation in list_provider_operations(provider.element)
|
|
}
|
|
)
|
|
|
|
global_operations = ask_selection(
|
|
"Available operations:",
|
|
all_operations,
|
|
default_all=True,
|
|
)
|
|
|
|
provider_operations: dict[str, set[str]] = {}
|
|
|
|
print()
|
|
custom = input("Define provider-specific operation filters? [y/N]: ").strip().lower()
|
|
|
|
if custom in {"y", "yes"}:
|
|
for provider_id in sorted(effective_providers):
|
|
provider = next(p for p in providers if p.provider_id == provider_id)
|
|
ops = sorted(list_provider_operations(provider.element))
|
|
|
|
if not ops:
|
|
continue
|
|
|
|
selected_ops = ask_selection(
|
|
f"Operations for provider '{provider_id}':",
|
|
ops,
|
|
default_all=True,
|
|
)
|
|
|
|
if selected_ops is not None:
|
|
provider_operations[provider_id] = selected_ops
|
|
|
|
return selected_providers, global_operations, provider_operations
|
|
|
|
|
|
def merge_providers(
|
|
providers: list[ProviderSource],
|
|
selected_providers: set[str] | None,
|
|
global_operations: set[str] | None,
|
|
provider_operations: dict[str, set[str]],
|
|
) -> ET.Element:
|
|
root = ET.Element("providers")
|
|
|
|
seen_provider_ids: set[str] = set()
|
|
|
|
for provider in providers:
|
|
if selected_providers is not None and provider.provider_id not in selected_providers:
|
|
continue
|
|
|
|
if provider.provider_id in seen_provider_ids:
|
|
raise RuntimeError(
|
|
f"Duplicate provider id '{provider.provider_id}' after selection. "
|
|
"Provider ids must be unique in the merged output."
|
|
)
|
|
|
|
seen_provider_ids.add(provider.provider_id)
|
|
|
|
operations_for_provider = provider_operations.get(
|
|
provider.provider_id,
|
|
global_operations,
|
|
)
|
|
|
|
provider_element = filter_provider_operations(
|
|
provider.element,
|
|
operations_for_provider,
|
|
)
|
|
|
|
root.append(provider_element)
|
|
|
|
return root
|
|
|
|
|
|
def write_output(root: ET.Element, output_file: Path) -> None:
|
|
output_file.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
indent_xml(root)
|
|
|
|
ET.register_namespace("", "https://calco2la.to/schema/providers/v1")
|
|
tree = ET.ElementTree(root)
|
|
tree.write(
|
|
output_file,
|
|
encoding="utf-8",
|
|
xml_declaration=True,
|
|
)
|
|
|
|
|
|
def build_arg_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(
|
|
description="Merge single-provider XML descriptions into one provider.xml file."
|
|
)
|
|
|
|
parser.add_argument(
|
|
"input_dir",
|
|
type=Path,
|
|
help="Directory containing provider XML files or provider folders.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
type=Path,
|
|
default=Path("provider.xml"),
|
|
help="Output file. Default: provider.xml",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--providers",
|
|
help="Comma-separated provider ids to merge. Default: all providers.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--operations",
|
|
help="Comma-separated global operation ids to keep. Default: all operations.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--provider-operations",
|
|
action="append",
|
|
help=(
|
|
"Provider-specific operation filter. "
|
|
"Format: provider_id=operation1,operation2. "
|
|
"Can be used multiple times."
|
|
),
|
|
)
|
|
|
|
parser.add_argument(
|
|
"-i",
|
|
"--interactive",
|
|
action="store_true",
|
|
help="Interactively select providers and operations.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--strict-provider-id",
|
|
action="store_true",
|
|
help="Fail if folder/file name and XML provider id do not match.",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def main() -> int:
|
|
parser = build_arg_parser()
|
|
args = parser.parse_args()
|
|
|
|
input_dir: Path = args.input_dir
|
|
output_file: Path = args.output
|
|
|
|
if not input_dir.exists():
|
|
print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
|
|
return 1
|
|
|
|
if not input_dir.is_dir():
|
|
print(f"Input path is not a directory: {input_dir}", file=sys.stderr)
|
|
return 1
|
|
|
|
try:
|
|
providers = read_providers(
|
|
input_dir=input_dir,
|
|
output_file=output_file,
|
|
strict_ids=args.strict_provider_id,
|
|
)
|
|
|
|
if not providers:
|
|
print("No providers found.", file=sys.stderr)
|
|
return 1
|
|
|
|
if args.interactive:
|
|
selected_providers, global_operations, provider_operations = interactive_selection(providers)
|
|
else:
|
|
selected_providers = parse_csv(args.providers)
|
|
global_operations = parse_csv(args.operations)
|
|
provider_operations = parse_provider_operations(args.provider_operations)
|
|
|
|
merged_root = merge_providers(
|
|
providers=providers,
|
|
selected_providers=selected_providers,
|
|
global_operations=global_operations,
|
|
provider_operations=provider_operations,
|
|
)
|
|
|
|
write_output(merged_root, output_file)
|
|
|
|
except Exception as exc:
|
|
print(f"Error: {exc}", file=sys.stderr)
|
|
return 1
|
|
|
|
print(f"Created merged provider configuration: {output_file}")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main()) |