# -*- coding: utf-8 -*-
from __future__ import annotations

import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score

from tabpfn import TabPFNClassifier
# from tabpfn.constants import ModelVersion


# =============================
# Config
# =============================
RAW_CSV_PATH = "spambase_data.csv"                     
CTX_CSV_PATH = "spambase_strategic_context.csv"   

SEED = 42
LABEL_COL = "Spam"

# Mahalanobis best response
FREEZE_FRAC = 0.10
LAMBDA_COST = 6.0
COV_RIDGE = 1e-3
CLIP_NONNEG = True



def make_mutable_mask(d: int, freeze_frac: float, seed: int) -> np.ndarray:
    rng = np.random.default_rng(seed)
    n_freeze = int(np.round(d * freeze_frac))
    freeze_idx = rng.choice(np.arange(d), size=n_freeze, replace=False)
    mutable = np.ones(d, dtype=bool)
    mutable[freeze_idx] = False
    return mutable


def compute_covariance(X: np.ndarray, ridge: float) -> np.ndarray:
    Sigma = np.cov(X, rowvar=False, bias=False)
    d = Sigma.shape[0]
    return Sigma + ridge * np.eye(d)


def best_response_mahalanobis(
    X: np.ndarray,
    w: np.ndarray,
    Sigma: np.ndarray,
    mutable_mask: np.ndarray,
    lam: float,
    clip_nonneg: bool = True,
) -> np.ndarray:
    """
    x' = x + (1/lam) * Sigma_mm @ w_m   (only on mutable dims)
    NOTE: Only changes X, never touches y.
    """
    X = np.asarray(X, dtype=float)
    w = np.asarray(w, dtype=float).reshape(-1)

    m_idx = np.where(mutable_mask)[0]
    Sigma_mm = Sigma[np.ix_(m_idx, m_idx)]
    w_m = w[m_idx]

    delta_m = (Sigma_mm @ w_m) / float(lam)

    X_tilde = X.copy()
    X_tilde[:, m_idx] = X[:, m_idx] + delta_m[None, :]

    if clip_nonneg:
        X_tilde = np.clip(X_tilde, 0.0, None)

    return X_tilde


def ensure_same_feature_order(df_ctx: pd.DataFrame, df_raw: pd.DataFrame) -> list[str]:
    """
    Use context feature order as canonical; ensure raw contains them.
    """
    ctx_feat = [c for c in df_ctx.columns if c != LABEL_COL]
    raw_feat = [c for c in df_raw.columns if c != LABEL_COL]

    missing = [c for c in ctx_feat if c not in raw_feat]
    if missing:
        raise ValueError(f"Raw is missing {len(missing)} context feature cols, e.g. {missing[:5]}")
    return ctx_feat


# =============================
# Main
# =============================
def main():
    # ---- Load raw full dataset
    df_raw = pd.read_csv(RAW_CSV_PATH)
    if LABEL_COL not in df_raw.columns:
        raise ValueError(f"RAW_CSV must contain label column '{LABEL_COL}'")

    # ---- Load strategic context
    df_ctx = pd.read_csv(CTX_CSV_PATH)
    if LABEL_COL not in df_ctx.columns:
        raise ValueError(f"CTX_CSV must contain label column '{LABEL_COL}'")

    # ---- Align features using context column order
    feat_cols = ensure_same_feature_order(df_ctx, df_raw)

    # Full raw (for training f and full testing)
    X_all = df_raw[feat_cols].to_numpy(dtype=float)
    y_all = df_raw[LABEL_COL].to_numpy(dtype=int).reshape(-1)   

    # ---- Train deployed rule f on ALL raw
    f = LogisticRegression(solver="liblinear", random_state=SEED, max_iter=4000)
    f.fit(X_all, y_all)
    w = f.coef_.reshape(-1)

    # ---- Estimate Sigma (Mahalanobis)
    Sigma = compute_covariance(X_all, ridge=COV_RIDGE)

    # ---- Freeze 10% features (must match your context generation)
    d = len(feat_cols)
    mutable_mask = make_mutable_mask(d, FREEZE_FRAC, SEED)

    # ---- TabPFN: fit on strategic context only
    X_ctx = df_ctx[feat_cols].to_numpy(dtype=float)
    y_ctx = df_ctx[LABEL_COL].to_numpy(dtype=int).reshape(-1)

    clf = TabPFNClassifier(device="auto")
    # If you want v2 explicitly:
    # clf = TabPFNClassifier.create_default_for_version(ModelVersion.V2)

    clf.fit(X_ctx, y_ctx)

    # ---- FULL TEST: manipulate ALL samples, labels unchanged
    X_all_manip = best_response_mahalanobis(
        X=X_all, w=w, Sigma=Sigma, mutable_mask=mutable_mask,
        lam=LAMBDA_COST, clip_nonneg=CLIP_NONNEG
    )

    # ---- Predict using TabPFN (official-demo style)
    predictions = clf.predict(X_all_manip)
    prediction_probabilities = clf.predict_proba(X_all_manip)

    # ---- Evaluate
    acc = accuracy_score(y_all, predictions)
    auc = roc_auc_score(y_all, prediction_probabilities[:, 1])

    frozen = int((~mutable_mask).sum())
    print(f"[OK] TabPFN fit on strategic context: n_ctx={len(X_ctx)} d={d}")
    print(f"[OK] FULL TEST on ALL manipulated X (labels unchanged): n_test={len(X_all_manip)}")
    print(f"[OK] Frozen features: {frozen}/{d} ({FREEZE_FRAC*100:.1f}%)")
    print("ROC AUC:", auc)
    print("Accuracy:", acc)


if __name__ == "__main__":
    main()
