import numpy as np
import matplotlib.pyplot as plt
import time

from scipy.optimize import minimize, NonlinearConstraint, minimize_scalar, differential_evolution, root_scalar
from scipy.interpolate import interp1d
from scipy.stats import truncnorm, norm
from sklearn.isotonic import IsotonicRegression  # (only used for reference/comparison)

import matplotlib as mpl
import pandas as pd
import os


###############################################
# 1. LS ALGORITHM FOR s-CONCAVITY
###############################################

# Tolerance for near-zero transformation parameter
tol_alpha = 1e-7

def d_alpha_fun(alpha):
    if abs(alpha) < tol_alpha:
        return lambda y: np.log(y)
    else:
        if alpha < 0:
            return lambda y: -np.power(y, alpha)
        else:
            return lambda y: np.power(y, alpha)

def d_alpha_inv_fun(alpha):
    if abs(alpha) < tol_alpha:
        return lambda x: np.exp(x)
    else:
        if alpha < 0:
            return lambda x: np.power(-x, 1/alpha)
        else:
            return lambda x: np.power(x, 1/alpha)

def convex_matrix(w):
    """
    Constructs a matrix A so that for phi = [phi_1,...,phi_n],
    A.dot(phi) yields the slopes:
         A[i, i]   = -1/(w[i+1]-w[i])
         A[i, i+1] =  1/(w[i+1]-w[i])
    A small cutoff is used to avoid division by zero.
    """
    n = len(w)
    A = np.zeros((n-1, n))
    diffs = np.diff(w)
    diffs = np.where(np.abs(diffs) < 1e-10, 1e-10, diffs)
    for i in range(n-1):
        A[i, i] = -1.0 / diffs[i]
        A[i, i+1] = 1.0 / diffs[i]
    return A

def objective_function(phi, y_s, d_alpha_inv):
    transformed = d_alpha_inv(phi)
    return np.sum((y_s - transformed)**2)

def build_constraints(phi, w, d_alpha, epsi):
    """
    Constructs all constraints in the form c(phi) <= 0:
    
      (i) Value constraints: 
          d_alpha(epsi) <= phi <= d_alpha(1-epsi)
          → c_upper: phi - d_alpha(1-epsi) <= 0
          → c_lower: d_alpha(epsi) - phi <= 0

      (ii) Monotonicity: Let s = A.dot(phi); we require s >= 0, i.e.
          -s <= 0

      (iii) Concavity: The differences between slopes must be non-positive:
          diff(s) <= 0
    """
    # Value constraints
    c_upper = phi - d_alpha(1 - epsi)
    c_lower = d_alpha(epsi) - phi
    
    # Monotonicity & concavity constraints
    A = convex_matrix(w)
    slopes = A.dot(phi)
    c_monotonicity = -slopes          # slopes >= 0 → -slopes <= 0
    c_concavity = np.diff(slopes)      # successive differences <= 0

    return np.concatenate([c_upper, c_lower, c_monotonicity, c_concavity])

def LS_SLSQP(w, y, lamb, epsi):
    d_alpha = d_alpha_fun(lamb)
    d_alpha_inv = d_alpha_inv_fun(lamb)
    
    w_s, idx = np.unique(w, return_index=True)
    y_s = y[idx]
    n = len(w_s)
    
    obj_func = lambda phi: objective_function(phi, y_s, d_alpha_inv)
    
    num_constraints = 2 * n + (n - 1) + (n - 2)
    
    constraints_list = [
        {'type': 'ineq', 'fun': lambda phi, idx=i: -build_constraints(phi, w_s, d_alpha, epsi)[idx]}
        for i in range(num_constraints)
    ]
    
    phi_init = np.linspace(d_alpha(epsi), d_alpha(1 - epsi), n)
    options = {'maxiter': 10000, 'ftol': 1e-8}
    result = minimize(obj_func, phi_init, method='SLSQP', constraints=constraints_list, options=options)
    
    if not result.success:
        print("LS_SLSQP optimization not converged:", result.message)
    
    phi_opt = result.x
    F_hat = interp1d(w_s, d_alpha_inv(phi_opt), kind='linear', fill_value='extrapolate')
    return F_hat


