import wandb
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from mount_car.utils.utils import gaussian_smooth


# Initialize W&B API
api = wandb.Api()

# Fetch runs from your project
runs = api.runs("albicaron93/MC_EXPLORE")

# Dictionary to store runs
uncer_methods = ['Entropy', 'IG']
noise_models = ['heteroskedastic', 'homoskedastic']
policy_types = ['PPO', 'MPC']

all_runs = {policy_types: {uncer: {noise: [] for noise in noise_models} for uncer in uncer_methods} for policy_types in policy_types}

# The config key that identifies which group the run belongs to.
# Adjust based on how your runs' configs are struct ured.
group_key = "uncertainty_method"

# Define color for each uncertainty method and each noise model combination
colors = {
    'Entropy': {'heteroskedastic': 'lightcoral', 'homoskedastic': 'darkred'},
    'IG': {'heteroskedastic': 'cornflowerblue', 'homoskedastic': 'darkblue'}
}

# ---------------------
# 1) Download data from W&B and group runs
# ---------------------
for run in runs:

    if (run.name.startswith('PPO') or run.name.startswith('MPC')) and run.config[group_key] != 'Error':

        group_value = run.config[group_key]
        pol_name = run.name.split('_')[0]
        noise_model = run.config['noise_model']

        # Download the full history for this run
        history = run.history(keys=["Env/Perc States Visited"])
        uncertainties = history["Env/Perc States Visited"].dropna().to_numpy()
        all_runs[pol_name][group_value][noise_model].append(uncertainties)

# ---------------------
# 2) Plotting
# ---------------------

# Here I want 3 plots: one for each uncertainty method. Each plot has 2 subplots: one for each noise model. All the
# runs must be smoother via Gaussian filter first. Then averaged and plotted with 95% error bars.

# Define plot parameters
fig, axs = plt.subplots(1, 2, figsize=(3.2*(7/3), 3.2))

for i, policy_type in enumerate(policy_types):
    for uncer in uncer_methods:
        for noise in noise_models:

            ns_mod = 'Heterosk' if noise == 'heteroskedastic' else 'Homosk'
            unc_label = '$H[\cdot]$' if uncer == 'Entropy' else 'EIG'

            # We have 500 elements, but each is taken every 2 steps so we need x-axis to be 500 elements
            x = np.arange(0, 1000, 2)

            all_runs_smoothed = [gaussian_smooth(run, sigma=10) for run in all_runs[policy_type][uncer][noise]]
            all_runs_mean = np.mean(all_runs_smoothed, axis=0)
            all_runs_se = np.std(all_runs_smoothed, axis=0) / np.sqrt(len(all_runs_smoothed))

            axs[i].plot(x, all_runs_mean, label=f"{unc_label} {ns_mod}", color=colors[uncer][noise])
            axs[i].fill_between(x, all_runs_mean - 1.64 * all_runs_se,
                                all_runs_mean + 1.64 * all_runs_se, alpha=0.2, color=colors[uncer][noise])

    axs[i].set_title("PTS-BE") if policy_type == 'MPC' else axs[i].set_title("BE")
    axs[i].set_xlabel('Steps')
    axs[i].set_ylabel('Frac States Visited')

    # Set y-axis limits from 0.0 to 1.0
    axs[i].set_ylim([0.0, 1.0])

    # Grid and legend
    axs[i].legend()
    axs[i].grid()

plt.tight_layout()
fig.savefig('mpc_MC.pdf', dpi=300, bbox_inches='tight')
