import numpy as np
import matplotlib.pyplot as plt
import time

from scipy.optimize import minimize, NonlinearConstraint, minimize_scalar, differential_evolution, root_scalar
from scipy.interpolate import interp1d
from scipy.stats import truncnorm, norm
from sklearn.isotonic import IsotonicRegression
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
import os

# Tolerance for near-zero transformation parameter
tol_alpha = 1e-7

def d_alpha_fun(alpha):
    if abs(alpha) < tol_alpha:
        return lambda y: np.log(y)
    else:
        if alpha < 0:
            return lambda y: -np.power(y, alpha)
        else:
            return lambda y: np.power(y, alpha)

def d_alpha_inv_fun(alpha):
    if abs(alpha) < tol_alpha:
        return lambda x: np.exp(x)
    else:
        if alpha < 0:
            return lambda x: np.power(-x, 1/alpha)
        else:
            return lambda x: np.power(x, 1/alpha)

def convex_matrix(w):
    """
    Constructs a matrix A so that for phi = [phi_1,...,phi_n],
    A.dot(phi) yields the slopes:
         A[i, i]   = -1/(w[i+1]-w[i])
         A[i, i+1] =  1/(w[i+1]-w[i])
    A small cutoff is used to avoid division by zero.
    """
    n = len(w)
    A = np.zeros((n-1, n))
    diffs = np.diff(w)
    diffs = np.where(np.abs(diffs) < 1e-10, 1e-10, diffs)
    for i in range(n-1):
        A[i, i] = -1.0 / diffs[i]
        A[i, i+1] = 1.0 / diffs[i]
    return A

def objective_function(phi, y_s, d_alpha_inv):
    transformed = d_alpha_inv(phi)
    return np.sum((y_s - transformed)**2)

def build_constraints(phi, w, d_alpha, epsi):
    """
    Constructs all constraints in the form c(phi) <= 0:
    
      (i) Value constraints: 
          d_alpha(epsi) <= phi <= d_alpha(1-epsi)
          → c_upper: phi - d_alpha(1-epsi) <= 0
          → c_lower: d_alpha(epsi) - phi <= 0

      (ii) Monotonicity: Let s = A.dot(phi); we require s >= 0, i.e.
          -s <= 0

      (iii) Concavity: The differences between slopes must be non-positive:
          diff(s) <= 0
    """
    # Value constraints
    c_upper = phi - d_alpha(1 - epsi)
    c_lower = d_alpha(epsi) - phi
    
    # Monotonicity & concavity constraints
    A = convex_matrix(w)
    slopes = A.dot(phi)
    c_monotonicity = -slopes          # slopes >= 0 → -slopes <= 0
    c_concavity = np.diff(slopes)      # successive differences <= 0

    return np.concatenate([c_upper, c_lower, c_monotonicity, c_concavity])

def LS_SLSQP(w, y, lamb, epsi):
    d_alpha = d_alpha_fun(lamb)
    d_alpha_inv = d_alpha_inv_fun(lamb)
    
    w_s, idx = np.unique(w, return_index=True)
    y_s = y[idx]
    n = len(w_s)
    
    obj_func = lambda phi: objective_function(phi, y_s, d_alpha_inv)
    
    num_constraints = 2 * n + (n - 1) + (n - 2)
    
    constraints_list = [
        {'type': 'ineq', 'fun': lambda phi, idx=i: -build_constraints(phi, w_s, d_alpha, epsi)[idx]}
        for i in range(num_constraints)
    ]
    
    phi_init = np.linspace(d_alpha(epsi), d_alpha(1 - epsi), n)
    options = {'maxiter': 10000, 'ftol': 1e-8}
    result = minimize(obj_func, phi_init, method='SLSQP', constraints=constraints_list, options=options)
    
    if not result.success:
        print("LS_SLSQP optimization not converged:", result.message)
    
    phi_opt = result.x
    F_hat = interp1d(w_s, d_alpha_inv(phi_opt), kind='linear', fill_value='extrapolate')
    return F_hat


def payoff(psi_func, p, b, q):
    return p * psi_func(p * b + q)

# def p_opt(F, b, q, p_min, p_max):
#     def objective(p):
#         return -payoff(F, p, b, q)  # negate for maximization
#     result = minimize_scalar(objective, bounds=(p_min, p_max), method='bounded')
#     return result.x

