import wandb
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
from mount_car.utils.utils import gaussian_smooth


# Initialize W&B API
api = wandb.Api()

maze_list = ['Open_Diverse_G', 'UMaze', 'Open', 'Medium']
maze_labels = ['Open', 'U-Shape', 'Double U-Shape', 'Obstacles']
runs_list = [f"PointMaze_{x}-v3_MPC" for x in maze_list]

x_ticks = [2_000, 3_000, 5_000, 8_000]

# Dictionary to store runs
uncer_methods = ['IG']
policy_ts = ['BE_SAC', 'BayesExp_MPC_SAC']
policy_types = [f"{base_name}_{uncer}" for base_name in policy_ts for uncer in uncer_methods]
policy_types.insert(0, 'SAC')
policy_labels = ['SAC', 'BE-SAC', 'PTS-BE-SAC']
all_runs = {maze: {policy: [] for policy in policy_types} for maze in runs_list}

for run_proj in runs_list:

    # Fetch runs from your project
    runs = api.runs(f"albicaron93/{run_proj}")

    # ---------------------
    # 1) Download data from W&B and group runs
    # ---------------------
    for run in runs:

        policy_name = run.config['alg_name']

        if 'Entropy' not in policy_name:

            # Download the full history for this run
            history = run.history(keys=["Train/Episode Reward"])
            rews = history["Train/Episode Reward"].dropna().to_numpy()
            all_runs[run_proj][policy_name].append(rews)


# ---------------------
# 2) Plotting
# ---------------------

# Setup plot and colors
fig, axes = plt.subplots(1, 4, figsize=(3.5*(13/3), 3.5))

for k, maze in enumerate(runs_list):

    maze_label = maze_labels[k]
    x_tick = x_ticks[k]

    for i, policy in enumerate(policy_types):

        policy_label = policy_labels[i]

        # Stack the list arrays into a tensor of shape (n_runs, n_steps)
        tsr = np.array(all_runs[maze][policy])
        solv_prob = tsr.mean(axis=0)

        # Smooth the probability
        solv_prob = gaussian_smooth(solv_prob, sigma=2)

        # Calculate the standard error of the mean
        std_error = np.sqrt(solv_prob * (1 - solv_prob) / tsr.shape[0])

        # Create an array of x-values from 0 to x_tick with len(solv_prob) elements
        x = np.linspace(0, x_tick, len(solv_prob))

        # Plot
        axes[k].plot(x, solv_prob, label=policy_label, linewidth=1)
        axes[k].fill_between(x, solv_prob - 1.64 * std_error, solv_prob + 1.64 * std_error, alpha=0.3)

        # Set labels
        axes[k].set_title(maze_label)
        axes[k].set_xlabel('Steps')

        # Grid
        axes[k].grid(True)

        # Set y-axis limits from 0.0 to 1.0
        axes[k].set_ylim([-0.05, 1.05])

        # Set legend
        axes[k].legend()

# Set y-axis label
axes[0].set_ylabel('$p(r^e_{t} = 1)$')

plt.tight_layout()
plt.savefig('maze_solved_prob.pdf', dpi=300, bbox_inches='tight')
