import argparse
import copy
import os
import random
import yaml
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

try:
    from tqdm.auto import tqdm
except Exception:  # pragma: no cover
    def tqdm(x, *args, **kwargs):
        return x
from experiments.deployment_gate.dataset_loader import get_dataloaders
from experiments.deployment_gate.model_factory import get_resnet50, load_checkpoint
from experiments.deployment_gate.training import evaluate, train_candidate_models, train_source_model

# New and existing metric imports
from experiments.deployment_gate.metrics.trace import calculate_trace_bound
from experiments.deployment_gate.metrics.disagreement import compute_entropy, compute_symmetric_kl
from experiments.deployment_gate.metrics.divergences import calculate_mmd
from experiments.deployment_gate.metrics.domain_classifier import train_domain_classifier_and_iw
from experiments.deployment_gate.metrics.instrumentation import run_with_cost
from experiments.deployment_gate.metrics.mahalanobis import compute_mahalanobis
from experiments.deployment_gate.metrics.ood_scores import calculate_ood_scores
from experiments.deployment_gate.metrics.model_change import compute_outdisc_metrics, compute_ece_shift
from experiments.deployment_gate.metrics.shift import compute_shift_metrics


def _load_config(path: str) -> Dict[str, Any]:
    with open(path, "r") as f:
        return yaml.safe_load(f)


def _get_feature_extractor(model: nn.Module) -> Callable:
    if isinstance(model, nn.DataParallel):
        model = model.module
    backbone = nn.Sequential(*list(model.children())[:-1])

    def extractor(loader):
        backbone.eval()
        feats = []
        with torch.no_grad():
            for images, _ in loader:
                images = images.to(next(backbone.parameters()).device)
                f = backbone(images)
                f = torch.flatten(f, 1)
                feats.append(f.detach())
        return torch.cat(feats, dim=0) if feats else torch.empty(0)

    return extractor


def _log(msg: str) -> None:
    print(msg, flush=True)