def p_opt(F, b, q, p_min, p_max):
    def f(x):
        p = x[0]
        return -payoff(F, p, b, q)
    res = differential_evolution(f, bounds=[(p_min, p_max)], polish=True)
    return float(res.x[0])

    #Selection of kappa_i via root finding
    
def solve_kappa(N, tau, bracket=(1e-8, 1-1e-8)):
    """
    Solve for kappa in (0,1) from:
    4 tau^(1/10) kappa^(3/2) = 5 N^(1/2) (1-kappa)^(7/5)
    """
    f = lambda kappa: 4 * tau**0.1 * kappa**1.5 - 5 * np.sqrt(N) * (1 - kappa)**(7/5)

    sol = root_scalar(f, bracket=bracket, method='brentq')
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Root finding did not converge")
        
def construct_m_and_s(p_m_arr, p_M_arr):
    """
    Compute center and scale parameters from boundaries.
    """
    p_m_arr = np.array(p_m_arr, dtype=float)
    p_M_arr = np.array(p_M_arr, dtype=float)
    m = 0.5 * (p_m_arr + p_M_arr)  # center
    sigma = np.sqrt(0.5 * (p_M_arr - p_m_arr))
    return m, sigma

def truncated_multivariate_gaussian(m, s, size=1, random_state=None):
    """
    Draw samples from a truncated multivariate Gaussian where the
    elliptical constraint is: sum((x_i-m_i)^2/s_i^2) <= 1.
    """
    rng = np.random.default_rng(random_state)
    N = len(m)
    out = np.zeros((size, N), dtype=float)
    n_accepted = 0
    while n_accepted < size:
        batch_size = size - n_accepted
        proposals = rng.normal(loc=m, scale=s, size=(batch_size, N))
        centered = (proposals - m) / s
        sq_mahal = np.sum(centered**2, axis=1)
        mask = (sq_mahal <= 1.0)
        accepted = proposals[mask]
        n_good = len(accepted)
        if n_good > 0:
            limit = min(n_good, batch_size)
            out[n_accepted:n_accepted+limit] = accepted[:limit]
            n_accepted += limit
    return out

def sample_truncated_multivariate_times_tau(p_m_arr, p_M_arr, tau, random_state=None):
    """
    Generate tau samples from the truncated multivariate Gaussian.
    """
    m, s = construct_m_and_s(p_m_arr, p_M_arr)
    P = truncated_multivariate_gaussian(m, s, size=tau, random_state=random_state)
    return P
    
