"""Generic scenario runner built on top of Ray."""

from __future__ import annotations

import time
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence

import numpy as np
import pandas as pd
import ray

from .data import generate_federated_streams, proportional_chains, sample_allocations
from .dp import DPQuantile
from .metrics import interval_metrics
from .suite import MethodSpec, ScenarioSpec, SuiteSpec
from .transfer import TransferDPQuantile


@dataclass(frozen=True)
class CaseCombination:
    n_samples: int
    source_prop: float
    tau: float
    biases: tuple[float, ...]
    bias_id: int
    rs: tuple[float, ...]
    rs_id: int
    k_base: int


def ensure_ray(ray_address: str | None = None) -> None:
    if ray.is_initialized():
        return
    if ray_address:
        ray.init(address=ray_address, ignore_reinit_error=True)
    else:
        ray.init(ignore_reinit_error=True)


def _build_method_index(spec: ScenarioSpec) -> Dict[str, MethodSpec]:
    mapping: Dict[str, MethodSpec] = {}
    for suite in spec.suites:
        for method in suite.methods:
            if method.name in mapping:
                raise ValueError(f"Duplicate method name detected: {method.name}")
            mapping[method.name] = method
    return mapping


def _generate_combinations(spec: ScenarioSpec) -> List[CaseCombination]:
    combos: List[CaseCombination] = []
    for n_samples in spec.target_sample_sizes:
        bias_list = spec.bias_strategy(n_samples)
        if not bias_list:
            bias_list = [[0.0] * spec.n_sites]
        for bias_id, biases in enumerate(bias_list, 1):
            if len(biases) != spec.n_sites:
                raise ValueError("Bias vector length must equal n_sites.")
            bias_tuple = tuple(float(b) for b in biases)
            for source_prop in spec.source_props:
                for tau in spec.taus:
                    for rs_id, rs in enumerate(spec.rs_options, 1):
                        if len(rs) != spec.n_sites:
                            raise ValueError("Each rs vector must match n_sites.")
                        k_base = spec.chain_lookup.get(int(n_samples))
                        if k_base is None:
                            raise ValueError(f"No chain count defined for n={n_samples}")
                        combos.append(
                            CaseCombination(
                                n_samples=int(n_samples),
                                source_prop=float(source_prop),
                                tau=float(tau),
                                biases=bias_tuple,
                                bias_id=bias_id,
                                rs=tuple(float(r) for r in rs),
                                rs_id=rs_id,
                                k_base=int(k_base),
                            )
                        )
    return combos


def run_scenario(
    spec: ScenarioSpec,
    output_dir: str | Path,
    ray_address: str | None = None,
    opt_lambda_scale: float = 8.0,
    suite_filter: str | None = None,
    n_sim_override: int | None = None,
    method_filter: list[str] | None = None,
    output_name: str | None = None,
) -> Path:
    ensure_ray(ray_address)
    combos = _generate_combinations(spec)
    total = len(combos)
    if total == 0:
        raise ValueError("Scenario produced no parameter combinations.")
    method_specs = _build_method_index(spec)
    output_filename = output_name or spec.output_name
    output_path = Path(output_dir).resolve() / output_filename
    output_path.parent.mkdir(parents=True, exist_ok=True)
    rows: List[Dict[str, float]] = []
    start = time.perf_counter()
    for idx, combo in enumerate(combos, 1):
        combo_start = time.perf_counter()
        true_qs = None
        method_data: Dict[str, Dict[str, np.ndarray | None]] = {}
        for suite in spec.suites:
            if suite_filter and suite.name != suite_filter:
                continue
            suite_true_qs, suite_methods = _run_suite(
                spec,
                combo,
                suite,
                opt_lambda_scale=opt_lambda_scale,
                n_sim_override=n_sim_override,
                method_filter=method_filter,
            )
            if suite_true_qs is not None:
                true_qs = suite_true_qs if true_qs is None else true_qs
            method_data.update(suite_methods)
        if true_qs is None:
            raise RuntimeError("No suite returned any estimates.")
        row = _build_row(spec, combo, true_qs, method_data, method_specs)
        rows.append(row)
        elapsed = time.perf_counter() - combo_start
        remaining = (time.perf_counter() - start) / idx * (total - idx)
        bias_msg = f"bias#{combo.bias_id}" if "bias_id" in spec.row_fields else ""
        rs_msg = f"rs_id={combo.rs_id}" if "rs" in spec.row_fields else ""
        print(
            f"[{idx:03d}/{total}] n={combo.n_samples:<6} τ={combo.tau:<4} "
            f"src={combo.source_prop:<4} {bias_msg} {rs_msg} "
            f"{elapsed:6.1f}s left≈{remaining/60:5.1f} min"
        )
    final_df = _merge_with_existing(output_path, rows)
    final_df.to_csv(output_path, index=False)
    elapsed_total = time.perf_counter() - start
    print(f"{spec.name} finished in {elapsed_total/60:5.1f} min")
    return output_path
