# -*- coding: utf-8 -*-
from __future__ import annotations

import os
import time
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# ---------------------------
# 0) Offline hard block (optional)
# ---------------------------
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# ---------------------------
# TabPFN finetune imports
# ---------------------------
import tabpfn.finetuning.train_util as train_util
import tabpfn.finetuning.finetuned_classifier as finetuned_classifier_mod
import tabpfn.finetuning.finetuned_base as finetuned_base_mod

from tabpfn import TabPFNClassifier
from tabpfn.finetuning.finetuned_classifier import FinetunedTabPFNClassifier



def _extract_model_spec(model_obj):
    for attr in ["model_path", "model_path_"]:
        if hasattr(model_obj, attr):
            spec = getattr(model_obj, attr)
            if spec is None:
                continue
            if isinstance(spec, (str, Path)) or hasattr(spec, "__fspath__"):
                return spec
    for attr in ["model_specs", "model_specs_", "model_spec", "model_spec_"]:
        if hasattr(model_obj, attr):
            spec = getattr(model_obj, attr)
            if spec is None:
                continue
            if isinstance(spec, (str, Path)) or hasattr(spec, "__fspath__"):
                return spec
    raise TypeError(f"Cannot extract model_path / ModelSpecs from type {type(model_obj)}.")


def _clone_model_for_evaluation_patched(model, eval_config, model_class):
    eval_init_args = dict(eval_config)
    eval_init_args.pop("model_path", None)
    spec = _extract_model_spec(model)
    return model_class(model_path=spec, **eval_init_args)


def apply_patch():
    train_util.clone_model_for_evaluation = _clone_model_for_evaluation_patched
    finetuned_classifier_mod.clone_model_for_evaluation = _clone_model_for_evaluation_patched
    if hasattr(finetuned_base_mod, "clone_model_for_evaluation"):
        finetuned_base_mod.clone_model_for_evaluation = _clone_model_for_evaluation_patched
    print("[patch] clone_model_for_evaluation patched ✅", flush=True)


# ==============================================================================
# 2) Config (EDIT THESE)
# ==============================================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

BASE_CKPT = "./tabpfn-v2.5-classifier-v2.5_default.ckpt"  
RAW_CSV = "./spambase_data.csv"                                
CTX_CSV = "./spambase_strategic_context.csv"              

LABEL_COL = "Spam"

# Must match your strategic context setup
SEED = 42
FREEZE_FRAC = 0.10

# Mahalanobis best response
LAMBDA_COST = 6.0
COV_RIDGE = 1e-3
CLIP_NONNEG = True

# Baseline inference setting
N_ESTIMATORS_BASE = 3

# Finetune hyperparams
EPOCHS_FT = 10
LR_FT = 1e-5
WD_FT = 0.01

OUT_DIR = Path("./spambase_ft_strategic").resolve()
REPORT_TXT = OUT_DIR / "report.txt"


# ==============================================================================
# 3) Minimal helpers
# ==============================================================================
def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def now_str():
    return time.strftime("%Y%m%d_%H%M%S")

def compute_covariance(X: np.ndarray, ridge: float) -> np.ndarray:
    Sigma = np.cov(X, rowvar=False, bias=False)
    d = Sigma.shape[0]
    return Sigma + ridge * np.eye(d)

