import math
import os
import time
import argparse

import numpy as np
import matplotlib.pyplot as plt
from numba import njit
from concurrent.futures import ProcessPoolExecutor, as_completed

from Traders import Traders

# Algorithm 1
def algo(T, K, delta, traders, GFT_pstar):
    # Create (K+1)-equispaced price points in [0,1]
    p = np.arange(K + 1, dtype=float) / K

    # Initialize the algorithm internal parameters and constants
    A_t = np.zeros(K, dtype=float)  # cumulative sum of U# per arm
    B_t = np.zeros(K, dtype=int)  # cumulative sum of I{C <= P} per arm
    N_t = np.zeros(K, dtype=int)  # number of pulls per arm
    F = np.zeros(K, dtype=float)
    G = np.zeros(K, dtype=float)
    g_hat = np.zeros(K, dtype=float)

    computed_zeta = zeta(K, delta)
    regret_over_time = np.zeros(T + 1, dtype=float)

    # Iterate over the T rounds
    for t in range(1, T + 1):

        # Round-robin style exploration
        if t <= K ** 2:
            # Cycle through arms 0 to K-1 for the first K^2 rounds
            I_t = (t - 1) % K

        # Exploitation
        else:
            update_g_hat(T, K, delta, A_t, B_t, N_t, G, F, g_hat, computed_zeta)

            # Select the arm that maximizes the upper confidence bound
            I_t = int(np.argmax(g_hat))

        # Play arm I_t by posting price p_{I_t} and receive feedback
        # Note that the algorithm does not have access to the supplier cost and retailer net gain to choose the next price
        supplier_cost, supplier_indicator, retailer_net_gain, retailer_quantity = traders.feedback(p[I_t])

        # Calculate the gain from trade from posting price p_{I_t} and update the estimated regret
        GFT = 0
        if supplier_indicator == 1:
            GFT = (p[I_t] - supplier_cost) * retailer_quantity + retailer_net_gain
        regret_over_time[t] = regret_over_time[t - 1] + (GFT_pstar - GFT)

        # Update the count for arm I_t and feedback sums
        N_t[I_t] += 1
        A_t[I_t] += retailer_quantity
        B_t[I_t] += supplier_indicator

        # At the end of the exploration phase, compute F_j and G_j
        # We use these fixed estimates in rounds t > K^2

        if t == K ** 2:
            for j in range(K):
                for i in range(j, K):
                    F[j] += A_t[i]
                F[j] /= (K ** 2)
                if j != 0:
                    for i in range(j):
                        G[j] += B_t[i]
                    G[j] /= (K ** 2)

    # Return the cumulative regret of the algorithm for every time step up to round T
    return regret_over_time[1:]


@njit
def update_g_hat(T, K, delta, A_t, B_t, N_t, G, F, g_hat, computed_zeta):
    for j in range(K):
        computed_xi = xi(T, K, delta, N_t[j])

        # Compute an upper confidence bound on the estimated expected reward
        g_hat[j] = ((A_t[j] / N_t[j] + computed_xi) * (G[j] + computed_zeta)
                    + (B_t[j] / N_t[j] + computed_xi) * (F[j] + computed_zeta))


# Define confidence radii for the algorithm
def zeta(K, delta):
    return (1 + math.sqrt(2 * math.log(8 * K / delta))) / K


@njit
def xi(T, K, delta, n):
    return math.sqrt(2 * math.log(8 * K * T / delta) / n)


# Compute MC estimate of the price maximizing the expected gain from trade (g) along with the corresponding gain
def estimate_optimal_price(prices, traders_obj, n_samples):
    g_means = []

    # For every price in the grid, compute the mean of g evaluated at n_samples time steps
    # to estimate the expected g under the current environment
    for price in prices:
        g_sum = 0.0
        for _ in range(n_samples):
            # Receive full feedback to compute the gain from trade directly
            supplier_cost, supplier_indicator, retailer_net_gain, retailer_quantity = traders_obj.feedback(price)
            if supplier_indicator == 1:
                g_sum += (price - supplier_cost) * retailer_quantity + retailer_net_gain
        g_means.append(g_sum / n_samples)

    # Select the price in the grid that maximizes the mean of the gain from trade
    best_idx = int(np.argmax(g_means))
    return prices[best_idx], g_means[best_idx]


# Run a single repetition/trial of the algorithm
def single_rep(T, K, delta, traders, GFT_pstar, rep_seed):
    np.random.seed(rep_seed)
    return algo(T, K, delta, traders, GFT_pstar)


# Run the algorithm for n_trials repetitions and return the mean regret and 95% confidence intervals
def run_simulation(T, traders, MC_price_grid, n_trials, seed=1):
    np.random.seed(seed)
    K = int(np.ceil(T ** (1 / 3)))
    delta = 1 / T

    cum_regret = np.zeros((n_trials, T))

    # Use MC estimation to compute the maximum value of the expected gain from trade
    num_mc_samples = 20000
    _, GFT_pstar = estimate_optimal_price(MC_price_grid, traders, num_mc_samples)

    with ProcessPoolExecutor(max_workers=os.cpu_count()) as pool:
        futures = [pool.submit(single_rep, T, K, delta, traders, GFT_pstar, seed + r)
                   for r in range(n_trials)]

    num_completed = 0

    for future in as_completed(futures):
        cum_regret[num_completed] += np.array(future.result())
        num_completed += 1

    mean_regret = cum_regret.mean(axis=0)

    # Compute 95% confidence intervals
    lower_CI = mean_regret - 1.96 * (cum_regret.std(axis=0) / np.sqrt(n_trials))
    upper_CI = mean_regret + 1.96 * (cum_regret.std(axis=0) / np.sqrt(n_trials))

    return mean_regret.tolist(), lower_CI.tolist(), upper_CI.tolist()


