initial untested XML commit
This commit is contained in:
500
merge_providers.py
Normal file
500
merge_providers.py
Normal file
@@ -0,0 +1,500 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderSource:
|
||||
provider_id: str
|
||||
expected_id: str
|
||||
path: Path
|
||||
element: ET.Element
|
||||
|
||||
|
||||
def local_name(tag: str) -> str:
|
||||
"""Return local XML tag name, ignoring namespace if present."""
|
||||
if "}" in tag:
|
||||
return tag.rsplit("}", 1)[1]
|
||||
return tag
|
||||
|
||||
|
||||
def indent_xml(element: ET.Element, level: int = 0) -> None:
|
||||
indentation = "\n" + level * "\t"
|
||||
child_indentation = "\n" + (level + 1) * "\t"
|
||||
|
||||
children = list(element)
|
||||
|
||||
if children:
|
||||
if not element.text or not element.text.strip():
|
||||
element.text = child_indentation
|
||||
|
||||
for child in children:
|
||||
indent_xml(child, level + 1)
|
||||
|
||||
last_child = children[-1]
|
||||
if not last_child.tail or not last_child.tail.strip():
|
||||
last_child.tail = indentation
|
||||
|
||||
if level and (not element.tail or not element.tail.strip()):
|
||||
element.tail = indentation
|
||||
|
||||
|
||||
def parse_csv(value: str | None) -> set[str] | None:
|
||||
if not value:
|
||||
return None
|
||||
|
||||
result = {part.strip() for part in value.split(",") if part.strip()}
|
||||
return result or None
|
||||
|
||||
|
||||
def parse_provider_operations(values: list[str] | None) -> dict[str, set[str]]:
|
||||
"""
|
||||
Parse repeated provider-specific operation filters.
|
||||
|
||||
Example:
|
||||
--provider-operations atmosfair=flight,airports
|
||||
--provider-operations myclimate=flight
|
||||
"""
|
||||
result: dict[str, set[str]] = {}
|
||||
|
||||
if not values:
|
||||
return result
|
||||
|
||||
for value in values:
|
||||
if "=" not in value:
|
||||
raise ValueError(
|
||||
f"Invalid provider operation filter '{value}'. "
|
||||
"Expected format: provider_id=operation1,operation2"
|
||||
)
|
||||
|
||||
provider_id, operations_raw = value.split("=", 1)
|
||||
provider_id = provider_id.strip()
|
||||
|
||||
if not provider_id:
|
||||
raise ValueError(f"Invalid provider operation filter '{value}': missing provider id.")
|
||||
|
||||
operations = parse_csv(operations_raw)
|
||||
if not operations:
|
||||
raise ValueError(f"Invalid provider operation filter '{value}': missing operations.")
|
||||
|
||||
result[provider_id] = operations
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def find_xml_files(input_dir: Path, output_file: Path | None = None) -> list[Path]:
|
||||
files = sorted(input_dir.rglob("*.xml"))
|
||||
|
||||
if output_file:
|
||||
output_file = output_file.resolve()
|
||||
files = [path for path in files if path.resolve() != output_file]
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def expected_provider_id_for_file(input_dir: Path, path: Path) -> str:
|
||||
"""
|
||||
If file is directly in input dir:
|
||||
providers/atmosfair.xml -> atmosfair
|
||||
|
||||
If file is in provider folder:
|
||||
providers/atmosfair/provider.xml -> atmosfair
|
||||
"""
|
||||
relative = path.relative_to(input_dir)
|
||||
|
||||
if len(relative.parts) == 1:
|
||||
return path.stem
|
||||
|
||||
return relative.parts[0]
|
||||
|
||||
|
||||
def find_provider_nodes(root: ET.Element) -> list[ET.Element]:
|
||||
if local_name(root.tag) == "provider":
|
||||
return [root]
|
||||
|
||||
return [element for element in root.iter() if local_name(element.tag) == "provider"]
|
||||
|
||||
|
||||
def provider_id(provider: ET.Element, fallback: str) -> str:
|
||||
value = (
|
||||
provider.attrib.get("id")
|
||||
or provider.attrib.get("name")
|
||||
or provider.attrib.get("provider")
|
||||
or fallback
|
||||
)
|
||||
|
||||
if "id" not in provider.attrib:
|
||||
provider.attrib["id"] = value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def read_providers(input_dir: Path, output_file: Path | None, strict_ids: bool) -> list[ProviderSource]:
|
||||
providers: list[ProviderSource] = []
|
||||
|
||||
for path in find_xml_files(input_dir, output_file):
|
||||
expected_id = expected_provider_id_for_file(input_dir, path)
|
||||
|
||||
try:
|
||||
tree = ET.parse(path)
|
||||
except ET.ParseError as exc:
|
||||
raise RuntimeError(f"Could not parse XML file '{path}': {exc}") from exc
|
||||
|
||||
root = tree.getroot()
|
||||
provider_nodes = find_provider_nodes(root)
|
||||
|
||||
if not provider_nodes:
|
||||
print(f"Warning: no <provider> node found in {path}", file=sys.stderr)
|
||||
continue
|
||||
|
||||
if len(provider_nodes) > 1:
|
||||
print(
|
||||
f"Warning: multiple <provider> nodes found in {path}; all will be considered.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
for provider_node in provider_nodes:
|
||||
current_id = provider_id(provider_node, expected_id)
|
||||
|
||||
if strict_ids and current_id != expected_id:
|
||||
raise RuntimeError(
|
||||
f"Provider id mismatch in '{path}': "
|
||||
f"expected '{expected_id}', found '{current_id}'."
|
||||
)
|
||||
|
||||
if current_id != expected_id:
|
||||
print(
|
||||
f"Warning: provider id mismatch in '{path}': "
|
||||
f"folder/file suggests '{expected_id}', XML says '{current_id}'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
providers.append(
|
||||
ProviderSource(
|
||||
provider_id=current_id,
|
||||
expected_id=expected_id,
|
||||
path=path,
|
||||
element=provider_node,
|
||||
)
|
||||
)
|
||||
|
||||
return providers
|
||||
|
||||
|
||||
def operation_id(operation: ET.Element) -> str | None:
|
||||
return (
|
||||
operation.attrib.get("id")
|
||||
or operation.attrib.get("name")
|
||||
or operation.attrib.get("operation")
|
||||
)
|
||||
|
||||
|
||||
def find_operations_container(provider: ET.Element) -> ET.Element | None:
|
||||
for child in provider:
|
||||
if local_name(child.tag) == "operations":
|
||||
return child
|
||||
return None
|
||||
|
||||
|
||||
def list_provider_operations(provider: ET.Element) -> list[str]:
|
||||
operations_container = find_operations_container(provider)
|
||||
|
||||
if operations_container is None:
|
||||
return []
|
||||
|
||||
operation_ids: list[str] = []
|
||||
|
||||
for child in operations_container:
|
||||
if local_name(child.tag) != "operation":
|
||||
continue
|
||||
|
||||
op_id = operation_id(child)
|
||||
if op_id:
|
||||
operation_ids.append(op_id)
|
||||
|
||||
return operation_ids
|
||||
|
||||
|
||||
def filter_provider_operations(provider: ET.Element, allowed_operations: set[str] | None) -> ET.Element:
|
||||
provider_copy = copy.deepcopy(provider)
|
||||
|
||||
if allowed_operations is None:
|
||||
return provider_copy
|
||||
|
||||
operations_container = find_operations_container(provider_copy)
|
||||
|
||||
if operations_container is None:
|
||||
return provider_copy
|
||||
|
||||
for child in list(operations_container):
|
||||
if local_name(child.tag) != "operation":
|
||||
continue
|
||||
|
||||
op_id = operation_id(child)
|
||||
|
||||
if op_id not in allowed_operations:
|
||||
operations_container.remove(child)
|
||||
|
||||
return provider_copy
|
||||
|
||||
|
||||
def ask_selection(
|
||||
label: str,
|
||||
options: list[str],
|
||||
default_all: bool = True,
|
||||
) -> set[str] | None:
|
||||
if not options:
|
||||
return set()
|
||||
|
||||
print()
|
||||
print(label)
|
||||
for index, option in enumerate(options, start=1):
|
||||
print(f" {index}) {option}")
|
||||
|
||||
if default_all:
|
||||
prompt = "Selection [Enter = all, comma-separated numbers or names]: "
|
||||
else:
|
||||
prompt = "Selection [Enter = none, comma-separated numbers or names]: "
|
||||
|
||||
raw = input(prompt).strip()
|
||||
|
||||
if not raw:
|
||||
return None if default_all else set()
|
||||
|
||||
selected: set[str] = set()
|
||||
|
||||
for part in raw.split(","):
|
||||
part = part.strip()
|
||||
|
||||
if part.isdigit():
|
||||
index = int(part)
|
||||
if index < 1 or index > len(options):
|
||||
raise ValueError(f"Invalid selection number: {index}")
|
||||
selected.add(options[index - 1])
|
||||
else:
|
||||
if part not in options:
|
||||
raise ValueError(f"Invalid selection value: {part}")
|
||||
selected.add(part)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
def interactive_selection(
|
||||
providers: list[ProviderSource],
|
||||
) -> tuple[set[str] | None, set[str] | None, dict[str, set[str]]]:
|
||||
provider_ids = sorted({provider.provider_id for provider in providers})
|
||||
|
||||
selected_providers = ask_selection(
|
||||
"Available providers:",
|
||||
provider_ids,
|
||||
default_all=True,
|
||||
)
|
||||
|
||||
effective_providers = selected_providers or set(provider_ids)
|
||||
|
||||
all_operations = sorted(
|
||||
{
|
||||
operation
|
||||
for provider in providers
|
||||
if provider.provider_id in effective_providers
|
||||
for operation in list_provider_operations(provider.element)
|
||||
}
|
||||
)
|
||||
|
||||
global_operations = ask_selection(
|
||||
"Available operations:",
|
||||
all_operations,
|
||||
default_all=True,
|
||||
)
|
||||
|
||||
provider_operations: dict[str, set[str]] = {}
|
||||
|
||||
print()
|
||||
custom = input("Define provider-specific operation filters? [y/N]: ").strip().lower()
|
||||
|
||||
if custom in {"y", "yes"}:
|
||||
for provider_id in sorted(effective_providers):
|
||||
provider = next(p for p in providers if p.provider_id == provider_id)
|
||||
ops = sorted(list_provider_operations(provider.element))
|
||||
|
||||
if not ops:
|
||||
continue
|
||||
|
||||
selected_ops = ask_selection(
|
||||
f"Operations for provider '{provider_id}':",
|
||||
ops,
|
||||
default_all=True,
|
||||
)
|
||||
|
||||
if selected_ops is not None:
|
||||
provider_operations[provider_id] = selected_ops
|
||||
|
||||
return selected_providers, global_operations, provider_operations
|
||||
|
||||
|
||||
def merge_providers(
|
||||
providers: list[ProviderSource],
|
||||
selected_providers: set[str] | None,
|
||||
global_operations: set[str] | None,
|
||||
provider_operations: dict[str, set[str]],
|
||||
) -> ET.Element:
|
||||
root = ET.Element("providers")
|
||||
|
||||
seen_provider_ids: set[str] = set()
|
||||
|
||||
for provider in providers:
|
||||
if selected_providers is not None and provider.provider_id not in selected_providers:
|
||||
continue
|
||||
|
||||
if provider.provider_id in seen_provider_ids:
|
||||
raise RuntimeError(
|
||||
f"Duplicate provider id '{provider.provider_id}' after selection. "
|
||||
"Provider ids must be unique in the merged output."
|
||||
)
|
||||
|
||||
seen_provider_ids.add(provider.provider_id)
|
||||
|
||||
operations_for_provider = provider_operations.get(
|
||||
provider.provider_id,
|
||||
global_operations,
|
||||
)
|
||||
|
||||
provider_element = filter_provider_operations(
|
||||
provider.element,
|
||||
operations_for_provider,
|
||||
)
|
||||
|
||||
root.append(provider_element)
|
||||
|
||||
return root
|
||||
|
||||
|
||||
def write_output(root: ET.Element, output_file: Path) -> None:
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
indent_xml(root)
|
||||
|
||||
ET.register_namespace("", "https://calco2la.to/schema/providers/v1")
|
||||
tree = ET.ElementTree(root)
|
||||
tree.write(
|
||||
output_file,
|
||||
encoding="utf-8",
|
||||
xml_declaration=True,
|
||||
)
|
||||
|
||||
|
||||
def build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Merge single-provider XML descriptions into one provider.xml file."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"input_dir",
|
||||
type=Path,
|
||||
help="Directory containing provider XML files or provider folders.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output",
|
||||
type=Path,
|
||||
default=Path("provider.xml"),
|
||||
help="Output file. Default: provider.xml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--providers",
|
||||
help="Comma-separated provider ids to merge. Default: all providers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--operations",
|
||||
help="Comma-separated global operation ids to keep. Default: all operations.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider-operations",
|
||||
action="append",
|
||||
help=(
|
||||
"Provider-specific operation filter. "
|
||||
"Format: provider_id=operation1,operation2. "
|
||||
"Can be used multiple times."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
help="Interactively select providers and operations.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--strict-provider-id",
|
||||
action="store_true",
|
||||
help="Fail if folder/file name and XML provider id do not match.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
input_dir: Path = args.input_dir
|
||||
output_file: Path = args.output
|
||||
|
||||
if not input_dir.exists():
|
||||
print(f"Input directory does not exist: {input_dir}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if not input_dir.is_dir():
|
||||
print(f"Input path is not a directory: {input_dir}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
providers = read_providers(
|
||||
input_dir=input_dir,
|
||||
output_file=output_file,
|
||||
strict_ids=args.strict_provider_id,
|
||||
)
|
||||
|
||||
if not providers:
|
||||
print("No providers found.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if args.interactive:
|
||||
selected_providers, global_operations, provider_operations = interactive_selection(providers)
|
||||
else:
|
||||
selected_providers = parse_csv(args.providers)
|
||||
global_operations = parse_csv(args.operations)
|
||||
provider_operations = parse_provider_operations(args.provider_operations)
|
||||
|
||||
merged_root = merge_providers(
|
||||
providers=providers,
|
||||
selected_providers=selected_providers,
|
||||
global_operations=global_operations,
|
||||
provider_operations=provider_operations,
|
||||
)
|
||||
|
||||
write_output(merged_root, output_file)
|
||||
|
||||
except Exception as exc:
|
||||
print(f"Error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"Created merged provider configuration: {output_file}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user