"""Command-line interface for running different scenarios."""

from __future__ import annotations

import argparse
from dataclasses import replace
from pathlib import Path
from typing import Sequence

from .biases import privacy_vectors
from .plotting import (
    plot_case_bias,
    plot_case_bias_smallk,
    plot_case_gbias,
    plot_case_gbias_changesites,
    plot_case_r,
)
from .runner import run_scenario
from .scenarios import SCENARIOS, _constant_rs, _gbias_bias_strategy


def _case_r_plotter(case_name: str):
    def _wrapped(**kwargs):
        return plot_case_r(case=case_name, **kwargs)

    return _wrapped


PLOTTERS = {
    "case_bias": plot_case_bias,
    "case_bias_smallK": plot_case_bias_smallk,
    "case_gbias": plot_case_gbias,
    "case_gbias_changesites": plot_case_gbias_changesites,
    "case_r": _case_r_plotter("case_r"),
    "case_r_consvar": _case_r_plotter("case_r_consvar"),
}


def _format_numeric(value: float) -> str:
    text = f"{value:.6g}"
    if "." in text:
        text = text.rstrip("0").rstrip(".")
    return text


def _build_suffix(dist_type: str, taus: Sequence[float]) -> str:
    tau_part = "-".join(_format_numeric(tau) for tau in taus) if taus else "na"
    return f"{dist_type}_tau{tau_part}"


def _build_case_r_suffix(dist_type: str, tau: float | None, target_r: float | None) -> str:
    parts = [dist_type]
    if tau is not None:
        parts.append(f"tau{_format_numeric(tau)}")
    if target_r is not None:
        parts.append(f"targetR{_format_numeric(target_r)}")
    return "_".join(parts)


def _replace_r(rs_vec: Sequence[float], new_r: float) -> tuple[float, ...]:
    if not rs_vec:
        return tuple()
    return tuple(new_r for _ in rs_vec)


def _append_suffix(existing: str | None, addition: str) -> str:
    return addition if not existing else f"{existing}_{addition}"


def _tweak_case_gbias_changesites(spec, n_sites: int):
    if n_sites <= 0:
        raise ValueError("n_sites must be positive.")
    base_r = spec.rs_options[0][0] if spec.rs_options and spec.rs_options[0] else 0.5
    bias_strategy = _gbias_bias_strategy(n_sites)
    rs_options = _constant_rs(n_sites, base_r)
    return replace(spec, n_sites=n_sites, bias_strategy=bias_strategy, rs_options=rs_options)


def _output_name_with_suffix(
    base_name: str,
    dist_type: str,
    taus: Sequence[float],
    extra: str | None = None,
    is_case_r: bool = False,
    target_r: float | None = None,
    case_name: str | None = None,
    tau_values: Sequence[float] | None = None,
    r_value: float | None = None,
    n_sites_value: int | None = None,
) -> str:
    if is_case_r:
        tau_val = taus[0] if taus else None
        suffix = _build_case_r_suffix(dist_type, tau_val, target_r)
        prefix = case_name or Path(base_name).stem
        return f"{prefix}_{suffix}.csv"
    if case_name in {"case_bias", "case_bias_target", "case_gbias", "case_gbias_target", "case_gbias_changesites"}:
        tau_iter = tau_values or taus or ()
        if not isinstance(tau_iter, (list, tuple)):
            tau_iter = (tau_iter,)
        tau_suffix = "_".join(_format_numeric(t) for t in tau_iter)
        name = f"{case_name}_{dist_type}_tau{tau_suffix}"
        if r_value is not None:
            name += f"_r{_format_numeric(r_value)}"
        if n_sites_value is not None:
            name += f"_nsites{n_sites_value}"
        return f"{name}.csv"
    base_path = Path(base_name)
    suffix = _build_suffix(dist_type, taus)
    if extra:
        suffix = f"{suffix}_{extra}"
    return f"{base_path.stem}_{suffix}{base_path.suffix or '.csv'}"