def make_mutable_mask(d: int, freeze_frac: float, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    n_freeze = int(np.round(d * freeze_frac))
    freeze_idx = rng.choice(np.arange(d), size=n_freeze, replace=False)
    mutable = np.ones(d, dtype=bool)
    mutable[freeze_idx] = False
    return mutable

def best_response_mahalanobis(
    X: np.ndarray,
    w: np.ndarray,
    Sigma: np.ndarray,
    mutable_mask: np.ndarray,
    lam: float,
    clip_nonneg: bool = True,
) -> np.ndarray:
    """
    x' = x + (1/lam) * Sigma_mm @ w_m   (only mutable dims)
    NOTE: Only changes X, NEVER changes y.
    """
    X = np.asarray(X, dtype=float)
    w = np.asarray(w, dtype=float).reshape(-1)

    m_idx = np.where(mutable_mask)[0]
    Sigma_mm = Sigma[np.ix_(m_idx, m_idx)]
    w_m = w[m_idx]
    delta_m = (Sigma_mm @ w_m) / float(lam)

    X_tilde = X.copy()
    X_tilde[:, m_idx] = X[:, m_idx] + delta_m[None, :]

    if clip_nonneg:
        X_tilde = np.clip(X_tilde, 0.0, None)
    return X_tilde

def align_feature_order(ctx_df: pd.DataFrame, raw_df: pd.DataFrame) -> list[str]:
    ctx_feat = [c for c in ctx_df.columns if c != LABEL_COL]
    raw_feat = [c for c in raw_df.columns if c != LABEL_COL]
    missing = [c for c in ctx_feat if c not in raw_feat]
    if missing:
        raise ValueError(f"RAW is missing {len(missing)} feature cols from CTX, e.g. {missing[:5]}")
    return ctx_feat



# 4) Main
# ==============================================================================
def main():
    apply_patch()
    ensure_dir(OUT_DIR)

    base_ckpt = Path(BASE_CKPT).expanduser().resolve()
    if not base_ckpt.exists():
        raise FileNotFoundError(f"Base ckpt not found: {base_ckpt}")

    raw_df = pd.read_csv(RAW_CSV)
    ctx_df = pd.read_csv(CTX_CSV)

    if LABEL_COL not in raw_df.columns:
        raise ValueError(f"RAW CSV must contain label col '{LABEL_COL}'")
    if LABEL_COL not in ctx_df.columns:
        raise ValueError(f"CTX CSV must contain label col '{LABEL_COL}'")

    feat_cols = align_feature_order(ctx_df, raw_df)

    # full raw (for f + full test)
    X_all = raw_df[feat_cols].to_numpy(dtype=float)
    y_all = raw_df[LABEL_COL].to_numpy(dtype=int).reshape(-1)  # labels unchanged

    # strategic context (for TabPFN fit/finetune)
    X_ctx = ctx_df[feat_cols].to_numpy(dtype=float)
    y_ctx = ctx_df[LABEL_COL].to_numpy(dtype=int).reshape(-1)

    # train deployed rule f on ALL raw (for best response)
    f = LogisticRegression(solver="liblinear", random_state=SEED, max_iter=4000)
    f.fit(X_all, y_all)
    w = f.coef_.reshape(-1)

    # Sigma + freeze mask (must match context gen)
    Sigma = compute_covariance(X_all, ridge=COV_RIDGE)
    mutable_mask = make_mutable_mask(len(feat_cols), FREEZE_FRAC, seed=SEED)

    # full manipulated test inputs
    X_all_manip = best_response_mahalanobis(
        X=X_all, w=w, Sigma=Sigma, mutable_mask=mutable_mask,
        lam=LAMBDA_COST, clip_nonneg=CLIP_NONNEG
    )

    # -----------------------
    # Baseline (ICL fit)
    # -----------------------
    base = TabPFNClassifier(
        model_path=str(base_ckpt),
        device=DEVICE,
        n_estimators=N_ESTIMATORS_BASE,
        ignore_pretraining_limits=True,
        differentiable_input=False,
        random_state=int(SEED),
        fit_mode="fit_preprocessors",
    )
    base.fit(X_ctx, y_ctx)
    pred_base = base.predict(X_all_manip)
    acc_base = accuracy_score(y_all, pred_base)


    # Official finetune 
    # -----------------------
    ft_out = OUT_DIR / f"ft_run__{now_str()}"
    ensure_dir(ft_out)

    ft = FinetunedTabPFNClassifier(
        device=DEVICE,
        epochs=int(EPOCHS_FT),
        learning_rate=float(LR_FT),
        weight_decay=float(WD_FT),
        validation_split_ratio=0.1,
        early_stopping=False,
        n_estimators_finetune=1,
        n_estimators_validation=1,
        n_estimators_final_inference=2,
        extra_classifier_kwargs={
            "model_path": str(base_ckpt),
            "ignore_pretraining_limits": True,
        },
    )

    try:
        ft.fit(X_ctx, y_ctx, output_dir=ft_out)
    except TypeError:
        ft.fit(X_ctx, y_ctx)

    pred_ft = ft.predict(X_all_manip)
    acc_ft = accuracy_score(y_all, pred_ft)

    # -----------------------
    # report
    # -----------------------
    frozen = int((~mutable_mask).sum())
    lines = []
    lines.append("=" * 80)
    lines.append("Spambase | Strategic Context Finetune (Classifier) | FULL manipulated test | ACC only")
    lines.append(f"time={now_str()}")
    lines.append(f"DEVICE={DEVICE}")
    lines.append(f"Base ckpt={base_ckpt}")
    lines.append(f"Context rows={X_ctx.shape[0]} | dims={X_ctx.shape[1]}")
    lines.append(f"FULL test rows={X_all.shape[0]} (all manipulated) | label unchanged: raw['{LABEL_COL}']")
    lines.append(f"Freeze: {frozen}/{len(feat_cols)} ({FREEZE_FRAC*100:.1f}%) | lambda={LAMBDA_COST} | ridge={COV_RIDGE}")
    lines.append("-" * 80)
    lines.append(f"[BASELINE TabPFNClassifier.fit]     ACC={acc_base:.6f}")
    lines.append(f"[FINETUNE FinetunedTabPFNClassifier.fit] ACC={acc_ft:.6f}")
    lines.append(f"Finetune output_dir={ft_out}")
    lines.append("=" * 80)

    print("\n".join(lines), flush=True)
    with open(REPORT_TXT, "w", encoding="utf-8") as f:
        f.write("\n".join(lines) + "\n")
    print(f"\nSaved: {REPORT_TXT}", flush=True)


if __name__ == "__main__":
    main()
