# prosf_sc_spambase.py
# -*- coding: utf-8 -*-

import argparse
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score


# ----------------------------
# Pro-SF components
# ----------------------------
def prelec_weight(p: torch.Tensor, alpha: float = 0.7, eps: float = 1e-6) -> torch.Tensor:
    p = torch.clamp(p, eps, 1.0)
    return torch.exp(-torch.pow(-torch.log(p), alpha))


def prospect_value(delta: torch.Tensor, loss_aversion: float = 2.0) -> torch.Tensor:
    return torch.where(delta >= 0, delta, loss_aversion * delta)


@torch.no_grad()
def clamp_box(x: torch.Tensor, lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
    return torch.max(torch.min(x, hi), lo)


def best_response_prosf(
    x_np: np.ndarray,
    w_np: np.ndarray,
    b_np: float,
    x_min: np.ndarray,
    x_max: np.ndarray,
    *,
    steps: int = 60,
    lr: float = 0.05,
    cost_l2: float = 0.02,
    linf_budget: float = 1.5,
    loss_aversion_lambda: float = 2.0,
    prelec_alpha: float = 0.7,
    abstain_if_no_gain: bool = True,
) -> np.ndarray:
    """
    Pro-SF-like best response:
      U = v( w(p(x')) - w(p_ref) ) - cost_l2 * ||x'-x||^2
    where:
      p_ref = p(x)  (reference dependence)
      w(.)   = Prelec probability weighting
      v(.)   = loss aversion around 0
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x0 = torch.tensor(x_np, dtype=torch.float32, device=device)
    w = torch.tensor(w_np, dtype=torch.float32, device=device)
    b = torch.tensor(float(b_np), dtype=torch.float32, device=device)
    lo = torch.tensor(x_min, dtype=torch.float32, device=device)
    hi = torch.tensor(x_max, dtype=torch.float32, device=device)

    x = x0.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x], lr=lr)

    with torch.no_grad():
        p0 = torch.sigmoid(torch.dot(w, x0) + b)
        wpref = prelec_weight(p0, alpha=prelec_alpha)

    for _ in range(steps):
        opt.zero_grad()

        p = torch.sigmoid(torch.dot(w, x) + b)
        wp = prelec_weight(p, alpha=prelec_alpha)
        delta = wp - wpref
        v = prospect_value(delta, loss_aversion=loss_aversion_lambda)

        cost = cost_l2 * torch.sum((x - x0) ** 2)
        U = v - cost
        (-U).backward()
        opt.step()

        with torch.no_grad():
            dx = torch.clamp(x - x0, -linf_budget, linf_budget)
            x.copy_(x0 + dx)
            x.copy_(clamp_box(x, lo, hi))

    x_star = x.detach()

    if abstain_if_no_gain:
        with torch.no_grad():
            p_star = torch.sigmoid(torch.dot(w, x_star) + b)
            wp_star = prelec_weight(p_star, alpha=prelec_alpha)
            if (wp_star - wpref).item() <= 0:
                return x_np.copy()

    return x_star.cpu().numpy()


def manipulate_dataset_prosf(
    X: np.ndarray,
    y: np.ndarray,
    clf: LogisticRegression,
    x_min: np.ndarray,
    x_max: np.ndarray,
    *,
    steps: int,
    lr: float,
    cost_l2: float,
    linf_budget: float,
    loss_aversion_lambda: float,
    prelec_alpha: float,
) -> np.ndarray:
    w = clf.coef_.reshape(-1).astype(np.float32)
    b = float(clf.intercept_.reshape(-1)[0])

    X_out = X.copy()
    idx = np.where(y == 0)[0]  # only spammers manipulate
    for i in idx:
        X_out[i] = best_response_prosf(
            X_out[i], w, b, x_min, x_max,
            steps=steps, lr=lr, cost_l2=cost_l2, linf_budget=linf_budget,
            loss_aversion_lambda=loss_aversion_lambda, prelec_alpha=prelec_alpha,
            abstain_if_no_gain=True,
        )
    return X_out


def train_logreg(X: np.ndarray, y: np.ndarray, seed: int) -> LogisticRegression:
    clf = LogisticRegression(solver="lbfgs", max_iter=2000, random_state=seed)
    clf.fit(X, y)
    return clf


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_path", type=str, default="data.csv")
    ap.add_argument("--label_col", type=str, default="Sapm")
    ap.add_argument("--rounds", type=int, default=5)
    ap.add_argument("--seed", type=int, default=0)

    # manipulation hyperparams
    ap.add_argument("--steps", type=int, default=60)
    ap.add_argument("--lr", type=float, default=0.05)
    ap.add_argument("--cost_l2", type=float, default=0.02)
    ap.add_argument("--linf_budget", type=float, default=1.5)
    ap.add_argument("--loss_lambda", type=float, default=2.0)
    ap.add_argument("--prelec_alpha", type=float, default=0.7)

    args = ap.parse_args()

    # 1) Load local CSV
    df = pd.read_csv(args.data_path)
    if args.label_col not in df.columns:
        raise ValueError(f"Label col '{args.label_col}' not found. Columns: {list(df.columns)[:10]} ...")

    y_raw = df[args.label_col].to_numpy()
    X = df.drop(columns=[args.label_col]).to_numpy(dtype=np.float32)

    y_raw = np.asarray(y_raw).astype(int)

    # label semantics: Sapm == 1 means spam
    # acceptance label y = 1 - spam
    y = 1 - y_raw

    # 2) Split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=args.seed, stratify=y
    )

    # 3) Standardize
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train).astype(np.float32)
    X_test_s = scaler.transform(X_test).astype(np.float32)

    x_min = X_train_s.min(axis=0)
    x_max = X_train_s.max(axis=0)

    # 4) Defense training loop
    X_curr = X_train_s.copy()
    y_curr = y_train.copy()

    clf = train_logreg(X_curr, y_curr, seed=args.seed)

    for _ in range(args.rounds):
        X_man = manipulate_dataset_prosf(
            X_curr, y_curr, clf, x_min, x_max,
            steps=args.steps, lr=args.lr, cost_l2=args.cost_l2, linf_budget=args.linf_budget,
            loss_aversion_lambda=args.loss_lambda, prelec_alpha=args.prelec_alpha,
        )
        clf = train_logreg(X_man, y_curr, seed=args.seed)

    # 5) Evaluation
    y_pred_clean = clf.predict(X_test_s)
    acc_clean = accuracy_score(y_test, y_pred_clean)

    X_test_man = manipulate_dataset_prosf(
        X_test_s, y_test, clf, x_min, x_max,
        steps=args.steps, lr=args.lr, cost_l2=args.cost_l2, linf_budget=args.linf_budget,
        loss_aversion_lambda=args.loss_lambda, prelec_alpha=args.prelec_alpha,
    )
    y_pred_strat = clf.predict(X_test_man)
    acc_strat = accuracy_score(y_test, y_pred_strat)

    idx_spam = np.where(y_test == 0)[0]
    if len(idx_spam) > 0:
        before = clf.predict(X_test_s[idx_spam])
        after = clf.predict(X_test_man[idx_spam])
        flip_rate = float(np.mean((before == 0) & (after == 1)))
    else:
        flip_rate = float("nan")

    print("\n[Pro-SF SC]")
    print(f"data={args.data_path}, label_col={args.label_col}, rounds={args.rounds}")
    print(f"loss_lambda={args.loss_lambda}, prelec_alpha={args.prelec_alpha}")
    print(f"Clean test acc:     {acc_clean:.4f}")
    print(f"Strategic test acc: {acc_strat:.4f}")
    print(f"Spam flip-to-accept rate (y=0 -> pred=1): {flip_rate:.4f}\n")


if __name__ == "__main__":
    torch.set_num_threads(1)
    main()