def main() -> None:
    parser = argparse.ArgumentParser(description="Run experiments.")
    subparsers = parser.add_subparsers(dest="command", required=True)

    list_parser = subparsers.add_parser("list", help="List available scenarios.")
    list_parser.add_argument("--verbose", action="store_true", help="Show descriptions.")

    run_parser = subparsers.add_parser("run", help="Execute one or more scenarios.")
    run_parser.add_argument("cases", nargs="+", choices=sorted(SCENARIOS.keys()))
    run_parser.add_argument(
        "--output-dir", default="output", help="Directory to store CSV files."
    )
    run_parser.add_argument(
        "--opt-lambda-scale",
        type=float,
        default=8.0,
        help="Scaling constant c in λ = c·log(N)/sqrt(N) for the opt method.",
    )
    run_parser.add_argument(
        "--simulations",
        type=int,
        default=None,
        help="Override the default number of Monte-Carlo replications.",
    )
    run_parser.add_argument(
        "--only",
        help="Run only the suite/method whose name matches this value (e.g., dpsgd).",
    )
    run_parser.add_argument(
        "--methods",
        nargs="+",
        default=None,
        help="Limit aggregation to these method names (e.g., opt mse).",
    )
    run_parser.add_argument(
        "--target-r",
        type=float,
        help="Override the target-site randomized-response rate (case_r variants).",
    )
    run_parser.add_argument(
        "--r",
        type=float,
        help="Override source-site randomized-response rate for constant-r scenarios.",
    )
    run_parser.add_argument(
        "--dist-type",
        help="Override the scenario distribution type (e.g., normal, laplace).",
    )
    run_parser.add_argument(
        "--tau",
        type=float,
        help="Override the scenario tau with a single value.",
    )
    run_parser.add_argument(
        "--n-sites",
        type=int,
        help="Override the number of participating sites (case_gbias_changesites).",
    )
    run_parser.add_argument("--ray-address", default=None, help="Optional Ray cluster address.")

    plot_parser = subparsers.add_parser("plot", help="Plot figures from generated CSVs.")
    plot_parser.add_argument("case", choices=sorted(PLOTTERS.keys()))
    plot_parser.add_argument(
        "--input-dir",
        default="output",
        help="Directory containing the aggregated CSVs.",
    )
    plot_parser.add_argument(
        "--output-dir",
        default="output",
        help="Directory to store generated figures.",
    )
    plot_parser.add_argument(
        "--tau",
        type=float,
        nargs="+",
        default=[0.25],
        help="Tau values to include (multiple allowed; first used for single-tau plots).",
    )
    plot_parser.add_argument(
        "--dist-type",
        default="normal",
        help="Distribution tag used when resolving data files (case_bias/case_gbias/case_r).",
    )
    plot_parser.add_argument(
        "--target-r",
        type=float,
        help="Target-site randomized-response rate (case_r variants).",
    )
    plot_parser.add_argument(
        "--r",
        type=float,
        help="Override all site randomized-response rates when plotting constant-r cases.",
    )
    plot_parser.add_argument(
        "--keep-samples",
        type=int,
        nargs="*",
        default=None,
        help="Optional list of target sample sizes to keep (default matches legacy plots).",
    )
    plot_parser.add_argument(
        "--figure-name",
        default=None,
        help="Override the default figure file name (defaults to <case>_<tau>_<n>.pdf).",
    )
    plot_parser.add_argument(
        "--n-sites",
        type=int,
        help="Number of sites (case_gbias_changesites plots).",
    )

    args = parser.parse_args()
    if args.command == "list":
        for name, spec in SCENARIOS.items():
            if args.verbose:
                print(f"{name:20s} - {spec.description}")
            else:
                print(name)
        return

    if args.command == "run":
        for case in args.cases:
            spec = SCENARIOS[case]
            extra_suffix = None
            if args.dist_type:
                spec = replace(spec, dist_type=args.dist_type)
            if args.tau is not None:
                spec = replace(spec, taus=(args.tau,))
            if args.n_sites is not None:
                if case != "case_gbias_changesites":
                    raise ValueError("--n-sites override is only supported for case_gbias_changesites.")
                spec = _tweak_case_gbias_changesites(spec, args.n_sites)
                extra_suffix = _append_suffix(extra_suffix, f"nsites{args.n_sites}")
            if args.target_r is not None and case in {"case_r", "case_r_consvar"}:
                rs_count = len(spec.rs_options) if spec.rs_options else 1
                existing_r_start = spec.rs_options[0][1] if spec.rs_options else 0.25
                existing_r_end = spec.rs_options[-1][1] if spec.rs_options else args.target_r
                rs_override = tuple(
                    tuple(vec)
                    for vec in privacy_vectors(
                        args.target_r,
                        existing_r_start,
                        existing_r_end,
                        rs_count,
                        spec.n_sites,
                    )
                )
                spec = replace(spec, rs_options=rs_override)
                extra_suffix = f"targetR{_format_numeric(args.target_r)}"
            if args.r is not None and spec.rs_options:
                rs_override = tuple((_replace_r(vec, args.r)) for vec in spec.rs_options)
                spec = replace(spec, rs_options=rs_override)
                if case not in {"case_r", "case_r_consvar"}:
                    extra_suffix = (extra_suffix + "_" if extra_suffix else "") + f"r{_format_numeric(args.r)}"
            output_name = None
            if case in {
                "case_bias",
                "case_bias_target",
                "case_gbias",
                "case_gbias_target",
                "case_r",
                "case_r_consvar",
                "case_gbias_changesites",
            }:
                output_name = _output_name_with_suffix(
                    spec.output_name,
                    spec.dist_type,
                    spec.taus,
                    extra=extra_suffix,
                    is_case_r=case in {"case_r", "case_r_consvar"},
                    target_r=args.target_r,
                    case_name=case,
                    tau_values=args.tau if case == "case_gbias" else None,
                    r_value=args.r,
                    n_sites_value=spec.n_sites if case == "case_gbias_changesites" else None,
                )
            path = run_scenario(
                spec,
                output_dir=args.output_dir,
                ray_address=args.ray_address,
                opt_lambda_scale=args.opt_lambda_scale,
                suite_filter=args.only,
                n_sim_override=args.simulations,
                method_filter=args.methods,
                output_name=output_name,
            )
            print(f"{case} completed → {path}")
        return

    if args.command == "plot":
        plot_fn = PLOTTERS[args.case]
        keep_samples = args.keep_samples if args.keep_samples else None
        tau_values = args.tau
        tau_primary = tau_values[0]
        if args.figure_name:
            figure_name = args.figure_name
        else:
            if len(tau_values) == 1:
                tau_suffix = f"tau{_format_numeric(tau_primary)}"
            else:
                tau_suffix = "tau" + "_".join(_format_numeric(t) for t in tau_values)
            suffix = tau_suffix
            if keep_samples:
                suffix += "_" + "_".join(str(s) for s in keep_samples)
            if args.case in {"case_bias", "case_bias_smallK", "case_gbias_changesites", "case_r", "case_r_consvar"}:
                suffix = f"{args.dist_type}_{suffix}"
            if args.case in {"case_r", "case_r_consvar"} and args.target_r is not None:
                suffix += f"_targetR{_format_numeric(args.target_r)}"
            if args.case == "case_gbias_changesites" and args.n_sites is not None:
                suffix += f"_nsites{args.n_sites}"
            figure_name = f"{args.case}_{suffix}.pdf"
        keep_samples = args.keep_samples if args.keep_samples else None
        plot_kwargs = {
            "input_dir": args.input_dir,
            "output_dir": args.output_dir,
            "tau": tau_primary,
            "keep_samples": keep_samples,
            "figure_name": figure_name,
        }
        if args.case == "case_gbias" and len(tau_values) > 1:
            plot_kwargs["taus"] = tau_values
        if args.case in {"case_bias", "case_bias_smallK", "case_gbias", "case_gbias_changesites", "case_r", "case_r_consvar"}:
            plot_kwargs["dist_type"] = args.dist_type
            if args.r is not None:
                plot_kwargs["r_value"] = args.r
        if args.case == "case_gbias_changesites" and args.n_sites is not None:
            plot_kwargs["n_sites"] = args.n_sites
        if args.case in {"case_r", "case_r_consvar"} and args.target_r is not None:
            plot_kwargs["target_r"] = args.target_r
        out_path = plot_fn(**plot_kwargs)
        print(f"{args.case} figure saved → {out_path}")
        return


if __name__ == "__main__":
    main()

