# %%
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=FutureWarning)

from main import *
import numpy as np
from joblib import Parallel, delayed

# %%
use_loss = 'hinge' 
sigma_loss = 1.0 
c = 1.0 
d = 8 # dimensions to run experiments on
T = 10  # number of trials
ms = [2**i for i in range(8, 18)]
mu_pos_scalar = 0.25 # class 1 mean
mu_neg_scalar = -0.25 # class -1 mean
mu_pos =  np.concatenate(([mu_pos_scalar], np.zeros(d - 1)))
mu_neg =  np.concatenate(([mu_neg_scalar], np.zeros(d - 1)))
sigmas = [0.1, 0.2, 0.5, 1, 1.5, 2 , 2.5] #np.linspace(2, 0.1, 31) # different sigmas to run the experiments on
base_seed = 12345 
# -----

show_plots = False # set to True to see plots of the binary search
is_throw = True # set to True to throw out points outside margin --> only when loss is lipschitz. 
k = None # if none will be set based on data max norm
k_coef = 1.0 # coefficient to multiply k with
fit_intercept = True # whether to fit intercept in SVM model

# %%
def generate_gaus_data(n_pos, n_neg, mu_pos, mu_neg, sigma_pos, sigma_neg, rng, d=1):
    X_neg = rng.normal(loc=mu_neg, scale=sigma_neg, size=(n_neg, d)) # label -1 
    X_pos = rng.normal(loc=mu_pos, scale=sigma_pos, size=(n_pos, d)) # label 1 
    X = np.vstack([X_neg, X_pos])
    y = np.concatenate([-1 * np.ones(n_neg),np.ones(n_pos)])

    v = np.concatenate([np.ones(n_pos),  np.ones(n_neg)])  # generate_v_data(n_neg, n_pos, rng)
    return X, y, v

# %%
for m in ms:
    print(f"Running m = {m}...")

    def compute_metrics_for_m(t, m=m):
        dfs = []

        # Create a SeedSequence for this trial t
        ss_trial = np.random.SeedSequence([base_seed, t, m])

        # Spawn independent child seeds for each sigma
        child_seeds = ss_trial.spawn(len(sigmas))

        for sigma, child_ss in zip(sigmas, child_seeds):
            rng = np.random.default_rng(child_ss)
            x, y, v = generate_gaus_data(
                m // 2, m // 2, mu_pos, mu_neg, sigma, sigma, rng, d
            )

            df_exact, _ = run_svm_payment(
                x, y, v, c, use_loss,
                sigma_loss=sigma_loss, plot=show_plots,
                is_throw=is_throw, k=k, k_coef=k_coef,
                fit_intercept=fit_intercept,
                payment_mode="exact"
            )

            df_exact['t'] = t
            df_exact['d'] = d
            df_exact['sigma'] = sigma
            df_exact['m_total'] = m
            df_exact['label'] = y[df_exact['agent'].astype(int)]

            dfs.append(df_exact)

        return pd.concat(dfs, ignore_index=True)

    # Parallel run over trials
    vary_m = pd.concat(
        Parallel(n_jobs=-1, backend="loky")(
            delayed(compute_metrics_for_m)(t) for t in range(T)
        ),
        ignore_index=True
    )

    # Save after each m
    filename = f'acc_auc/final_data/vary_m_{m}_d_{d}.csv'
    vary_m.to_csv(filename, index=False)
    print(f"Saved results for m = {m} to {filename}")


