# rational_sc_spambase.py
# -*- coding: utf-8 -*-

import argparse
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score



@torch.no_grad()
def clamp_box(x: torch.Tensor, lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor:
    return torch.max(torch.min(x, hi), lo)


def best_response_rational(
    x_np: np.ndarray,
    w_np: np.ndarray,
    b_np: float,
    x_min: np.ndarray,
    x_max: np.ndarray,
    *,
    steps: int = 40,
    lr: float = 0.05,
    cost_l2: float = 0.02,
    linf_budget: float = 1.5,
    abstain_if_no_gain: bool = True,
) -> np.ndarray:
    """
    Rational best-response:
      maximize   p(x') - cost_l2 * ||x'-x||^2
      p(x)=sigmoid(w^T x + b)

    Constraints:
      box: x' in [x_min, x_max]
      L_inf: |x'-x| <= linf_budget
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x0 = torch.tensor(x_np, dtype=torch.float32, device=device)
    w = torch.tensor(w_np, dtype=torch.float32, device=device)
    b = torch.tensor(float(b_np), dtype=torch.float32, device=device)
    lo = torch.tensor(x_min, dtype=torch.float32, device=device)
    hi = torch.tensor(x_max, dtype=torch.float32, device=device)

    x = x0.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x], lr=lr)

    with torch.no_grad():
        p0 = torch.sigmoid(torch.dot(w, x0) + b)

    for _ in range(steps):
        opt.zero_grad()
        p = torch.sigmoid(torch.dot(w, x) + b)
        cost = cost_l2 * torch.sum((x - x0) ** 2)
        U = p - cost
        (-U).backward()
        opt.step()

        with torch.no_grad():
            dx = torch.clamp(x - x0, -linf_budget, linf_budget)
            x.copy_(x0 + dx)
            x.copy_(clamp_box(x, lo, hi))

    x_star = x.detach()

    if abstain_if_no_gain:
        with torch.no_grad():
            p_star = torch.sigmoid(torch.dot(w, x_star) + b)
            if (p_star - p0).item() <= 0:
                return x_np.copy()

    return x_star.cpu().numpy()


def manipulate_dataset_rational(
    X: np.ndarray,
    y: np.ndarray,
    clf: LogisticRegression,
    x_min: np.ndarray,
    x_max: np.ndarray,
    *,
    steps: int,
    lr: float,
    cost_l2: float,
    linf_budget: float,
) -> np.ndarray:
    """
    Only y=0 agents manipulate (spammers) to be predicted as 1 (ham/accept).
    """
    w = clf.coef_.reshape(-1).astype(np.float32)
    b = float(clf.intercept_.reshape(-1)[0])

    X_out = X.copy()
    idx = np.where(y == 0)[0]
    for i in idx:
        X_out[i] = best_response_rational(
            X_out[i], w, b, x_min, x_max,
            steps=steps, lr=lr, cost_l2=cost_l2, linf_budget=linf_budget,
            abstain_if_no_gain=True,
        )
    return X_out


def train_logreg(X: np.ndarray, y: np.ndarray, seed: int) -> LogisticRegression:
    clf = LogisticRegression(solver="lbfgs", max_iter=2000, random_state=seed)
    clf.fit(X, y)
    return clf


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--data_path", type=str, default="data.csv")
    ap.add_argument("--label_col", type=str, default="Sapm")  # user specified
    ap.add_argument("--rounds", type=int, default=1)
    ap.add_argument("--seed", type=int, default=10)

    # manipulation hyperparams
    ap.add_argument("--steps", type=int, default=40)
    ap.add_argument("--lr", type=float, default=0.05)
    ap.add_argument("--cost_l2", type=float, default=0.02)
    ap.add_argument("--linf_budget", type=float, default=1.5)

    args = ap.parse_args()

    # 1) Load local CSV
    df = pd.read_csv(args.data_path)
    if args.label_col not in df.columns:
        raise ValueError(f"Label col '{args.label_col}' not found. Columns: {list(df.columns)[:10]} ...")

    y_raw = df[args.label_col].to_numpy()
    X = df.drop(columns=[args.label_col]).to_numpy(dtype=np.float32)

    # Ensure labels are 0/1 ints
    y_raw = np.asarray(y_raw).astype(int)

    # Important: You said label column 'Sapm' corresponds to spam.
    # We'll use acceptance label y = 1 - spam, so:
    #   spam=1 -> y=0 (manipulators)
    #   ham =0 -> y=1
    y = 1 - y_raw

    # 2) Train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=args.seed, stratify=y
    )

    # 3) Standardize
    scaler = StandardScaler()
    X_train_s = scaler.fit_transform(X_train).astype(np.float32)
    X_test_s = scaler.transform(X_test).astype(np.float32)

    # Bounds in scaled space
    x_min = X_train_s.min(axis=0)
    x_max = X_train_s.max(axis=0)

    # 4) Defense training loop
    X_curr = X_train_s.copy()
    y_curr = y_train.copy()

    clf = train_logreg(X_curr, y_curr, seed=args.seed)

    for _ in range(args.rounds):
        X_man = manipulate_dataset_rational(
            X_curr, y_curr, clf, x_min, x_max,
            steps=args.steps, lr=args.lr, cost_l2=args.cost_l2, linf_budget=args.linf_budget,
        )
        clf = train_logreg(X_man, y_curr, seed=args.seed)

    # 5) Evaluation
    y_pred_clean = clf.predict(X_test_s)
    acc_clean = accuracy_score(y_test, y_pred_clean)

    X_test_man = manipulate_dataset_rational(
        X_test_s, y_test, clf, x_min, x_max,
        steps=args.steps, lr=args.lr, cost_l2=args.cost_l2, linf_budget=args.linf_budget,
    )
    y_pred_strat = clf.predict(X_test_man)
    acc_strat = accuracy_score(y_test, y_pred_strat)

    idx_spam = np.where(y_test == 0)[0]
    if len(idx_spam) > 0:
        before = clf.predict(X_test_s[idx_spam])
        after = clf.predict(X_test_man[idx_spam])
        flip_rate = float(np.mean((before == 0) & (after == 1)))
    else:
        flip_rate = float("nan")

    print("\n[Rational SC]")
    print(f"data={args.data_path}, label_col={args.label_col}, rounds={args.rounds}")
    print(f"Clean test acc:     {acc_clean:.4f}")
    print(f"Strategic test acc: {acc_strat:.4f}")
    print(f"Spam flip-to-accept rate (y=0 -> pred=1): {flip_rate:.4f}\n")


if __name__ == "__main__":
    torch.set_num_threads(1)
    main()