def main():
    parser = argparse.ArgumentParser(description="trace Deployment Gate")
    parser.add_argument("--config", type=str, required=True, help="Path to YAML config")
    parser.add_argument("--device", type=str, help="cpu or cuda")
    parser.add_argument("--seed", type=int, help="RNG seed")
    parser.add_argument("--train-source", action="store_true", help="Force-retrain source model")
    parser.add_argument("--train-candidates", action="store_true", help="Force-retrain candidates")
    args = parser.parse_args()

    config = _load_config(args.config)
    if args.device:
        config["device"] = args.device
    if args.seed:
        config["seed"] = args.seed
    device = config["device"]

    # Data
    _log("[Setup] Building dataloaders...")
    loaders = get_dataloaders(
        source_domain=config["source_domain"],
        target_domain=config["target_domain"],
        batch_size=config["batch_size"],
        data_root=config.get("data_root"),
        num_workers=config.get("num_workers", 2),
        seed=config["seed"],
        max_samples_per_split=config.get("max_samples_per_split"),
    )

    # Source model Q
    output_dir = config.get("output_dir", "experiments/deployment_gate/results")
    os.makedirs(output_dir, exist_ok=True)
    source_ckpt_path = os.path.join(output_dir, "checkpoints", "source", "source_best.pth")

    if args.train_source or not os.path.exists(source_ckpt_path):
        _log("[Train] Training source model Q...")
        model_q, source_ckpt_path = train_source_model(config, loaders)
        _log("[Train] Source model training complete.")
    else:
        _log("[Load] Loading existing source model Q...")
        model_q = get_resnet50(pretrained=False, num_classes=config["num_classes"])
        model_q = load_checkpoint(model_q, source_ckpt_path, map_location=device).to(device)

    if device.startswith("cuda") and torch.cuda.device_count() > 1:
        model_q = nn.DataParallel(model_q)

    # Candidate models \tilde{Q}
    candidates_dir = os.path.join(output_dir, "checkpoints", "candidates")
    num_cand = config.get("num_candidates", 20)
    first_cand_path = os.path.join(candidates_dir, "tilde_q_1.pth")
    if args.train_candidates or not os.path.exists(first_cand_path):
        _log(f"[Train] Training {num_cand} candidate models...")
        candidate_paths = train_candidate_models(source_ckpt_path, config, loaders)
        _log("[Train] Candidate training complete.")
    else:
        _log("[Load] Loading existing candidate models...")
        candidate_paths = [os.path.join(candidates_dir, f"tilde_q_{i+1}.pth") for i in range(num_cand)]

    # Ground truth risk on SOURCE TEST
    _log("[Eval] Evaluating source model Q on source_test for anchor risk...")
    src_test_metrics_q = evaluate(model_q, loaders["source_test"], device)
    risk_q = 1.0 - src_test_metrics_q["acc"]
    _log(f"[Eval] Source anchor risk (1-acc): {risk_q:.4f}")

    # Shared metrics
    _log("[Metrics] Computing shared shift and domain metrics...")
    feature_extractor_q = _get_feature_extractor(model_q)
    fast_mode = bool(config.get("fast_mode", False))
    max_steps_domain = int(config.get("max_steps_domain_clf", 2000))
    shift_layers = list(config.get("shift_layers", ["layer2", "layer3", "avgpool"]))
    mmd, cost_mmd = run_with_cost(calculate_mmd, loaders["source_val"], loaders["target_eval"], feature_extractor_q, device=device)
    auc, iw, dom_err = train_domain_classifier_and_iw(
        model_q, loaders["source_val"], loaders["target_train"], device=device,
        max_steps=(max_steps_domain if fast_mode else 2000)
    )
    shift_multi = compute_shift_metrics(model_q, loaders["source_val"], loaders["target_eval"], device=device,
                                        layers=(shift_layers if fast_mode else shift_layers))
    _log("[Metrics] Shared metrics complete.")

    results = []
    tau = float(config.get("tau", 0.0))
    cand_iter = tqdm(candidate_paths, desc="[Gate] Candidates", unit="model")
    for idx, path in enumerate(cand_iter, start=1):
        row: Dict[str, Any] = {"candidate_id": os.path.basename(path).replace(".pth", "")}
        _log(f"[Candidate {idx}/{len(candidate_paths)}] Loading and evaluating {row['candidate_id']}...")
        model_tilde_q = get_resnet50(False, config["num_classes"])
        model_tilde_q = load_checkpoint(model_tilde_q, path, map_location=device).to(device)
        if device.startswith("cuda") and torch.cuda.device_count() > 1:
            model_tilde_q = nn.DataParallel(model_tilde_q)

        # Ground truth delta_R on SOURCE TEST
        src_test_metrics_t = evaluate(model_tilde_q, loaders["source_test"], device)
        risk_tilde_q = 1.0 - src_test_metrics_t["acc"]
        row["risk_q_source"] = float(risk_q)
        row["risk_tilde_q_source"] = float(risk_tilde_q)
        row["delta_R_true"] = risk_tilde_q - risk_q
        row["delta_R_signed"] = risk_tilde_q - risk_q
        row["is_harmful"] = int((risk_tilde_q - risk_q) > tau)
        # Optional: evaluate on target (sanity check improvement there)
        if "target_eval" in loaders:
            tgt_eval_metrics_t = evaluate(model_tilde_q, loaders["target_eval"], device)
            row["risk_tilde_q_target"] = float(1.0 - tgt_eval_metrics_t["acc"])
        _log(f"[Candidate {idx}] delta_R={row['delta_R_true']:.4f} harmful={row['is_harmful']} (tau={tau:.3f})")

        # trace/TRACE
        trace_result, cost_trace = run_with_cost(calculate_trace_bound, model_q, model_tilde_q, loaders["source_val"], loaders["target_eval"], device=device)
        row["trace_score"] = trace_result["bound"]
        row["trace_w1"] = trace_result["w1_term"]
        row["trace_output_dist"] = trace_result["output_dist"]
        for k, v in cost_trace.items(): row[f"cost_{k}_trace"] = v
        row["lips_q_q99"] = trace_result.get("lipschitz_q", float("nan"))
        row["lips_t_q99"] = trace_result.get("lipschitz_tilde_q", float("nan"))

        # OOD on target
        ood_scores, cost_ood = run_with_cost(calculate_ood_scores, model_tilde_q, loaders["target_eval"], device=device)
        row["msp_score"] = ood_scores["msp_score"]
        row["energy_score"] = ood_scores["energy_score"]
        for k, v in cost_ood.items(): row[f"cost_{k}_ood"] = v

        # Model-change on target
        outdisc, cost_outdisc = run_with_cost(compute_outdisc_metrics, model_q, model_tilde_q, loaders["target_eval"], device=device)
        for k, v in outdisc.items():
            row[k] = v
        for k, v in cost_outdisc.items(): row[f"cost_{k}_outdisc"] = v
        kl, cost_kl = run_with_cost(compute_symmetric_kl, model_q, model_tilde_q, loaders["target_eval"], device=device)
        row["kl_disagreement"] = kl
        for k, v in cost_kl.items(): row[f"cost_{k}_kl"] = v

        if not fast_mode:
            try:
                ece_shift = compute_ece_shift(model_q, model_tilde_q, loaders["source_val"], device=device)
            except Exception:
                ece_shift = float("nan")
            row["ece_shift"] = ece_shift

        # Shift-only metrics (constant across candidates)
        row["mmd_score"] = mmd
        for k, v in cost_mmd.items(): row[f"cost_{k}_mmd"] = v
        row["domain_classifier_auc"] = auc
        row["a_distance_proxy"] = float(2.0 * (1.0 - 2.0 * dom_err))
        for k, v in shift_multi.items():
            row[k] = v

        # TRACE variants
        lips_t = trace_result.get("lipschitz_tilde_q", 0.0)
        outdisc_l2 = outdisc.get("outdisc_l2_mean", trace_result.get("output_dist", 0.0))
        # Use frozen-backbone W1 term for TRACE (from trace_result)
        row["trace_w1_avgpool"] = float(outdisc_l2 + lips_t * trace_result.get("w1_term", 0.0))
        row["trace_w1_layer2"] = float(outdisc_l2 + lips_t * shift_multi.get("w1_layer2", 0.0))
        row["trace_w1_layer3"] = float(outdisc_l2 + lips_t * shift_multi.get("w1_layer3", 0.0))
        row["trace_mmd_avgpool"] = float(outdisc_l2 + lips_t * shift_multi.get("mmd_avgpool", 0.0))
        row["trace_energy_avgpool"] = float(outdisc_l2 + lips_t * shift_multi.get("energy_avgpool", 0.0))
        row["trace_sw2_avgpool"] = float(outdisc_l2 + lips_t * shift_multi.get("sw2_avgpool", 0.0))

        results.append(row)
        _log(f"[Candidate {idx}] Processed {row['candidate_id']}")

    df = pd.DataFrame(results)
    out_path = os.path.join(output_dir, "deployment_gate_results.csv")
    df.to_csv(out_path, index=False)
    _log(f"[Save] Saved results to {out_path}")


if __name__ == "__main__":
    main()