###############################################
# 2. PAYOFF or INSTANTANEOUS REVENUE
###############################################

def payoff(psi_func, p, b, q):
    return p * psi_func(p * b + q)

# def p_opt(F, b, q, p_min, p_max):
#     def objective(p):
#         return -payoff(F, p, b, q)  # negate for maximization
#     result = minimize_scalar(objective, bounds=(p_min, p_max), method='bounded')
#     return result.x

def p_opt(F, b, q, p_min, p_max):
    def f(x):
        p = x[0]
        return -payoff(F, p, b, q)
    res = differential_evolution(f, bounds=[(p_min, p_max)], polish=True)
    return float(res.x[0])

    #Selection of kappa_i via root finding
    
def solve_kappa(N, tau, bracket=(1e-8, 1-1e-8)):
    """
    Solve for kappa in (0,1) from:
    4 tau^(1/10) kappa^(3/2) = 5 N^(1/2) (1-kappa)^(7/5)
    """
    f = lambda kappa: 4 * tau**0.1 * kappa**1.5 - 5 * np.sqrt(N) * (1 - kappa)**(7/5)

    sol = root_scalar(f, bracket=bracket, method='brentq')
    if sol.converged:
        return sol.root
    else:
        raise RuntimeError("Root finding did not converge")

###############################################
# 3. TRUNCATED GAUSSIAN SAMPLING FUNCTIONS
###############################################

def construct_m_and_s(p_m_arr, p_M_arr):
    """
    Compute center and scale parameters from boundaries.
    """
    p_m_arr = np.array(p_m_arr, dtype=float)
    p_M_arr = np.array(p_M_arr, dtype=float)
    m = 0.5 * (p_m_arr + p_M_arr)  # center
    sigma = np.sqrt(0.5 * (p_M_arr - p_m_arr))
    return m, sigma

def truncated_multivariate_gaussian(m, s, size=1, random_state=None):
    """
    Draw samples from a truncated multivariate Gaussian where the
    elliptical constraint is: sum((x_i-m_i)^2/s_i^2) <= 1.
    """
    rng = np.random.default_rng(random_state)
    N = len(m)
    out = np.zeros((size, N), dtype=float)
    n_accepted = 0
    while n_accepted < size:
        batch_size = size - n_accepted
        proposals = rng.normal(loc=m, scale=s, size=(batch_size, N))
        centered = (proposals - m) / s
        sq_mahal = np.sum(centered**2, axis=1)
        mask = (sq_mahal <= 1.0)
        accepted = proposals[mask]
        n_good = len(accepted)
        if n_good > 0:
            limit = min(n_good, batch_size)
            out[n_accepted:n_accepted+limit] = accepted[:limit]
            n_accepted += limit
    return out

def sample_truncated_multivariate_times_tau(p_m_arr, p_M_arr, tau, random_state=None):
    """
    Generate tau samples from the truncated multivariate Gaussian.
    """
    m, s = construct_m_and_s(p_m_arr, p_M_arr)
    P = truncated_multivariate_gaussian(m, s, size=tau, random_state=random_state)
    return P


###############################################
# 5. EXPERIMENT: SINGLE RUN EXPERIMENT FUNCTION
###############################################