def single_run_experiment(T, N, rng_seed, p_m_arr, p_M_arr, psi, theta_true, p_star, p_initial):
    """
    Runs a single experiment for T time periods.
    Data is collected, theta and psi are estimated, and regret is computed.
    """
    rng = np.random.default_rng(rng_seed)
    m, s = construct_m_and_s(p_m_arr, p_M_arr)

    # Define tau (number of initial samples) and n_i (for psi estimation)
    tau = int(np.ceil(T**(5/7) / 1.5))
    kappa_star = solve_kappa(N, tau)
    n_i = int(np.ceil(tau * kappa_star))

    # Prepare data arrays
    P = np.zeros((tau, N))
    Y = np.zeros((tau, N))
    W = np.zeros((tau, N))

    theta_hat = None
    psi_hat = None
    p_T_final = None
    regrets = np.zeros(T)
    l2_error_psi = np.nan  # will be computed later

    # Data collection and updates
    for t in range(T):
        total_payoff_star = 0.0
        total_payoff_chosen = 0.0

        # 1) Data collection for t < tau
        if t < tau:
            # Sample p from the truncated multivariate Gaussian
            p_t_sample = sample_truncated_multivariate_times_tau(p_m_arr, p_M_arr, 1, random_state=rng)[0]
            P[t, :] = p_t_sample
            
            for i in range(N):
                w_t_i = np.dot(theta_true[i], p_t_sample)
                noise_t_i = rng.uniform(-0.05, 0.05)
                y_t_i = psi[i](w_t_i) + noise_t_i
                Y[t, i] = y_t_i
                W[t, i] = w_t_i

                b_i = theta_true[i][i]
                gamma_i = np.delete(theta_true[i], i)
                q_i = np.delete(p_star, i) @ gamma_i
                payoff_star = payoff(psi[i], p_star[i], b_i, q_i)
                payoff_chosen = payoff(psi[i], p_t_sample[i], b_i, q_i)

                total_payoff_star += payoff_star
                total_payoff_chosen += payoff_chosen

            regrets[t] = total_payoff_star - total_payoff_chosen

        # 2) Estimate theta at t == n_i using all data so far
        if t == n_i:
            X_for_theta = P[:n_i,:]
            X_centered = X_for_theta - X_for_theta.mean(axis=0, keepdims=True)
            #P_centered = P - m
            theta_hat_temp = []
            for i in range(N):
                beta_hat_i = np.linalg.lstsq(X_centered, Y[:n_i, i], rcond=None)[0]
                norm_beta = np.linalg.norm(beta_hat_i, 2)
                if norm_beta > 0:
                    beta_hat_i /= norm_beta
                theta_hat_temp.append(beta_hat_i)
            theta_hat = theta_hat_temp

        # 3) Estimate psi at t == tau (using data from rows [n_i, tau))
        if t == tau:
            psi_hat_temp = []
            for i in range(N):
                w_i = W[n_i:tau, i]
                y_i = Y[n_i:tau, i]
                iso_reg = LS_SLSQP(w_i, y_i, lamb=0, epsi=0.001)
                psi_hat_i = lambda u, reg=iso_reg: np.clip(reg(u), 0, 1)
                psi_hat_temp.append(psi_hat_i)
            psi_hat = psi_hat_temp

            # Compute L2 error of psi estimation (on training data)
            errors = []
            for i in range(N):
                w_i_train = W[n_i:tau, i]
                y_true_i_train = np.array([psi[i](w_val) for w_val in w_i_train])
                y_pred_i_train = np.array([psi_hat[i](w_val) for w_val in w_i_train])
                mse_i = np.mean((y_true_i_train - y_pred_i_train) ** 2)
                errors.append(mse_i)
            l2_error_psi = np.mean(errors)

        # 4) For t >= tau, update p_t using estimated psi (and theta_hat if available)
        if t >= tau:
            if t == tau:
                p_t_minus_1 = p_initial  # starting point for p_t
            p_t_list = []
            for i in range(N):
                p_hat_minus_i = np.delete(p_t_minus_1, i)

                # Use estimated psi if available; otherwise, default to midpoint
                if psi_hat is not None and theta_hat is not None:
                    b_hat_i = theta_hat[i][i]
                    gamma_hat_i = np.delete(theta_hat[i], i)
                    q_hat = p_hat_minus_i @ gamma_hat_i
                    p_hat_t_i = p_opt(psi_hat[i], b_hat_i, q_hat, p_m_arr[i], p_M_arr[i])
                else:
                    p_hat_t_i = (p_m_arr[i] + p_M_arr[i]) / 2.0

                ### regret ###
                b_i = theta_true[i][i]
                gamma_i = np.delete(theta_true[i], i)
                q_i = np.delete(p_star, i) @ gamma_i
                payoff_star = payoff(psi[i], p_star[i], b_i, q_i)
                payoff_chosen = payoff(psi[i], p_hat_t_i, b_i, q_i)
                total_payoff_star += payoff_star
                total_payoff_chosen += payoff_chosen

                p_t_list.append(p_hat_t_i)
            p_t_minus_1 = np.array(p_t_list)

            regrets[t] = total_payoff_star - total_payoff_chosen

            if t == T - 1:
                p_T_final = p_t_minus_1

    # After time T, compute performance metrics
    if theta_hat is not None:
        sum_diff_theta = sum(np.linalg.norm(theta_hat[i] - theta_true[i], 2) for i in range(N))
    else:
        sum_diff_theta = np.nan

    if p_T_final is not None and p_star is not None:
        diff_p = np.linalg.norm(p_T_final - p_star, 2)
    else:
        diff_p = np.nan

    total_regret_run = np.nansum(regrets)
    return sum_diff_theta, l2_error_psi, diff_p, total_regret_run

first_nonzero_seen = False  # Flag to track the first nonzero label

def sci_formatter(x, pos):
    global first_nonzero_seen
    if x == 0:
        return "0"

    exponent = int(np.floor(np.log10(x)))
    base = x / (10 ** exponent)  # Normalize to 10^exponent
    
    if not first_nonzero_seen:
        first_nonzero_seen = True
        return rf"$10^{{{exponent}}}$"
    else:
        return rf"${int(base)} \times 10^{{{exponent}}}$"


