import wandb
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from mount_car.utils.utils import gaussian_smooth


# Initialize W&B API
api = wandb.Api()

# Fetch runs from your project
runs = api.runs("albicaron93/MC_EXPLORE")

# Dictionary to store runs
uncer_methods = ['Error', 'Entropy', 'IG']
noise_models = ['heteroskedastic', 'homoskedastic']
policy_types = ['Random']

all_runs = {uncer: {noise: [] for noise in noise_models} for uncer in uncer_methods}

# The config key that identifies which group the run belongs to.
# Adjust based on how your runs' configs are struct ured.
group_key = "uncertainty_method"

# Define color for each uncertainty method and each noise model combination
colors = {
    'Error': {'heteroskedastic': 'cornflowerblue', 'homoskedastic': 'darkblue'},
    'Entropy': {'heteroskedastic': 'lightcoral', 'homoskedastic': 'darkred'},
    'IG': {'heteroskedastic': 'mediumseagreen', 'homoskedastic': 'darkgreen'}
}

# ---------------------
# 1) Download data from W&B and group runs
# ---------------------
for run in runs:

    if run.config['policy'] == 'Random':
        group_value = run.config[group_key]
        noise_model = run.config['noise_model']

        # Download the full history for this run
        history = run.history(keys=[f"Uncertainty/Raw_Avg_{group_value}"])
        uncertainties = history[f"Uncertainty/Raw_Avg_{group_value}"].dropna().to_numpy()
        all_runs[group_value][noise_model].append(uncertainties)

# ---------------------
# 2) Plotting
# ---------------------

# Here I want 3 plots: one for each uncertainty method. Each plot has 2 subplots: one for each noise model. All the
# runs must be smoother via Gaussian filter first. Then averaged and plotted with 95% error bars.

# Define plot parameters
fig, axs = plt.subplots(1, 3, figsize=(2.8*3.4, 2.8))

for i, uncer in enumerate(uncer_methods):
    for j, noise in enumerate(noise_models):

        ns_mod = 'Heterosk' if noise == 'heteroskedastic' else 'Homosk'

        all_runs_smoothed = [gaussian_smooth(run, sigma=8) for run in all_runs[uncer][noise]]
        all_runs_mean = np.mean(all_runs_smoothed, axis=0)
        all_runs_se = np.std(all_runs_smoothed, axis=0) / np.sqrt(len(all_runs_smoothed))

        axs[i].plot(all_runs_mean, label=ns_mod, color=colors[uncer][noise])
        axs[i].fill_between(range(len(all_runs_mean)), all_runs_mean - 1.9 * all_runs_se,
                            all_runs_mean + 1.9 * all_runs_se, alpha=0.2, color=colors[uncer][noise])

        # Only for titling
        axs[i].set_title('EIG') if uncer == 'IG' else axs[i].set_title(uncer)
        axs[i].set_xlabel('Model Updates')
        axs[i].set_ylabel('Uncertainty')

    # Grid and legend
    axs[i].legend()
    axs[i].grid()

plt.tight_layout()
fig.savefig('hh_uncertainty.pdf', dpi=300, bbox_inches='tight')