def _build_row(
    spec: ScenarioSpec,
    combo: CaseCombination,
    true_qs: np.ndarray,
    method_data: Dict[str, Dict[str, np.ndarray | None]],
    method_specs: Dict[str, MethodSpec],
) -> Dict[str, float]:
    row: Dict[str, float] = {
        "n_samples": combo.n_samples,
        "source_prop": combo.source_prop,
        "tau": combo.tau,
    }
    for field in spec.row_fields:
        if field == "bias_id":
            row["bias_id"] = combo.bias_id
        elif field == "rs":
            row["rs"] = tuple(combo.rs)
    true_value = float(true_qs[0])
    for name, data in method_data.items():
        spec_obj = method_specs[name]
        metrics = interval_metrics(data["estimates"], data["variances"], true_value, spec_obj.z_score)
        for metric_name, metric_value in metrics.items():
            row[f"{name}_{metric_name}"] = metric_value
    return row


def _run_suite(
    spec: ScenarioSpec,
    combo: CaseCombination,
    suite: SuiteSpec,
    opt_lambda_scale: float,
    n_sim_override: int | None = None,
    method_filter: list[str] | None = None,
) -> tuple[np.ndarray | None, Dict[str, Dict[str, np.ndarray | None]]]:
    if not suite.methods:
        return None, {}
    lambda_grid = tuple(float(x) for x in suite.lambda_grid) or (1.0,)
    selected_methods = [
        method for method in suite.methods if (not method_filter or method.name in method_filter)
    ]
    if not selected_methods:
        return None, {}
    method_names = [method.name for method in selected_methods]
    base_seed = spec.base_seed
    target_seed_base = spec.target_seed if spec.target_seed is not None else spec.base_seed

    @ray.remote
    def _replica(main_seed: int, target_seed: int):
        np.random.seed(main_seed)
        sample_sizes = sample_allocations(spec.n_sites, combo.n_samples, combo.source_prop)
        datas, true_qs = generate_federated_streams(
            spec.dist_type, combo.tau, sample_sizes, combo.biases
        )
        K_list = proportional_chains(sample_sizes, combo.k_base)
        method_results = {}
        transfer_model = None
        requires_transfer = any(method.strategy != "target" for method in selected_methods)
        if requires_transfer:
            transfer_model = TransferDPQuantile(
                K_list=K_list,
                rs=combo.rs,
                tau=combo.tau,
                mechanism=suite.mechanism,
                burn_in_ratio=suite.burn_in_ratio,
                c0=suite.lr.c0,
                a=suite.lr.a,
                b0=suite.lr.b0,
                true_q=true_qs[0],
            )
            transfer_model.fit(datas)
        for method in selected_methods:
            if method.strategy == "target":
                np.random.seed(target_seed)
                target_model = DPQuantile(
                    tau=combo.tau,
                    r=combo.rs[0],
                    mechanism=suite.mechanism,
                    true_q=true_qs[0],
                    burn_in_ratio=suite.burn_in_ratio,
                )
                target_model.fit(datas[0])
                method_results[method.name] = {
                    "estimate": target_model.Q_avg,
                    "variance": target_model.get_variance(),
                    "weights": None,
                }
            else:
                assert transfer_model is not None
                if method.select_lambda:
                    n_target = max(combo.n_samples, 1)
                    lam = opt_lambda_scale * np.log(n_target) / np.sqrt(n_target)
                else:
                    lam = method.lambda_value if method.lambda_value is not None else lambda_grid[0]
                weights, estimate, variance = transfer_model.aggregate(lam, method.strategy)
                method_results[method.name] = {
                    "estimate": estimate,
                    "variance": variance,
                    "weights": weights,
                }
        return {"true_qs": true_qs, "methods": method_results}

    n_simu = n_sim_override if n_sim_override is not None else spec.n_sim
    futures = [
        _replica.remote(base_seed + i, target_seed_base + i) for i in range(n_simu)
    ]
    records = ray.get(futures)
    if not records:
        raise RuntimeError("No simulation records returned.")
    true_qs = np.asarray(records[0]["true_qs"], dtype=float)
    method_data: Dict[str, Dict[str, np.ndarray | None]] = {}
    for name in method_names:
        estimates = np.array([rec["methods"][name]["estimate"] for rec in records])
        variances = np.array([rec["methods"][name]["variance"] for rec in records])
        weights_samples = [rec["methods"][name]["weights"] for rec in records]
        if weights_samples[0] is None:
            weights = None
        else:
            weights = np.vstack(weights_samples)
        method_data[name] = {"estimates": estimates, "variances": variances, "weights": weights}
    return true_qs, method_data


def _merge_with_existing(output_path: Path, new_rows: List[Dict[str, float]]) -> pd.DataFrame:
    new_df = pd.DataFrame(new_rows)
    key_cols = ["n_samples", "source_prop", "tau"]
    if "bias_id" in new_df.columns:
        key_cols.append("bias_id")
    if "rs" in new_df.columns:
        key_cols.append("rs")
    if output_path.exists():
        old_df = pd.read_csv(output_path)
        overlap = old_df.merge(new_df, on=key_cols, how="inner")
        if not overlap.empty:
            print(
                f"Warning: overwriting {len(overlap)} rows in {output_path.name}; "
                "keeping latest results."
            )
        keep = old_df.merge(new_df[key_cols], on=key_cols, how="left", indicator=True)
        keep = keep[keep["_merge"] == "left_only"].drop(columns="_merge")
        combined = pd.concat([keep, new_df], ignore_index=True, sort=False)
    else:
        combined = new_df
    combined.sort_values(key_cols, inplace=True)
    combined.reset_index(drop=True, inplace=True)
    return combined