def single_run_experiment(T, N, rng_seed, p_m_arr, p_M_arr, psi, theta_true, p_star, p_initial):
    """
    Runs a single experiment for T time periods.
    Data is collected, theta and psi are estimated, and regret is computed.
    """
    rng = np.random.default_rng(rng_seed)
    m, s = construct_m_and_s(p_m_arr, p_M_arr)

    # Define tau (number of initial samples) and n_i (for psi estimation)
    tau_base = int(np.ceil(T**(5/7) / 1.5))
    tau = [int(np.ceil(tau_base + tau_base * np.random.uniform(0.25, 0.75))) for _ in range(N)]
    kappa_star = [solve_kappa(N, tau[i]) for i in range(N)]
    n_i = [int(np.ceil(tau[i] * kappa_star[i])) for i in range(N)]
    order = list(np.argsort(tau))

    P_GLOBAL = np.zeros((T, N))
    Y_GLOBAL = np.zeros((T, N))
    W_GLOBAL = np.zeros((T, N))


    P = []
    Y = []
    W = []
    
    for i in range(N):
        length_i = int(np.ceil(tau[i]))  # length specific to i
        P.append(np.zeros(length_i))
        Y.append(np.zeros(length_i))
        W.append(np.zeros(length_i))


    theta_hat = [None] * N
    psi_hat = [None] * N
    p_T_final = None
    regrets = np.zeros(T)
    l2_error_psi = np.nan  # will be computed later
    errors = []

    # Data collection and updates
    for t in range(np.max(tau)+1):
        for i in range(N):
            if t < tau[i]:
                # Sample p from the truncated multivariate Gaussian
                P[i][t] = np.random.normal(loc=0.5 * (p_m_arr[i] + p_M_arr[i]),scale=np.sqrt(0.5 * (p_M_arr[i] - p_m_arr[i])))
                P_GLOBAL[t,i] = P[i][t]


    for t in range(T):
        total_payoff_star = 0.0
        total_payoff_chosen = 0.0


        for i in order:
            # 1) Data collection for t < tau
            if t < tau[i]:
                w_t_i = np.dot(theta_true[i], P_GLOBAL[t,:])
                noise_t_i = rng.uniform(-0.05, 0.05)
                y_t_i = psi[i](w_t_i) + noise_t_i
                Y_GLOBAL[t, i] = y_t_i
                W_GLOBAL[t, i] = w_t_i

                b_i = theta_true[i][i]
                gamma_i = np.delete(theta_true[i], i)
                q_i = np.delete(p_star, i) @ gamma_i
                payoff_star = payoff(psi[i], p_star[i], b_i, q_i)
                payoff_chosen = payoff(psi[i], P_GLOBAL[t,i], b_i, q_i)

                total_payoff_star += payoff_star
                total_payoff_chosen += payoff_chosen
                regrets[t] += total_payoff_star - total_payoff_chosen
    
            # 2) Estimate theta at t == n_i using all data so far
            if t == n_i[i]:
                X_for_theta = P_GLOBAL[:n_i[i],:]
                X_centered = X_for_theta - X_for_theta.mean(axis=0, keepdims=True)
                beta_hat_i = np.linalg.lstsq(X_centered, Y_GLOBAL[:n_i[i], i], rcond=None)[0]
                norm_beta_i = np.linalg.norm(beta_hat_i, 2)
                if norm_beta_i > 0:
                    beta_hat_i /= norm_beta_i
                theta_hat[i] = beta_hat_i
    
            # 3) Estimate psi at t == tau (using data from rows [n_i, tau))
            if t == tau[i]:
                w_i = W_GLOBAL[n_i[i]:tau[i], i]
                y_i = Y_GLOBAL[n_i[i]:tau[i], i]
                iso_reg = LS_SLSQP(w_i, y_i, lamb=0, epsi=0.001)
                psi_hat[i] = lambda u, reg=iso_reg: np.clip(reg(u), 0, 1)
    
                # Compute L2 error of psi estimation (on training data)
                y_true_i_train = np.array([psi[i](w_val) for w_val in w_i])
                y_pred_i_train = np.array([psi_hat[i](w_val) for w_val in w_i])
                mse_i = np.mean((y_true_i_train - y_pred_i_train) ** 2)
                errors.append(mse_i)
    
            # 4) For t >= tau, update p_t using estimated psi (and theta_hat if available)
            if t >= tau[i]:
                if t == tau[i]:
                    p_t_minus_1 = np.copy(p_initial)  # starting point for p_t
                p_hat_minus_i = np.delete(p_t_minus_1, i)

                # Use estimated psi if available; otherwise, default to midpoint
                if psi_hat[i] is not None and theta_hat[i] is not None:
                    b_hat_i = theta_hat[i][i]
                    gamma_hat_i = np.delete(theta_hat[i], i)
                    q_hat = p_hat_minus_i @ gamma_hat_i
                    p_hat_t_i = p_opt(psi_hat[i], b_hat_i, q_hat, p_m_arr[i], p_M_arr[i])
                    P_GLOBAL[t,i] = p_hat_t_i
                else:
                    p_hat_t_i = (p_m_arr[i] + p_M_arr[i]) / 2.0
                    P_GLOBAL[t,i] = p_hat_t_i

                p_t_minus_1[i] = p_hat_t_i
                ### regret ###
                b_i = theta_true[i][i]
                gamma_i = np.delete(theta_true[i], i)
                q_i = np.delete(p_star, i) @ gamma_i
                payoff_star = payoff(psi[i], p_star[i], b_i, q_i)
                payoff_chosen = payoff(psi[i], p_hat_t_i, b_i, q_i)
                total_payoff_star += payoff_star
                total_payoff_chosen += payoff_chosen
                regrets[t] += total_payoff_star - total_payoff_chosen
    
                if t == T - 1:
                    p_T_final = p_t_minus_1

    l2_error_psi = np.mean(errors)
    
    # After time T, compute performance metrics
    if theta_hat is not None:
        sum_diff_theta = sum(np.linalg.norm(theta_hat[i] - theta_true[i], 2) for i in range(N))
    else:
        sum_diff_theta = np.nan

    if p_T_final is not None and p_star is not None:
        diff_p = np.linalg.norm(p_T_final - p_star, 2)
    else:
        diff_p = np.nan

    total_regret_run = np.nansum(regrets)
    return sum_diff_theta, l2_error_psi, diff_p, total_regret_run

