import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=FutureWarning)

from main import *
from itertools import product
import os
import numpy as np
import pandas as pd
from joblib import Parallel, delayed, parallel_backend

# ------------------------
# Load data once
# ------------------------
weight_col = "ord__WKHP"
target = "PINCP"

df_full = pd.read_csv("acc_auc/final_data/NJ_data_with_noise.csv")
df_full[weight_col] = np.where(df_full[weight_col] < 40, 1,
                        np.where(df_full[weight_col] == 40, 2, 5))

updated_features = [
    f for f in df_full.columns
    if f not in [weight_col, "original_weight", target]
]

all_x = df_full[updated_features].to_numpy()
all_y = df_full[target].to_numpy()
all_w = df_full[weight_col].to_numpy()

pos_idx = np.where(all_y == 1)[0]
neg_idx = np.where(all_y == 0)[0]


# ------------------------
# Parameters
# ------------------------
number_samples =  [500, 1000, 1250, 2500, 5000, 10000]
test_size_per_class = 2500

losses = ["log"] # ["hinge", "log", "squared_hinge"]
# throw = {"hinge": True, "log": False, "squared_hinge": True}

cs = [1000]


k = np.linalg.norm(all_x, axis=1).max()
sigma_loss = 1.0
show_plots = False
k_coef = 1.0
fit_intercept = True

T = 1
base_seed = 42


# ------------------------
# Order-preserving data generation
# ------------------------
def generate_data(num, rng):

    sampled_pos = rng.choice(pos_idx, size=num, replace=False)
    sampled_neg = rng.choice(neg_idx, size=num, replace=False)

    train_idx = np.concatenate([sampled_pos, sampled_neg])

    X = all_x[train_idx]
    y = all_y[train_idx]
    v = all_w[train_idx]

    # ---- Boolean mask (preserves order)
    mask = np.ones(len(all_y), dtype=bool)
    mask[train_idx] = False

    remaining_pos = pos_idx[mask[pos_idx]]
    remaining_neg = neg_idx[mask[neg_idx]]

    test_pos = rng.choice(remaining_pos, size=test_size_per_class, replace=False)
    test_neg = rng.choice(remaining_neg, size=test_size_per_class, replace=False)

    test_idx = np.concatenate([test_pos, test_neg])

    X_test = all_x[test_idx]
    y_test = all_y[test_idx]
    v_test = all_w[test_idx]

    return X, y, v, X_test, y_test, v_test


# ------------------------
# One (t, num) experiment
# ------------------------
def run_trial(t, num):

    rng  = np.random.default_rng(base_seed + 1000*t + num)

    X, y, v, X_test, y_test, v_test = generate_data(num, rng)

    all_results = []

    for c in cs:

        print(f"running: t={t}, num={num}, c={c}")

        df_results, model = run_svm_payment(
            X, y, v, c, 'log',
            sigma_loss=sigma_loss,
            plot=show_plots,
            is_throw= False,
            k=k,
            k_coef=k_coef,
            fit_intercept=fit_intercept,
            payment_mode="exact"
        )

        df_results["t"] = t
        df_results["num"] = num
        df_results["c"] = c
        df_results["loss"] = 'log'
        df_results["label"] = y[df_results["agent"].astype(int)]

        test_alloc = (model.predict(X_test) == y_test).astype(int)

        df_results["test_acc"] = test_alloc.mean()
        df_results["test_welfare"] = (
            np.sum(test_alloc * v_test) / np.sum(v_test)
        )

        all_results.append(df_results)

    final_df = pd.concat(all_results, ignore_index=True)

    os.makedirs("acc_auc/final_data/vary_loss", exist_ok=True)

    final_df.to_csv(
        f"acc_auc/final_data/vary_loss/results_t_{t}_num_{num}_loss_{"log"}_c_{c}.csv",
        index=False
    )


# ------------------------
# Parallel execution
# ------------------------
print("start running...")

n_jobs = int(os.environ.get("SLURM_CPUS_PER_TASK", "1"))

with parallel_backend("loky", inner_max_num_threads=1):
    Parallel(n_jobs=n_jobs)(
        delayed(run_trial)(t, num)
        for t, num in product(range(T), number_samples)
    )