def save_plot(N):
    data_folder = "DATA"
    plot_folder = "PLOTS"
    os.makedirs(plot_folder, exist_ok=True)
    
    # Use a clean and modern style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.size'] = 16
    mpl.rcParams['axes.labelsize'] = 18
    mpl.rcParams['axes.titlesize'] = 20
    mpl.rcParams['legend.fontsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 16
    mpl.rcParams['ytick.labelsize'] = 14
    
    # Define your datasets (filenames you saved earlier)
    datasets = {
        r"$L_{\mathbf{\Gamma}} \approx 1$":      f"data_N_{N}_L_1.csv",
        r"$L_{\mathbf{\Gamma}} \, middle$":      f"data_N_{N}_L_middle.csv",
        r"$L_{\mathbf{\Gamma}} \approx 0$":      f"data_N_{N}_L_0.csv"
    }
    
    colors = {
        r"$L_{\mathbf{\Gamma}} \approx 1$": "red",
        r"$L_{\mathbf{\Gamma}} \, middle$": "blue",
        r"$L_{\mathbf{\Gamma}} \approx 0$": "green"
    }
    
    
    # Create 4 subplots
    fig_main, axes = plt.subplots(1, 4, figsize=(20, 5))
    ax1, ax2, ax3, ax4 = axes.ravel()
    
    for label, filename in datasets.items():
        df = pd.read_csv(os.path.join(data_folder, filename))
    
        T_values        = df["T_values"].to_numpy()
        mean_theta_diffs = df["mean_theta_diffs"].to_numpy()
        ci_theta_diffs   = df["ci_theta_diffs"].to_numpy()
        mean_l2_errors   = df["mean_l2_errors"].to_numpy()
        ci_l2_errors     = df["ci_l2_errors"].to_numpy()
        mean_p_diffs     = df["mean_p_diffs"].to_numpy()
        ci_p_diffs       = df["ci_p_diffs"].to_numpy()
        mean_regrets     = df["mean_regrets"].to_numpy()
        ci_regrets       = df["ci_regrets"].to_numpy()
    
        c = colors[label]
    
        # (A) Theta error
        ax1.plot(T_values, mean_theta_diffs, '-o', linewidth=2, alpha=0.8, color=c, label=label)
        ax1.fill_between(T_values,
                         mean_theta_diffs - ci_theta_diffs,
                         mean_theta_diffs + ci_theta_diffs,
                         color=c, alpha=0.2)
    
        # (B) Psi error
        ax2.plot(T_values, mean_l2_errors, '-o', linewidth=2, alpha=0.8, color=c, label=label)
        ax2.fill_between(T_values,
                         mean_l2_errors - ci_l2_errors,
                         mean_l2_errors + ci_l2_errors,
                         color=c, alpha=0.2)
    
        # (C) p difference
        ax3.plot(T_values, mean_p_diffs, '-o', linewidth=2, alpha=0.8, color=c, label=label)
        ax3.fill_between(T_values,
                         mean_p_diffs - ci_p_diffs,
                         mean_p_diffs + ci_p_diffs,
                         color=c, alpha=0.2)
    
        # (D) Regret
        ax4.plot(T_values, mean_regrets, '-o', linewidth=2, alpha=0.8, color=c, label=label)
        ax4.fill_between(T_values,
                         mean_regrets - ci_regrets,
                         mean_regrets + ci_regrets,
                         color=c, alpha=0.2)
    
    # Titles and labels
    ax1.set_title(r'$\sum_{i=1}^{' + str(N) + r'} \|\widehat{\theta}_i^{(T)} - \theta_i\|_2$', fontsize=20, fontweight='bold')
    ax1.set_xlabel("T"); ax1.set_ylabel(f"N={N}"); ax1.grid(True, linestyle="--", alpha=0.5)
    
    ax2.set_title(r'$\sum_{i=1}^{' + str(N) + r'} \|\widehat{\psi}_i^{(T)} - \psi_i\|_2$', fontsize=20, fontweight='bold')
    ax2.set_xlabel("T"); ax2.grid(True, linestyle="--", alpha=0.5)
    
    ax3.set_title(r'$\|\mathbf{p}^{(T)} - \mathbf{p}^\star\|_2$', fontsize=20, fontweight='bold')
    ax3.set_xlabel("T"); ax3.grid(True, linestyle="--", alpha=0.5)
    
    ax4.set_title("Total Expected Regret", fontsize=20, fontweight='bold')
    ax4.set_xlabel("T"); ax4.grid(True, linestyle="--", alpha=0.5)
    
    # Add legends
    for ax in [ax1, ax2, ax3, ax4]:
        ax.legend()
    
    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    
    # Save
    plot_file = os.path.join(plot_folder, f"stacked_overdraw_N_{N}.pdf")
    plt.savefig(plot_file, bbox_inches="tight", dpi=300)
    plt.show()
    
    print(f"Figure saved successfully to {plot_file}")

def plot_all():
    data_folder = "DATA"
    plot_folder = "PLOTS"
    os.makedirs(plot_folder, exist_ok=True)
    
    # Use a clean and modern style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.size'] = 16
    mpl.rcParams['axes.labelsize'] = 18
    mpl.rcParams['axes.titlesize'] = 20
    mpl.rcParams['legend.fontsize'] = 14
    mpl.rcParams['xtick.labelsize'] = 18
    mpl.rcParams['ytick.labelsize'] = 18
    mpl.rcParams['text.color'] = 'black'
    mpl.rcParams['axes.labelcolor'] = 'black'
    
    # Dataset labels + colors
    datasets = {
        r"$L_{\mathbf{\Gamma}} \approx 1$":      "_L_1.csv",
        r"$L_{\mathbf{\Gamma}} \approx 0.5$":      "_L_middle.csv",
        r"$L_{\mathbf{\Gamma}} \approx 0$":      "_L_0.csv"
    }
    
    colors = {
        r"$L_{\mathbf{\Gamma}} \approx 1$": "red",
        r"$L_{\mathbf{\Gamma}} \approx 0.5$": "blue",
        r"$L_{\mathbf{\Gamma}} \approx 0$": "green"
    }
    
    markers = {
        r"$L_{\mathbf{\Gamma}} \approx 1$": "o",   # circle
        r"$L_{\mathbf{\Gamma}} \approx 0.5$": "^",   # triangle
        r"$L_{\mathbf{\Gamma}} \approx 0$": "s"    # square
    }
    MARKERSIZE = 10       # try 10–12
    MARKEREDGE = 1.4
    
    # N values we want
    N_list = [2, 4, 6]
    
    # Create subplot grid: rows = N, cols = 4 metrics
    fig, axes = plt.subplots(len(N_list), 4, figsize=(22, 12), sharex=False)
    axes = np.array(axes)
    
    for row, N in enumerate(N_list):
        ax1, ax2, ax3, ax4 = axes[row]
    
        for label, suffix in datasets.items():
            filename = f"data_N_{N}{suffix}"
            df = pd.read_csv(os.path.join(data_folder, filename))
    
            T_values        = df["T_values"].to_numpy()
            mean_theta_diffs = df["mean_theta_diffs"].to_numpy()
            ci_theta_diffs   = df["ci_theta_diffs"].to_numpy()
            mean_l2_errors   = df["mean_l2_errors"].to_numpy()
            ci_l2_errors     = df["ci_l2_errors"].to_numpy()
            mean_p_diffs     = df["mean_p_diffs"].to_numpy()
            ci_p_diffs       = df["ci_p_diffs"].to_numpy()
            mean_regrets     = df["mean_regrets"].to_numpy()
            ci_regrets       = df["ci_regrets"].to_numpy()
    
            c = colors[label]
    
            # (A) Theta error
            ax1.plot(T_values, mean_theta_diffs, '-',marker=markers[label],markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                     linewidth=2, alpha=0.8, color=c, label=label)
            ax1.fill_between(T_values, mean_theta_diffs - ci_theta_diffs,
                             mean_theta_diffs + ci_theta_diffs, color=c, alpha=0.2)
            ax1.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
            
    
            # (B) Psi error
            ax2.plot(T_values, mean_l2_errors, '-',marker=markers[label], markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                     linewidth=2, alpha=0.8, color=c, label=label)
            ax2.fill_between(T_values, mean_l2_errors - ci_l2_errors,
                             mean_l2_errors + ci_l2_errors, color=c, alpha=0.2)
            ax2.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
    
            # (C) p difference
            ax3.plot(T_values, mean_p_diffs, '-',marker=markers[label], markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                     linewidth=2, alpha=0.8, color=c, label=label)
            ax3.fill_between(T_values, mean_p_diffs - ci_p_diffs,
                             mean_p_diffs + ci_p_diffs, color=c, alpha=0.2)
            ax3.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
    
            # (D) Regret
            ax4.plot(T_values, mean_regrets, '-',marker=markers[label], markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                     linewidth=2, alpha=0.8, color=c, label=label)
            ax4.fill_between(T_values, mean_regrets - ci_regrets,
                             mean_regrets + ci_regrets, color=c, alpha=0.2)
            ax4.grid(True, linestyle="--", linewidth=0.5, alpha=0.7)
    
        # Titles only on top row
        if row == 0:
            ax1.set_title(r'$\sum_{i=1}^N \|\widehat{\theta}_i^{(T)} - \theta_i\|_2$', fontsize=25, fontweight='bold')
            ax2.set_title(r'$\sum_{i=1}^N \|\widehat{\psi}_i^{(T)} - \psi_i\|_2$', fontsize=25, fontweight='bold')
            ax3.set_title(r'$\|\mathbf{p}^{(T)} - \mathbf{p}^\star\|_2$', fontsize=25, fontweight='bold')
            ax4.set_title("Total Expected Regret", fontsize=25, fontweight='bold')
    
        # Row labels
        ax1.set_ylabel(f"N={N}", fontsize=25)
    
        # Add legends only on the last row
        if row == 0:
            for ax in [ax1, ax2, ax3, ax4]:
                ax.legend(fontsize=20)
    
        # Add legends only on the last row
        if row == len(N_list) - 1:
            for ax in [ax1, ax2, ax3, ax4]:
                ax.set_xlabel("T", fontsize=25)
    
    # Tidy up
    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(wspace=0.25, hspace=0.25)
    
    # Save
    plot_file = os.path.join(plot_folder, "grid_overdraw_N246.pdf")
    plt.savefig(plot_file, bbox_inches="tight", dpi=300)
    plt.show()
    
    print(f"Figure saved successfully to {plot_file}")

def plot_all_slopes():
    # --- helper: log-log regression y ~ x^m  (i.e., log y = b + m log x) ---
    def loglog_fit(x, y):
        x = np.asarray(x, dtype=float)
        y = np.asarray(y, dtype=float)
        mask = np.isfinite(x) & np.isfinite(y) & (x > 0) & (y > 0)
        if mask.sum() < 2:
            return np.nan, np.nan
        logx, logy = np.log(x[mask]), np.log(y[mask])
        m, b = np.polyfit(logx, logy, 1)
        return m, b
    
    data_folder = "DATA"
    plot_folder = "PLOTS"
    os.makedirs(plot_folder, exist_ok=True)
    
    # Use a clean and modern style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.size'] = 16
    mpl.rcParams['axes.labelsize'] = 18
    mpl.rcParams['axes.titlesize'] = 20
    mpl.rcParams['legend.fontsize'] = 14
    mpl.rcParams['xtick.labelsize'] = 18
    mpl.rcParams['ytick.labelsize'] = 18
    mpl.rcParams['text.color'] = 'black'
    mpl.rcParams['axes.labelcolor'] = 'black'
    
    # Dataset labels + colors
    datasets = {
        r"$L_{\mathbf{\Gamma}} \approx 1$":      "_L_1.csv",
        r"$L_{\mathbf{\Gamma}} \approx 0.5$":    "_L_middle.csv",
        r"$L_{\mathbf{\Gamma}} \approx 0$":      "_L_0.csv"
    }
    
    colors = {
        r"$L_{\mathbf{\Gamma}} \approx 1$": "red",
        r"$L_{\mathbf{\Gamma}} \approx 0.5$": "blue",
        r"$L_{\mathbf{\Gamma}} \approx 0$": "green"
    }
    
    markers = {
        r"$L_{\mathbf{\Gamma}} \approx 1$": "o",   # circle
        r"$L_{\mathbf{\Gamma}} \approx 0.5$": "^", # triangle
        r"$L_{\mathbf{\Gamma}} \approx 0$": "s"    # square
    }
    MARKERSIZE = 10
    MARKEREDGE = 1.4
    EPS = 1e-300
    
    # N values we want
    N_list = [2, 4, 6]
    
    # Create subplot grid: rows = N, cols = 2 metrics
    fig, axes = plt.subplots(len(N_list), 2, figsize=(15, 14), sharex=False)
    axes = np.array(axes)
    
    for row, N in enumerate(N_list):
        ax1, ax2 = axes[row]
    
        for label, suffix in datasets.items():
            filename = f"data_N_{N}{suffix}"
            df = pd.read_csv(os.path.join(data_folder, filename))
    
            T_values     = df["T_values"].to_numpy()
            mean_p_diffs = df["mean_p_diffs"].to_numpy()
            ci_p_diffs   = df["ci_p_diffs"].to_numpy()
            mean_regrets = df["mean_regrets"].to_numpy()
            ci_regrets   = df["ci_regrets"].to_numpy()
    
            c = colors[label]
    
            # --- P differences (log–log plot) ---
            lower_p = np.maximum(mean_p_diffs - ci_p_diffs, EPS)
            upper_p = np.maximum(mean_p_diffs + ci_p_diffs, EPS)
    
            m_p, b_p = loglog_fit(T_values, mean_p_diffs)
            marker_label = rf"m={m_p:.2f}" if np.isfinite(m_p) else label
    
            # Plot log-transformed data
            ax1.plot(np.log(T_values), np.log(mean_p_diffs), '-', 
                     marker=markers[label], markersize=MARKERSIZE,
                     markeredgewidth=MARKEREDGE, linewidth=2, alpha=0.8, 
                     color=c, label=marker_label)
            ax1.fill_between(np.log(T_values), np.log(lower_p), np.log(upper_p), 
                             color=c, alpha=0.2)
    
            # Regression line
            if np.isfinite(m_p):
                logxfit = np.linspace(np.log(T_values.min()), np.log(T_values.max()), 200)
                logyfit = b_p + m_p * logxfit
                ax1.plot(logxfit, logyfit, '--', color=c, lw=2, alpha=0.9, label='_nolegend_')
    
    
            # --- Regret (log–log plot) ---
            lower_r = np.maximum(mean_regrets - ci_regrets, EPS)
            upper_r = np.maximum(mean_regrets + ci_regrets, EPS)
    
            m_r, b_r = loglog_fit(T_values, mean_regrets)
            marker_label_r = rf"m={m_r:.2f}" if np.isfinite(m_r) else label
    
            ax2.plot(np.log(T_values), np.log(mean_regrets), '-', 
                     marker=markers[label], markersize=MARKERSIZE,
                     markeredgewidth=MARKEREDGE, linewidth=2, alpha=0.8, 
                     color=c, label=marker_label_r)
            ax2.fill_between(np.log(T_values), np.log(lower_r), np.log(upper_r), 
                             color=c, alpha=0.2)
    
            if np.isfinite(m_r):
                logxfit_r = np.linspace(np.log(T_values.min()), np.log(T_values.max()), 200)
                logyfit_r = b_r + m_r * logxfit_r
                ax2.plot(logxfit_r, logyfit_r, '--', color=c, lw=2, alpha=0.9, label='_nolegend_')
    
    
        # Titles on top row
        if row == 0:
            ax1.set_title(r'$\log \|\mathbf{p}^{(T)} - \mathbf{p}^\star\|_2$', fontsize=25, fontweight='bold')
            ax2.set_title(r'$\log$ Total Expected Regret', fontsize=25, fontweight='bold')
    
        # Row labels
        ax1.set_ylabel(f"N={N}", fontsize=25)
    
        # Legends on every row (to see slopes per N)
        ax1.legend(fontsize=16)
        ax2.legend(fontsize=16)
    
        # X labels on the bottom row
        if row == len(N_list) - 1:
            ax1.set_xlabel(r"$\log T$", fontsize=22)
            ax2.set_xlabel(r"$\log T$", fontsize=22)
    
    # Tidy up
    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(wspace=0.25, hspace=0.25)
    
    # Save
    plot_file = os.path.join(plot_folder, "grid_overdraw_N246_loglog.pdf")
    plt.savefig(plot_file, bbox_inches="tight", dpi=300)
    plt.show()
    
    print(f"Figure saved successfully to {plot_file}")
