import numpy as np
import os
import argparse
import yaml
import matplotlib.pyplot as plt
import seaborn as sns

from algorithm import *
from environment import *

"""
Description:

Simulate the online learning procedure. The environment is a linear bandit problem with safety constraints.
The learning objective is to minimize the cumulative regret and statistical estimation error of the linear 
parameter \theta^*, subject to safety constraints.

In each time step, the algorithms interact with the environment, in the following manner:
    - The environment determines the safety cost associated with each arm in the arm set.
    - The algorithm selects an arm based on the historical observation data. 
    - The environment generates and reveals the reward for the selected arm.

"""


def simulate_online_learning(exp_config):
    """
    Simulate one run of the online learning procedure.
    """
    T = exp_config["environment"]["T"]

    alg_list = [
        ALG_REGISTRY[alg["name"]](
            {**alg["params"], **exp_config["environment"]})
        for alg in exp_config["algorithms"]
    ]

    env = Environment({**exp_config["environment"], **exp_config["secrets"]})

    for t in range(1, T + 1):
        # 1. The environment determines the safety cost associated with each arm in the arm set.
        safety_costs_t = env.get_safety_costs_t()
        rewards_t = env.get_rewards_t()

        # 2. The algorithm selects an arm based on the historical observation data.
        arms_t = []
        for alg in alg_list:
            a_t = alg.select_arm(t, safety_costs_t)
            arms_t.append(a_t)
            alg.update(
                t, a_t, rewards_t[a_t], safety_costs_t
            )  # update the algorithm's estimation of the linear parameter \theta_t

        env.update_statistics(t, arms_t, rewards_t, safety_costs_t)
    results = env.get_results()
    results["estimation_error"] = [np.linalg.norm(
        alg.get_estimation() - env.theta_star, ord=2) for alg in alg_list]
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_config", type=str,
                        default="config/example.yaml")
    args = parser.parse_args()

    with open(args.exp_config, "r") as f:
        exp_config = yaml.load(f, Loader=yaml.FullLoader)

    epoch_num = exp_config["epoch_num"]
    seed = exp_config["seed"]
    np.random.seed(seed)
    T = exp_config["environment"]["T"]
    alg_num = len(exp_config["algorithms"])
    # Store all epoch results: shape (epoch_num, T, alg_num)
    all_cumu_regret = np.zeros((epoch_num, T, alg_num))
    all_cumu_safety_cost = np.zeros((epoch_num, T, alg_num))
    all_estimation_error = np.zeros((epoch_num, alg_num))
    all_PO_error = np.zeros((epoch_num, alg_num))
    for epoch in range(epoch_num):
        results = simulate_online_learning(exp_config)
        all_cumu_regret[epoch, :, :] = np.array(results["cumu_regret"]) / T
        all_cumu_safety_cost[epoch, :, :] = np.array(
            results["cumu_safety_cost"]) / T
        all_estimation_error[epoch, :] = np.array(results["estimation_error"])

    # Calculate mean and percentiles
    cumu_regret_mean = np.mean(all_cumu_regret, axis=0)
    cumu_regret_p25 = np.percentile(all_cumu_regret, 10, axis=0)
    cumu_regret_p75 = np.percentile(all_cumu_regret, 90, axis=0)

    cumu_safety_cost_mean = np.mean(all_cumu_safety_cost, axis=0)
    cumu_safety_cost_p25 = np.percentile(all_cumu_safety_cost, 20, axis=0)
    cumu_safety_cost_p75 = np.percentile(all_cumu_safety_cost, 80, axis=0)

    fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1)


    for i, alg in enumerate(exp_config["algorithms"]):
        # Plot mean line
        ax1.plot(cumu_regret_mean[:, i], label=f"{alg['name']}")
        # Add shaded region between 25th and 75th percentiles
        ax1.fill_between(
            range(T),
            cumu_regret_p25[:, i],
            cumu_regret_p75[:, i],
            alpha=0.3
        )

        # Plot mean line
        ax2.plot(cumu_safety_cost_mean[:, i],
                 label=f"{alg['name']}")
        # Add shaded region between 25th and 75th percentiles
        ax2.fill_between(
            range(T),
            cumu_safety_cost_p25[:, i],
            cumu_safety_cost_p75[:, i],
            alpha=0.25
        )

        ax3.boxplot(all_estimation_error[:, i], positions=[
                    i], widths=0.2)

    ax1.legend()
    ax2.legend()
   
    ax1.grid(True)
    ax2.grid(True)
    ax3.grid(True)
    ax1.set_xlabel("Time steps")
    ax2.set_xlabel("Time steps")
    ax3.set_xlabel(f"Algorithms")
    ax1.set_ylabel("Regret / $T$")
    ax2.set_ylabel("Safety cost / $T$")
    ax3.set_ylabel("Estimation error")
    ax3.set_xticks(range(alg_num), [f"{alg['name']}" for alg in exp_config["algorithms"]])
    
    plt.tight_layout()
    plt.show()
