import numpy as np
from linear_regression.fixed_point_equations.fpeqs import fixed_point_finder
from linear_regression.aux_functions.misc import classification_adversarial_error
from linear_regression.fixed_point_equations.regularisation.pstar_attacks_Lr_reg import (
    f_Lr_regularisation_Lpstar_attack,
)
from linear_regression.fixed_point_equations.classification.Adv_train_p_norm import (
    f_hat_Logistic_no_noise_Linf_adv_classif,
)
from os.path import join, exists
from mpi4py import MPI
from itertools import product
import os
from scipy.optimize import minimize_scalar

MIN_REG_PARAM = 1e-6
XATOL = 1e-6

alpha_min, alpha_max, n_alpha_pts = 0.1, 0.46, 50
alphas = np.logspace(np.log10(alpha_min), np.log10(alpha_max), n_alpha_pts)
eps_min, eps_max, n_eps_pts = 0.05, 0.3, 100
epss = np.linspace(eps_min, eps_max, n_eps_pts)[-53:-48]
epss = [0.1]
print(epss)
print(len(epss))
reg_orders = [1]
pstar = 1

pairs = list(product(epss, reg_orders))

size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()

assert len(pairs) >= size

eps_t, reg_order = pairs[rank]
eps_g = eps_t

data_folder_SE = "./data"

if not exists(data_folder_SE):
    os.makedirs(data_folder_SE)

file_name = f"aaa_SE_alpha_sweep_optimal_lambda_pstar_{pstar}_reg_order_{reg_order}_alpha_{alpha_min:.3f}_{alpha_max:.3f}_eps_{eps_t:.4e}.csv"

n_alpha_pts = len(alphas)
print(f"Rank: {rank}, {n_alpha_pts = }, {alphas[0] = }, {alphas[-1] = }")

ms_found = np.empty((n_alpha_pts,))
qs_found = np.empty((n_alpha_pts,))
Vs_found = np.empty((n_alpha_pts,))
Ps_found = np.empty((n_alpha_pts,))

mhats_found = np.empty((n_alpha_pts,))
qhats_found = np.empty((n_alpha_pts,))
Vhats_found = np.empty((n_alpha_pts,))
Phats_found = np.empty((n_alpha_pts,))

reg_param_opts = np.empty((n_alpha_pts,))

estim_errors_se = np.empty((n_alpha_pts,))
adversarial_errors_found = np.empty((n_alpha_pts,))
gen_errors_se = np.empty((n_alpha_pts,))

if reg_order == 1:
    m, q, V, P = (0.348751, 9.11468, 270.038, 0.615440)
    initial_condition = (m, q, V, P)
elif reg_order == 2:
    m, q, V, P = (6.766e-01, 5.780e00, 4.442e03, 1.780e00)
    initial_condition = (m, q, V, P)
elif reg_order == 3:
    m, q, V, P = (3.873e-01, 2.571e00, 1.710e03, 1.446e00)
    initial_condition = (m, q, V, P)
elif reg_order == 4:
    m, q, V, P = (3.417e-01, 2.140e00, 8.220e02, 1.371e00)
    initial_condition = (m, q, V, P)
else:
    m, q, V, P = (0.6604, 13.959, 25767.13, 1.364)
    initial_condition = (m, q, V, P)

m, q, V, P = (
    2.539323172300404607e00,
    4.429693432929278174e01,
    2.877865984749135350e03,
    2.867387394785570365e00,
)
initial_condition = (m, q, V, P)

reg_param_init = 2.497400307702702608e-03
min_bound, max_bound = reg_param_init / 5, reg_param_init * 5
min_bound, max_bound = 1e-4, 1e-3

print(f"Starting the sweep for {reg_order = }, {eps_t = }")

for jprime, alpha in enumerate(reversed(alphas)):
    print(
        f"Optimization for alpha = {alpha}, eps = {eps_t},  {jprime = }, bounds = {min_bound, max_bound}"
    )
    j = n_alpha_pts - jprime - 1

    f_kwargs = {"reg_param": reg_param_init / alpha, "reg_order": reg_order, "pstar": pstar}
    f_hat_kwargs = {"alpha": alpha, "eps_t": eps_t}

    def minimize_fun(reg_param: float, alpha: float):
        print(
            f"Testing reg_param = {reg_param}, for reg_order = {reg_order}, alpha = {alpha}, eps_t = {eps_t}"
        )
        f_kwargs.update({"reg_param": reg_param / alpha})
        m, q, V, P = fixed_point_finder(
            f_Lr_regularisation_Lpstar_attack,
            f_hat_Logistic_no_noise_Linf_adv_classif,
            initial_condition,
            f_kwargs,
            f_hat_kwargs,
            abs_tol=1e-7,
            min_iter=10,
            args_update_function=(0.3,),
            max_iter=1_000,
        )
        return classification_adversarial_error(m, q, P, eps_g, pstar)

    try:
        obj = minimize_scalar(
            minimize_fun,
            args=(alpha,),
            method="bounded",
            bounds=(min_bound, max_bound),
            tol=XATOL,
            options={"maxiter": 75},
        )
    except RuntimeError:
        continue

    if obj.success:
        print(
            f"Optimisation successful for reg_order = {reg_order:.2f}, alpha = {alpha}, eps_t = {eps_t}"
        )
        fun_min_val = obj.fun
        reg_param_opt = obj.x

        f_kwargs.update({"reg_param": float(reg_param_opt / alpha)})

        ms_found[j], qs_found[j], Vs_found[j], Ps_found[j] = fixed_point_finder(
            f_Lr_regularisation_Lpstar_attack,
            f_hat_Logistic_no_noise_Linf_adv_classif,
            initial_condition,
            f_kwargs,
            f_hat_kwargs,
            abs_tol=1e-7,
            min_iter=10,
            args_update_function=(0.3,),
            max_iter=10_000,
        )

        reg_param_opts[j] = reg_param_opt
        reg_param_init = reg_param_opt

        min_bound, max_bound = reg_param_opt / 5, reg_param_opt * 5

        initial_condition = (ms_found[j], qs_found[j], Vs_found[j], Ps_found[j])

        estim_errors_se[j] = 1 - 2 * ms_found[j] + qs_found[j]
        adversarial_errors_found[j] = classification_adversarial_error(
            ms_found[j], qs_found[j], Ps_found[j], eps_g, pstar
        )
        gen_errors_se[j] = np.arccos(ms_found[j] / np.sqrt(qs_found[j])) / np.pi

    else:
        print(f"Optimisation failed for reg_order = {reg_order}, alpha = {alpha}, eps_t = {eps_t}")

print(f"Saving data for reg_order = {reg_order}, eps_t = {eps_t}")

# # Save data
data = {
    "alpha": alphas,
    "eps": np.ones_like(alphas) * eps_t,
    "reg_order": reg_order * np.ones_like(alphas),
    "m": ms_found,
    "q": qs_found,
    "V": Vs_found,
    "P": Ps_found,
    "mhat": mhats_found,
    "qhat": qhats_found,
    "Vhat": Vhats_found,
    "Phat": Phats_found,
    "reg_param": reg_param_opts,
    "estim_error": estim_errors_se,
    "adversarial_error": adversarial_errors_found,
    "generalisation_error": gen_errors_se,
}

data_array = np.column_stack([data[key] for key in data.keys()])
header = ",".join(data.keys())
np.savetxt(
    join(data_folder_SE, file_name),
    data_array,
    header=header,
    delimiter=",",
    comments="",
)
