# src/ablation_iees.py
import copy
from typing import List, Dict, Any, Optional
import numpy as np

from src.train import evaluate_model_thresholds

try:
    import pandas as pd
except Exception:
    pd = None


def _build_variants_from_weights(base_w: List[float]) -> Dict[str, List[float]]:
    """
    Create weight masks for common ablations without assuming your internal IEES formula.
    Supported layouts:
      - len==3: [conf, relevance, progression]
      - len==4: [conf, act_rel, grad_rel, progression]
    Anything else -> we just expose 'IEES_full' and leave the rest empty.
    """
    n = len(base_w)
    W = {}

    def mask(keep_idxs):
        w = [0.0] * n
        for i in keep_idxs:
            if 0 <= i < n:
                w[i] = base_w[i]
        return w

    if n == 3:
        # indices: 0=conf, 1=rel, 2=prog  (most common)
        W["IEES_full"]     = copy.deepcopy(base_w)
        W["IEES_w/o_conf"] = mask([1, 2])
        W["IEES_w/o_rel"]  = mask([0, 2])
        W["IEES_w/o_prog"] = mask([0, 1])
        W["IEES_conf_only"]= mask([0])
        W["IEES_prog_only"]= mask([2])
        W["IEES_conf+prog"]= mask([0, 2])
    elif n == 4:
        # indices: 0=conf, 1=act_rel, 2=grad_rel, 3=prog
        W["IEES_full"]       = copy.deepcopy(base_w)
        W["IEES_w/o_conf"]   = mask([1, 2, 3])
        W["IEES_w/o_rel"]    = mask([0, 3])          # drop both act+grad
        W["IEES_w/o_prog"]   = mask([0, 1, 2])
        W["IEES_conf_only"]  = mask([0])
        W["IEES_prog_only"]  = mask([3])
        W["IEES_rel_only"]   = mask([1, 2])
        W["IEES_conf+prog"]  = mask([0, 3])
    else:
        # Unknown layout: only keep full
        W["IEES_full"] = copy.deepcopy(base_w)

    return W


def _flatten_results(rows: List[Dict[str, Any]]):
    """Pretty print to console even if pandas isn't available."""
    if not rows:
        print("[WARN] No ablation rows collected.")
        return
    # Column order
    cols = ["variant", "threshold", "avg_accuracy", "avg_flops", "weighted_overall_accuracy",
            "exit_mean_accuracy", "exit_counts"]
    # Header
    widths = [max(len(c), 12) for c in cols]
    fmt_row = "  ".join("{:" + str(w) + "}" for w in widths)
    print("\n[ABLATON] IEES component study")
    print(fmt_row.format(*cols))
    print("-" * sum(widths))
    for r in rows:
        vals = [r.get(c, "") for c in cols]
        # make exit_counts shorter
        if isinstance(vals[-1], list):
            vals[-1] = "[" + ",".join(str(v) for v in vals[-1]) + "]"
        print(fmt_row.format(*[str(v)[:max(12, len(str(c)))] for v, c in zip(vals, cols)]))


def run_iees_ablation(model,
                      dataloader,
                      device,
                      out_csv: Optional[str] = None,
                      taus: Optional[List[float]] = None,
                      approx_flops_per_exit: Optional[List[float]] = None,  # not used here; kept for API parity
                      max_samples: Optional[int] = None,
                      base_weights: Optional[List[float]] = None,
                      model_name: Optional[str] = None) -> "pd.DataFrame | list":
    """
    Runs:
      - Confidence-only (uses exit_criterion='confidence')
      - IEES_full and component ablations derived from 'base_weights'
    Uses your evaluate_model_thresholds() for FLOPs/threshold handling.

    Args:
      model, dataloader, device: as usual
      out_csv: optional path to save a CSV
      taus: list of thresholds; if None, evaluate_model_thresholds will generate its own
      max_samples: if you have a sampler, downsample dataloader before calling this function
      base_weights: list of floats used by compute_iees_score (len=3 or 4 expected)
      model_name: passed through to your FLOPs lookup inside evaluate_model_thresholds

    Returns:
      pandas.DataFrame if pandas is available, else a list of dict rows.
    """
    rows: List[Dict[str, Any]] = []

    # 1) Confidence-only baseline
    conf_res = evaluate_model_thresholds(
        model=model,
        model_name=model_name,
        testloader=dataloader,
        device=device,
        exit_criterion="confidence",
        threshold=taus,
        weights=[],
    )
    for r in conf_res:
        rows.append({
            "variant": "Confidence-only",
            **r
        })

    # 2) IEES-based variants (full + ablations) if weights provided
    if base_weights is None:
        print("[INFO] No base_weights provided -> skipping IEES ablations.")
    else:
        variants = _build_variants_from_weights(base_weights)
        for vname, vweights in variants.items():
            res = evaluate_model_thresholds(
                model=model,
                model_name=model_name,
                testloader=dataloader,
                device=device,
                exit_criterion="iees",
                threshold=taus,
                weights=vweights,
            )
            for r in res:
                rows.append({
                    "variant": vname,
                    **r
                })

    # Save/return
    if pd is not None:
        df = pd.DataFrame(rows)
        if out_csv:
            df.to_csv(out_csv, index=False)
        return df
    else:
        _flatten_results(rows)
        return rows
