
from __future__ import annotations
from typing import Sequence, Optional, Dict, List, Tuple, Mapping, Any, Union, Iterable
import os, pickle, time, csv, json
from pathlib import Path
from datetime import datetime
import numpy as np
import pandas as pd
from .cav import (
    sample_train_cav,
    sample_train_car,
    get_or_create_nystrom_basis,
    sample_train_car_nystrom_from_features,
)
from .metrics import cav_pairwise_mean_angle_deg, aggregate_variance_by_n

TimingMeta = Optional[Mapping[str, Any]]
PathLike = Union[str, os.PathLike]

TIMING_COLUMNS = [
    "dataset",
    "model",
    "method",
    "layer",
    "concept",
    "n_random_samples",
    "avg_time_per_cav_s",
    "std_time_per_cav_s",
    "total_time_s",
    "cavs_trained",
    "runs",
    "sets_per_run",
    "seed",
    "timestamp",
]


def _resolve_timing_path(timing_log_path: Optional[PathLike]) -> Optional[Path]:
    """Return a concrete path (file) where timing rows should be stored."""
    if timing_log_path is None:
        return None
    path = Path(timing_log_path)
    if path.suffix == "":
        path = path / "cav_timing.csv"
    return path


def _append_timing_row(timing_path: Path, record: Dict[str, Any]) -> None:
    """Append a single timing record to the CSV, creating directories/header when needed."""
    timing_path.parent.mkdir(parents=True, exist_ok=True)
    row = {col: record.get(col, "") for col in TIMING_COLUMNS}
    with timing_path.open("a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=TIMING_COLUMNS)
        if f.tell() == 0:
            writer.writeheader()
        writer.writerow(row)

def _run_timing_path(run_dir: str, run_id: int) -> str:
    """Return a JSON path for per-run timing saved alongside CAV files."""
    return os.path.join(run_dir, f"run_{run_id}_timing.json")

def _write_run_timing(run_dir: str, run_id: int, record: Dict[str, Any]) -> None:
    """Write a per-run timing JSON next to the CAV pickle."""
    path = _run_timing_path(run_dir, run_id)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(record, f, indent=2)


def precompute_cavs_for_layer(
    X_pos: np.ndarray,
    X_neg: np.ndarray,
    layer_name: str,
    method: str,
    n_values: Iterable[int],
    runs: int,
    sets_per_run: int,
    out_dir: str,
    seed: int = 0,
    concept_name: Optional[str] = None,
    timing_log_path: Optional[PathLike] = None,
    timing_metadata: TimingMeta = None,
    skip_if_exists: bool = True,
    save_run_timings: bool = True,
    print_timing: Optional[bool] = None,
    # --- CAR Nyström knobs (only used when method indicates Nyström CAR) ---
    car_nystrom_components: int = 200,
    car_nystrom_gamma: float | str = "scale",
    car_nystrom_basis_seed: Optional[int] = None,
    car_nystrom_basis_path: Optional[str] = None,
) -> None:
    """
    Saves, for each n and run, a list of CAV/CAR objects in:
        out_dir / str(n) / f"run_{run_id}.pkl"
    When `save_run_timings=True`, also writes per-run timing JSON files next to
    each run pickle (run_{id}_timing.json) so timing can be aggregated locally.
    When `timing_log_path` is provided, also appends the average train time per CAV
    (per `n`) to a CSV, which is useful for later plotting runtime vs. random-sample count.
    Set `skip_if_exists=False` to force recomputation (overwriting existing files) and
    capture fresh timings even when cached runs already exist.
    When `print_timing=True`, prints per-n timing summaries and a total summary. If
    `print_timing=None`, it defaults to `timing_log_path is not None`.
    """
    rng = np.random.default_rng(seed)
    n_values = list(n_values)
    method_l = (method or "").lower()
    is_car_full = method_l == "car"
    is_car_nystrom = method_l in {"car_nystrom", "car-nystrom", "nystrom_car", "tcar_nystrom"}

    # If Nyström CARs are requested, build/load the basis once and (if needed)
    # precompute Nyström features for the entire positive/negative pools.
    basis_path: Optional[str] = None
    basis = None
    Phi_pos_all: Optional[np.ndarray] = None
    Phi_neg_all: Optional[np.ndarray] = None

    if is_car_nystrom:
        basis_seed = int(seed if car_nystrom_basis_seed is None else car_nystrom_basis_seed)

        # Default: store the basis once in the concept directory (out_dir).
        if car_nystrom_basis_path is not None:
            basis_path = str(car_nystrom_basis_path)
        else:
            gamma_tag = str(car_nystrom_gamma).replace(".", "p").replace("/", "_")
            basis_path = os.path.join(
                out_dir,
                f"_nystrom_basis_m{int(car_nystrom_components)}_gamma{gamma_tag}_seed{basis_seed}.npz",
            )

        basis = get_or_create_nystrom_basis(
            X_pos_all=X_pos,
            X_neg_all=X_neg,
            basis_path=basis_path,
            n_components=int(car_nystrom_components),
            gamma=car_nystrom_gamma,
            basis_random_state=basis_seed,
        )

        # Only compute Phi pools if we will actually train any new CARs.
        needs_compute = not skip_if_exists
        if not needs_compute:
            for n in n_values:
                for run_id in range(runs):
                    p = os.path.join(out_dir, str(int(n)), f"run_{run_id}.pkl")
                    if not os.path.exists(p):
                        needs_compute = True
                        break
                if needs_compute:
                    break

        if needs_compute:
            Phi_pos_all = basis.transform(X_pos)
            Phi_neg_all = basis.transform(X_neg)
    timing_path = _resolve_timing_path(timing_log_path)
    if print_timing is None:
        print_timing = timing_path is not None
    overall_t0 = time.perf_counter()
    total_durations: List[float] = []
    base_record = {
        "dataset": None,
        "model": None,
        "method": method,
        "layer": layer_name,
        "concept": concept_name or Path(out_dir).name,
        "runs": int(runs),
        "sets_per_run": int(sets_per_run),
        "seed": int(seed),
    }
    if timing_metadata:
        for key, value in timing_metadata.items():
            if key in ("dataset", "model", "concept") and value is not None:
                base_record[key] = value

    for n in n_values:
        durations: List[float] = []
        for run_id in range(runs):
            run_dir = os.path.join(out_dir, str(n))
            os.makedirs(run_dir, exist_ok=True)
            path = os.path.join(run_dir, f"run_{run_id}.pkl")
            if os.path.exists(path) and skip_if_exists:
                # keep existing runs (consistent with current caching behavior)
                continue

            objs = []
            run_durations: List[float] = []
            for _ in range(sets_per_run):
                rs = int(rng.integers(0, 2**31 - 1))
                t0 = time.perf_counter()
                if is_car_full:
                    obj = sample_train_car(
                        X_pos_all=X_pos,
                        X_neg_all=X_neg,
                        n_examples=int(n),
                        kernel="rbf",
                        random_state=rs,
                    )
                elif is_car_nystrom:
                    if Phi_pos_all is None or Phi_neg_all is None:
                        # Should only happen if training is requested after a
                        # previous skip. Compute on demand.
                        if basis is None:
                            raise RuntimeError("Nyström basis was not initialized")
                        Phi_pos_all = basis.transform(X_pos)
                        Phi_neg_all = basis.transform(X_neg)
                    if basis_path is None:
                        raise RuntimeError("basis_path was not initialized")
                    obj = sample_train_car_nystrom_from_features(
                        Phi_pos_all,
                        Phi_neg_all,
                        n_examples=int(n),
                        basis_path=basis_path,
                        random_state=rs,
                    )
                else:
                    obj = sample_train_cav(
                        X_pos_all=X_pos,
                        X_neg_all=X_neg,
                        n_examples=int(n),
                        method=method,
                        random_state=rs,
                    )
                duration = time.perf_counter() - t0
                durations.append(duration)
                run_durations.append(duration)
                objs.append(obj)

            with open(path, "wb") as f:
                pickle.dump(objs, f)
            if save_run_timings and run_durations:
                run_record = {
                    **base_record,
                    "n_random_samples": int(n),
                    "run_id": int(run_id),
                    "avg_time_per_cav_s": float(np.mean(run_durations)),
                    "std_time_per_cav_s": float(np.std(run_durations)),
                    "total_time_s": float(np.sum(run_durations)),
                    "cavs_trained": len(run_durations),
                    "timestamp": datetime.utcnow().isoformat() + "Z",
                    "durations_s": [float(d) for d in run_durations],
                }
                _write_run_timing(run_dir, run_id, run_record)
        if durations:
            avg_time = float(np.mean(durations))
            std_time = float(np.std(durations))
            total_time = float(np.sum(durations))
            total_durations.extend(durations)
            if timing_path:
                record = {
                    **base_record,
                    "n_random_samples": int(n),
                    "avg_time_per_cav_s": avg_time,
                    "std_time_per_cav_s": std_time,
                    "total_time_s": total_time,
                    "cavs_trained": len(durations),
                    "timestamp": datetime.utcnow().isoformat() + "Z",
                }
                _append_timing_row(timing_path, record)
            if print_timing:
                print(
                    "[CAV timing] "
                    f"layer={layer_name} concept={base_record['concept']} method={method} "
                    f"n={int(n)} cavs={len(durations)} "
                    f"avg={avg_time:.4f}s std={std_time:.4f}s total={total_time:.4f}s"
                )
    if print_timing:
        wall_time = time.perf_counter() - overall_t0
        if total_durations:
            avg_time = float(np.mean(total_durations))
            std_time = float(np.std(total_durations))
            total_time = float(np.sum(total_durations))
            print(
                "[CAV timing] total "
                f"layer={layer_name} concept={base_record['concept']} method={method} "
                f"cavs={len(total_durations)} "
                f"avg={avg_time:.4f}s std={std_time:.4f}s total={total_time:.4f}s "
                f"wall={wall_time:.4f}s"
            )
        else:
            print(
                "[CAV timing] "
                f"layer={layer_name} concept={base_record['concept']} method={method} "
                f"no new CAVs; wall={wall_time:.4f}s"
            )


def _load_cavs_directory(base_dir: str, n_values: Sequence[int], runs: int) -> Dict[Tuple[int,int], List[np.ndarray]]:
    out={}
    for n in n_values:
        for run_id in range(runs):
            p = os.path.join(base_dir, str(n), f"run_{run_id}.pkl")
            if not os.path.exists(p): continue
            try:
                with open(p, "rb") as f:
                    cavs = pickle.load(f)
                vecs=[]
                for item in cavs:
                    if isinstance(item, dict) and "vector" in item: vecs.append(item["vector"])
                    else: vecs.append(item)
                out[(n, run_id)] = [np.asarray(v, dtype=float) for v in vecs]
            except Exception:
                continue
    return out

def cav_variability_analysis(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, n_values: Sequence[int], runs: int) -> pd.DataFrame:
    cavs_map = _load_cavs_directory(cav_root_for_layer_and_concept, n_values, runs)
    records=[]
    for (n, run_id), vecs in cavs_map.items():
        if len(vecs)<2: continue
        value = cav_pairwise_mean_angle_deg(vecs)
        records.append({"layer": layer_name, "concept": concept_name, "n": n, "run": run_id, "value": value})
    return aggregate_variance_by_n(records)

def sensitivity_variance_analysis(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, gradients_for_layer: np.ndarray, n_values: Sequence[int], runs: int) -> pd.DataFrame:
    cavs_map = _load_cavs_directory(cav_root_for_layer_and_concept, n_values, runs)
    records=[]
    for (n, run_id), vecs in cavs_map.items():
        if not vecs: continue
        cav_means=[]
        for v in vecs:
            s = gradients_for_layer@v
            cav_means.append(float(np.mean(s)))
        value = float(np.var(cav_means, ddof=1)) if len(cav_means)>1 else 0.0
        records.append({"layer": layer_name, "concept": concept_name, "n": n, "run": run_id, "value": value})
    return aggregate_variance_by_n(records)

def tcav_score_variance_analysis(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, gradients_for_layer: np.ndarray, n_values: Sequence[int], runs: int) -> pd.DataFrame:
    cavs_map = _load_cavs_directory(cav_root_for_layer_and_concept, n_values, runs)
    records=[]
    for (n, run_id), vecs in cavs_map.items():
        if not vecs: continue
        scores=[]
        for v in vecs:
            s = gradients_for_layer@v
            scores.append(float((s>0).mean()))
        value = float(np.var(scores, ddof=1)) if len(scores)>1 else 0.0
        records.append({"layer": layer_name, "concept": concept_name, "n": n, "run": run_id, "value": value})
    return aggregate_variance_by_n(records)

from .cache import save_df_bundle, try_load_df_bundle
def _cache_name(prefix: str, layer: str, concept: str) -> str:
    return f"{prefix}__{layer}__{concept}"

def cav_variability_analysis_cached(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, n_values: Sequence[int], runs: int, cache_dir: Optional[str]=None, cache_key: Optional[str]=None, load_if_exists: bool=False, save: bool=True):
    params = {"type":"cav_variability","layer":layer_name,"concept":concept_name,"n_values":list(n_values),"runs":int(runs),"cav_root":str(cav_root_for_layer_and_concept)}
    name = cache_key or _cache_name("cav_variability", layer_name, concept_name)
    if cache_dir and load_if_exists:
        d = try_load_df_bundle(cache_dir, scope="analysis", name=name, params=params)
        if d is not None: return d
    df = cav_variability_analysis(cav_root_for_layer_and_concept, layer_name, concept_name, n_values, runs)
    if cache_dir and save and not df.empty:
        save_df_bundle(cache_dir, scope="analysis", name=name, params=params, df=df)
    return df

def sensitivity_variance_analysis_cached(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, gradients_for_layer: np.ndarray, n_values: Sequence[int], runs: int, cache_dir: Optional[str]=None, cache_key: Optional[str]=None, load_if_exists: bool=False, save: bool=True):
    params = {"type":"sensitivity_variance","layer":layer_name,"concept":concept_name,"n_values":list(n_values),"runs":int(runs),"cav_root":str(cav_root_for_layer_and_concept),"grad_shape":list(gradients_for_layer.shape)}
    name = cache_key or _cache_name("sensitivity_variance", layer_name, concept_name)
    if cache_dir and load_if_exists:
        d = try_load_df_bundle(cache_dir, scope="analysis", name=name, params=params)
        if d is not None: return d
    df = sensitivity_variance_analysis(cav_root_for_layer_and_concept, layer_name, concept_name, gradients_for_layer, n_values, runs)
    if cache_dir and save and not df.empty:
        save_df_bundle(cache_dir, scope="analysis", name=name, params=params, df=df)
    return df

def tcav_score_variance_analysis_cached(cav_root_for_layer_and_concept: str, layer_name: str, concept_name: str, gradients_for_layer: np.ndarray, n_values: Sequence[int], runs: int, cache_dir: Optional[str]=None, cache_key: Optional[str]=None, load_if_exists: bool=False, save: bool=True):
    params = {"type":"tcav_score_variance","layer":layer_name,"concept":concept_name,"n_values":list(n_values),"runs":int(runs),"cav_root":str(cav_root_for_layer_and_concept),"grad_shape":list(gradients_for_layer.shape)}
    name = cache_key or _cache_name("tcav_score_variance", layer_name, concept_name)
    if cache_dir and load_if_exists:
        d = try_load_df_bundle(cache_dir, scope="analysis", name=name, params=params)
        if d is not None: return d
    df = tcav_score_variance_analysis(cav_root_for_layer_and_concept, layer_name, concept_name, gradients_for_layer, n_values, runs)
    if cache_dir and save and not df.empty:
        save_df_bundle(cache_dir, scope="analysis", name=name, params=params, df=df)
    return df
