import typer
from typing import List
from rich.console import Console
from ..config.models import TransitivityConfig
from ..config import load_config_with_args
from ..pipeline.transitivity_pipeline import transitivity_pipeline
transitivity_app = typer.Typer()
console = Console()
@transitivity_app.command("run", context_settings={"ignore_unknown_options": True, "allow_extra_args": True})
def run_transitivity(
    ctx: typer.Context,
    set_: List[str] = typer.Option([], "--set", "-s",
        help="Override via dot-keys, e.g. -s wandb.project=my_project"),
) -> None:
    extra_args = ctx.args
    kv_list = [arg.lstrip("-") for arg in extra_args if "=" in arg]
    all_overrides = list(set_) + kv_list
    console.print(f"[cyan]Collected overrides:[/cyan] {all_overrides}")
    config = load_config_with_args(TransitivityConfig, all_overrides, prefix="APP")
    config.share_settings_to_runs()
    console.print(f"[green]Test datasets:[/green] {config.dataset.test_dataset_list}")
    console.print(f"[green]Number of cases:[/green] {len(config.cases)}")
    for idx, case in enumerate(config.cases):
        a = case.run_ab.model.source_model
        b = case.run_ab.model.target_model
        c = case.run_bc.model.target_model
        console.print(f"  Case {idx+1}: {a} → {b} → {c} vs {a} → {c}")
    transitivity_pipeline(config)
if __name__ == "__main__":
    transitivity_app()