def plot_all():
    data_folder = "DATA"
    plot_folder = "PLOTS"
    os.makedirs(plot_folder, exist_ok=True)
    
    # Use a clean and modern style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.size'] = 16
    mpl.rcParams['axes.labelsize'] = 18
    mpl.rcParams['axes.titlesize'] = 20
    mpl.rcParams['legend.fontsize'] = 16
    mpl.rcParams['xtick.labelsize'] = 16
    mpl.rcParams['ytick.labelsize'] = 14
    
    # Define your datasets (filenames you saved earlier)
    datasets = {
        r"$N=2$":      f"data_N_{2}.csv",
        r"$N=4$":      f"data_N_{4}.csv",
        r"$N=6$":      f"data_N_{6}.csv"
    }
    
    colors = {
        r"$N=2$": "red",
        r"$N=4$": "blue",
        r"$N=6$": "green"
    }
    
    markers = {
        r"$N=2$": "o",   # circle
        r"$N=4$": "^",   # triangle
        r"$N=6$": "s"    # square
    }
    MARKERSIZE = 8       # try 10–12
    MARKEREDGE = 1
    
    # Create 4 subplots
    fig_main, axes = plt.subplots(1, 4, figsize=(20, 5))
    ax1, ax2, ax3, ax4 = axes.ravel()
    
    for label, filename in datasets.items():
        df = pd.read_csv(os.path.join(data_folder, filename))
    
        T_values        = df["T_values"].to_numpy()
        mean_theta_diffs = df["mean_theta_diffs"].to_numpy()
        ci_theta_diffs   = df["ci_theta_diffs"].to_numpy()
        mean_l2_errors   = df["mean_l2_errors"].to_numpy()
        ci_l2_errors     = df["ci_l2_errors"].to_numpy()
        mean_p_diffs     = df["mean_p_diffs"].to_numpy()
        ci_p_diffs       = df["ci_p_diffs"].to_numpy()
        mean_regrets     = df["mean_regrets"].to_numpy()
        ci_regrets       = df["ci_regrets"].to_numpy()
    
        c = colors[label]
    
        # (A) Theta error
        ax1.plot(T_values, mean_theta_diffs, '-', marker=markers[label],markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                 linewidth=2, alpha=0.8, color=c, label=label)
        ax1.fill_between(T_values,
                         mean_theta_diffs - ci_theta_diffs,
                         mean_theta_diffs + ci_theta_diffs,
                         color=c, alpha=0.2)
    
        # (B) Psi error
        ax2.plot(T_values, mean_l2_errors, '-', marker=markers[label],markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                 linewidth=2, alpha=0.8, color=c, label=label)
        ax2.fill_between(T_values,
                         mean_l2_errors - ci_l2_errors,
                         mean_l2_errors + ci_l2_errors,
                         color=c, alpha=0.2)
    
        # (C) p difference
        ax3.plot(T_values, mean_p_diffs, '-', marker=markers[label],markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                 linewidth=2, alpha=0.8, color=c, label=label)
        ax3.fill_between(T_values,
                         mean_p_diffs - ci_p_diffs,
                         mean_p_diffs + ci_p_diffs,
                         color=c, alpha=0.2)
    
        # (D) Regret
        ax4.plot(T_values, mean_regrets, '-', marker=markers[label],markersize=MARKERSIZE, markeredgewidth=MARKEREDGE,
                 linewidth=2, alpha=0.8, color=c, label=label)
        ax4.fill_between(T_values,
                         mean_regrets - ci_regrets,
                         mean_regrets + ci_regrets,
                         color=c, alpha=0.2)
    
    # Titles and labels
    ax1.set_title(r'$\sum_{i=1}^{N} \|\widehat{\theta}_i^{(T)} - \theta_i\|_2$', fontsize=25, fontweight='bold')
    ax1.set_xlabel("T"); ax1.grid(True, linestyle="--", alpha=0.5)
    
    ax2.set_title(r'$\sum_{i=1}^{N} \|\widehat{\psi}_i^{(T)} - \psi_i\|_2$', fontsize=25, fontweight='bold')
    ax2.set_xlabel("T"); ax2.grid(True, linestyle="--", alpha=0.5)
    
    ax3.set_title(r'$\|\mathbf{p}^{(T)} - \mathbf{p}^\star\|_2$', fontsize=25, fontweight='bold')
    ax3.set_xlabel("T"); ax3.grid(True, linestyle="--", alpha=0.5)
    
    ax4.set_title("Total Expected Regret", fontsize=25, fontweight='bold')
    ax4.set_xlabel("T"); ax4.grid(True, linestyle="--", alpha=0.5)
    
    # Add legends
    for ax in [ax1, ax2, ax3, ax4]:
        ax.legend()
    
    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    
    # Save
    plot_file = os.path.join(plot_folder, f"stacked_overdraw_N.pdf")
    plt.savefig(plot_file, bbox_inches="tight", dpi=300)
    plt.show()
    
    print(f"Figure saved successfully to {plot_file}")