def main(T, n_trials):

    # Initialize supplier and retailer environments for empirically showing the upper bound
    supplier_configs = [
        ("Uniform", "uniform", (), 'limegreen'),
        ("Beta", "beta", (2, 5), 'orange'),
        ("Trunc Log-Normal", "trunc_log_normal", (-0.5, 1), 'blue'),
        ("Beta Mixture", "two_beta_mixture", (0.75, 2, 5, 5, 2), 'red'),
    ]
    retailer_configs = [
        ("Capped-Linear", "cap_linear", ()),
        ("Exponential", "exp", ())
    ]

    rounds = np.arange(1, T + 1)

    # Compute theoretical upper bound from Theorem 3.3 (scaled down by a factor of 0.45 for clarity).
    # This bound is obtained by substituting K = ceil(T^(1/3)) and delta = 1/T
    # into the final regret inequality derived in the proof of Theorem 3.3, using simple estimations.
    theo_upper_bound = 0.45 * ((2.5 + np.sqrt(np.log(rounds))) * np.power(rounds, 2 / 3))

    # Compute and plot the empirical regret for every combination
    for retailer_label, retailer_env, params in retailer_configs:
        start_time = time.time()
        all_regret = {}
        for supplier_label, supplier_env, params, color in supplier_configs:
            # Initialize the supplier and retailer object
            traders = Traders(supplier_env, retailer_env, params)
            # Select a suitable range of prices to estimate the price maximizing the expected gain from trade
            # Note that for T = 7 * 10^5, K = ceil(T^(1/3)) = 89
            # Fixed so that regret computations are consistent across runs with different time horizons
            r = 89
            MC_price_grid = [j / r for j in range(r + 1)]
            # Run the simulation
            regret, lower_CI, upper_CI = run_simulation(T, traders, MC_price_grid, n_trials, seed=1)
            all_regret[supplier_label] = (regret, lower_CI, upper_CI, color)

        plt.figure(figsize=(4.3, 5))
        for supplier_label, (regrets, lower_CI, upper_CI, color) in all_regret.items():
            # Plot the regret
            plt.plot(rounds, regrets, label=supplier_label, color=color, linewidth=1)
            plt.fill_between(rounds, lower_CI, upper_CI, color=color, alpha=0.3)

        # Plot the scaled down theoretical upper bound
        plt.plot(rounds, theo_upper_bound, "--", label="Theorem 3.3 Upper Bound", color='black')
        plt.xlabel("Round")
        plt.ylabel("Regret")
        plt.title(f"Regret over Time for Different Seller \nValuations with {retailer_label} Utility")
        plt.legend(loc="upper left")
        plt.tight_layout()
        plt.savefig(f"Regret-{retailer_label}-Utility.png", dpi=400)

        print(f"Finished Simulations for {retailer_label} Utility")

    # Initialize supplier and retailer environments for empirically showing the lower bound
    # when any of the assumptions are removed
    # For Assumption 1.3, with x' as defined in the proof in Appendix E,
    # we choose the point x' = 5/9 to ensure it is not on the algorithm's price grid
    configs = [
        ("Lift Assumption 1.1", 'F_mu_adversary', 'F_mu_adversary', (), 24, 'limegreen'),
        ("Lift Assumption 1.2", 'mu', 'mu', (), 24, 'orange'),
        ("Lift Assumption 1.3", 'two_point_cost', 'two_type_buyer', (5 / 9), 10, 'blue'),
        ("Lift Assumption 1.4", 'uniform', 'quadratic', (), 5000, 'red'),
    ]

    # Compute and plot the empirical regret for every pair
    all_regret = {}
    for label, supplier_env, retailer_env, params, lower_bound_const, color in configs:
        # Initialize the supplier and retailer object
        traders = Traders(supplier_env, retailer_env, params)
        # Select a suitable range of prices to estimate the price maximizing the expected gain from trade
        r = 89 + 1
        MC_price_grid = [j / r for j in range(r + 1)]
        # Run the simulation
        regret, lower_CI, upper_CI = run_simulation(T, traders, MC_price_grid, n_trials, seed=1)
        all_regret[label] = (regret, lower_CI, upper_CI, lower_bound_const, color)

    plt.figure(figsize=(4.3, 5))
    for label, (regrets, lower_CI, upper_CI, lower_bound_const, color) in all_regret.items():
        # Plot the regret
        plt.plot(rounds, regrets, label=label, color=color, linewidth=1)
        plt.fill_between(rounds, lower_CI, upper_CI, color=color, alpha=0.3)

        # Plot the theoretical lower bound
        plt.plot(rounds, rounds / lower_bound_const, "--", color=color, label=f"T/{lower_bound_const} Lower Bound")

    plt.xlabel("Round")
    plt.ylabel("Regret")
    plt.ylim(top=T / 4)
    plt.title("Regret over Time when Lifting \nAssumptions 1.1-1.4")
    plt.legend(loc="upper left")
    plt.tight_layout()
    plt.savefig("Regret-Lifting-Assumptions.png", dpi=400)

    print(f"Finished Simulations for Lifting Assumptions")

def parse_args():
    parser = argparse.ArgumentParser(description="Run regret simulations.")
    parser.add_argument("--T", type=int, default = 7 * (10 ** 5),
                        help="Time horizon (default: 7e5)")
    parser.add_argument("--n_trials", type=int, default = 30,
                        help="Number of trials to run (default: 30)")

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    main(args.T, args.n_trials)