def plot_all_slopes():    
    data_folder = "DATA"
    plot_folder = "PLOTS"
    os.makedirs(plot_folder, exist_ok=True)
    
    # Use a clean and modern style
    plt.style.use('seaborn-v0_8-whitegrid')
    
    mpl.rcParams['text.usetex'] = True
    mpl.rcParams['font.family'] = 'serif'
    mpl.rcParams['font.size'] = 16
    mpl.rcParams['axes.labelsize'] = 18
    mpl.rcParams['axes.titlesize'] = 20
    mpl.rcParams['legend.fontsize'] = 20
    mpl.rcParams['xtick.labelsize'] = 16
    mpl.rcParams['ytick.labelsize'] = 14
    
    def loglog_fit(x, y):
        x = np.asarray(x, dtype=float)
        y = np.asarray(y, dtype=float)
        mask = (y > 0) & (x > 0)
        logx, logy = np.log(x[mask]), np.log(y[mask])
        slope, intercept = np.polyfit(logx, logy, 1)
        return slope, intercept
    
    # Define your datasets (filenames you saved earlier)
    datasets = {
        r"$N=2$":      f"data_N_{2}.csv",
        r"$N=4$":      f"data_N_{4}.csv",
        r"$N=6$":      f"data_N_{6}.csv"
    }
    
    colors = {
        r"$N=2$": "red",
        r"$N=4$": "blue",
        r"$N=6$": "green"
    }
    
    markers = {
        r"$N=2$": "o",   # circle
        r"$N=4$": "^",   # triangle
        r"$N=6$": "s"    # square
    }
    MARKERSIZE = 8       # try 10–12
    MARKEREDGE = 1
    
    
    EPS = 1e-300  # for log-scale safety
    
    # Create 4 subplots
    fig_main, axes = plt.subplots(1, 2, figsize=(15, 6))
    ax1, ax2 = axes.ravel()
    
    for label, filename in datasets.items():
        df = pd.read_csv(os.path.join(data_folder, filename))
    
        T_values        = df["T_values"].to_numpy()
        mean_p_diffs     = df["mean_p_diffs"].to_numpy()
        ci_p_diffs       = df["ci_p_diffs"].to_numpy()
        mean_regrets     = df["mean_regrets"].to_numpy()
        ci_regrets       = df["ci_regrets"].to_numpy()
    
        c = colors[label]
    
        # ----- P differences (log–log plot) -----
        lower_p = np.maximum(mean_p_diffs - ci_p_diffs, EPS)
        upper_p = np.maximum(mean_p_diffs + ci_p_diffs, EPS)
    
        # regression slope in log–log
        m_p, b_p = loglog_fit(T_values, mean_p_diffs)
        marker_label = rf"m={m_p:.2f}" if np.isfinite(m_p) else label
    
        # plot log-log directly
        ax1.plot(np.log(T_values), np.log(mean_p_diffs), '-', 
                 marker=markers[label], markersize=MARKERSIZE,
                 markeredgewidth=MARKEREDGE, linewidth=2, alpha=0.8, color=c, 
                 label=marker_label)
    
        ax1.fill_between(np.log(T_values), np.log(lower_p), np.log(upper_p), 
                         color=c, alpha=0.2)
    
        # regression line (already log–log)
        if np.isfinite(m_p):
            logxfit = np.linspace(np.log(T_values.min()), np.log(T_values.max()), 200)
            logyfit = b_p + m_p * logxfit
            ax1.plot(logxfit, logyfit, '--', color=c, lw=2, alpha=0.9, label='_nolegend_')
    
    
        # ----- Regret (log–log plot) -----
        lower_r = np.maximum(mean_regrets - ci_regrets, EPS)
        upper_r = np.maximum(mean_regrets + ci_regrets, EPS)
    
        m_r, b_r = loglog_fit(T_values, mean_regrets)
        marker_label_r = rf"m={m_r:.2f}" if np.isfinite(m_r) else label
    
        ax2.plot(np.log(T_values), np.log(mean_regrets), '-', 
                 marker=markers[label], markersize=MARKERSIZE,
                 markeredgewidth=MARKEREDGE, linewidth=2, alpha=0.8, color=c, 
                 label=marker_label_r)
    
        ax2.fill_between(np.log(T_values), np.log(lower_r), np.log(upper_r), 
                         color=c, alpha=0.2)
    
        if np.isfinite(m_r):
            logxfit_r = np.linspace(np.log(T_values.min()), np.log(T_values.max()), 200)
            logyfit_r = b_r + m_r * logxfit_r
            ax2.plot(logxfit_r, logyfit_r, '--', color=c, lw=2, alpha=0.9, label='_nolegend_')
    
    
    
    ax1.set_title(r'log $\|\mathbf{p}^{(T)} - \mathbf{p}^\star\|_2$', fontsize=25, fontweight='bold')
    ax1.set_xlabel("log(T)"); ax1.grid(True, linestyle="--", alpha=0.5)
    
    ax2.set_title("log Total Expected Regret", fontsize=25, fontweight='bold')
    ax2.set_xlabel("log(T)"); ax2.grid(True, linestyle="--", alpha=0.5)
    
    # Add legends
    for ax in [ax1, ax2]:
        ax.legend()
    
    plt.tight_layout(pad=3.0)
    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    
    # Save
    plot_file = os.path.join(plot_folder, f"stacked_overdraw_N_slopes.pdf")
    plt.savefig(plot_file, bbox_inches="tight", dpi=300)
    plt.show()
    
    print(f"Figure saved successfully to {plot_file}")