import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
from typing import Dict, List, Tuple, Optional

ENV = "mountain_car "  # Change to "reacher" for Reacher environment

def load_sb3_data(
    base_dir: str, 
    algos: List[str], 
    architectures: Dict[str, str],
    reward_type: str
) -> pd.DataFrame:
    """
    Loads SB3 data. Timesteps are assumed to be aligned across seeds for each experiment.
    """
    all_data_list: List[pd.DataFrame] = []
    
    pattern = re.compile(rf"{ENV}_(.+?)_(\w+)_policy_(\d+)_(\d+)_seed_0")

    for dir_name in sorted(os.listdir(base_dir)):
        match = pattern.match(dir_name)
        if not match:
            continue

        algo, current_reward_type, l1, l2 = match.groups()
        arch_key = f"({l1},{l2})"

        if algo not in algos or arch_key not in architectures or current_reward_type != reward_type:
            continue
        
        exp_path = os.path.join(base_dir, dir_name)
        
        all_seed_rewards = []
        timesteps = None
        
        for seed in range(10):
            eval_file = os.path.join(exp_path, f"seed_{seed}", "evaluations.npz")
            if not os.path.exists(eval_file):
                continue

            try:
                with np.load(eval_file) as npz_file:
                    if timesteps is None:
                        timesteps = npz_file['timesteps'].copy()
                        timesteps[0] = 0
                        timesteps[-1] = 300_000
                    mean_rewards_per_eval = np.mean(npz_file['results'], axis=1)
                    all_seed_rewards.append(mean_rewards_per_eval)
            except Exception as e:
                print(f"Error loading {eval_file}: {e}")

        if timesteps is not None and len(all_seed_rewards) > 0:
            rewards_matrix = np.array(all_seed_rewards)
            n_runs, n_timesteps = rewards_matrix.shape

            if arch_key == '(400,300)' or arch_key == '(64,64)':
                output_dir = os.path.join(base_dir, reward_type)
                os.makedirs(output_dir, exist_ok=True)
                output_path = os.path.join(output_dir, f'{algo}.npz')
                np.savez(output_path, timesteps=timesteps, all_rewards=rewards_matrix)
            
            df = pd.DataFrame({
                'Timestep': np.tile(timesteps, n_runs),
                'Reward': rewards_matrix.flatten(),
                'Algorithm': algo.upper(),
                'Architecture': architectures[arch_key],
                'Run': np.arange(n_runs).repeat(n_timesteps)
            })
            all_data_list.append(df)
            
    if not all_data_list:
        return pd.DataFrame()

    return pd.concat(all_data_list, ignore_index=True)


def load_pgpe_data(
    base_dir: str, 
    architectures: Dict[str, str],
    reward_type: str
) -> pd.DataFrame:
    """
    Loads PGPE data. For each experiment, it interpolates all runs to match the
    timesteps of the first run of that same experiment.
    """
    all_data_list: List[pd.DataFrame] = []

    pattern = re.compile(rf"{ENV}_pgpe_(.+?)_policy_(\d+)_(\d+)_seed_0")

    for dir_name in sorted(os.listdir(base_dir)):
        match = pattern.match(dir_name)
        if not match:
            continue
        
        current_reward_type, l1, l2 = match.groups()
        arch_key = f"({l1},{l2})"

        if current_reward_type != reward_type or arch_key not in architectures:
            continue
        
        history_file = os.path.join(base_dir, dir_name, 'pgpe_histories.pkl')
        if not os.path.exists(history_file):
            continue

        try:
            with open(history_file, 'rb') as f:
                histories = pickle.load(f)
        except Exception as e:
            print(f"ERROR loading PGPE history file {history_file}: {e}")
            continue

        if not histories or 'val_reward' not in histories[0] or len(histories[0]['val_reward']) < 2:
            continue

        local_canonical_timesteps, _ = zip(*histories[0]['val_reward'])
        local_canonical_timesteps = np.array(local_canonical_timesteps)

        if len(local_canonical_timesteps) == 0:
            continue

        n_timesteps = len(local_canonical_timesteps)
        n_runs = len(histories)
        interp_rewards_matrix = np.zeros((n_runs, n_timesteps))

        for i, run_history in enumerate(histories):
            if 'val_reward' not in run_history or len(run_history['val_reward']) < 2:
                interp_rewards_matrix[i, :] = np.nan
                continue
            
            run_timesteps, run_rewards = zip(*run_history['val_reward'])
            
            interp_rewards = np.interp(
                local_canonical_timesteps, run_timesteps, run_rewards,
                left=run_rewards[0], right=run_rewards[-1]
            )
            interp_rewards_matrix[i, :] = interp_rewards

        if arch_key == '(4,4)' or arch_key == '(64,64)':
            output_dir = os.path.join(base_dir, reward_type)
            os.makedirs(output_dir, exist_ok=True)
            output_path = os.path.join(output_dir, 'PGPE.npz')
            np.savez(output_path, timesteps=local_canonical_timesteps, all_rewards=interp_rewards_matrix)

        df_exp = pd.DataFrame({
            'Timestep': np.tile(local_canonical_timesteps, n_runs),
            'Reward': interp_rewards_matrix.flatten(),
            'Algorithm': 'PGPE',
            'Architecture': architectures[arch_key],
            'Run': np.arange(n_runs).repeat(n_timesteps)
        })
        all_data_list.append(df_exp)

    if not all_data_list:
        return pd.DataFrame()

    return pd.concat(all_data_list, ignore_index=True)


def aggregate_and_plot_seaborn(color_map: Dict, base_dir: str = 'baselines'):
    """
    Scans a directory, aggregates all data, and generates a plot for each reward type.
    """
    algos = ['ppo', 'ddpg', 'sac', 'td3']
    architectures = {'(4,4)': 'Small', '(64,64)': 'Medium', '(32,32)': 'Medium', '(64, 64)': 'Large', '(400,300)': 'Large'}
    style_map = {'Small': (1, 1), 'Medium': (4, 2), 'Large': ()}
    arch_order = ['Small', 'Medium', 'Large']
    algo_order = ['DDPG', 'PPO', 'SAC', 'TD3', 'PGPE']
    
    sns.set_theme(style="whitegrid")
    plt.rcParams.update({
        "font.family": "serif", "font.serif": ["Times New Roman", "DejaVu Serif"],
        "font.size": 16, "axes.labelsize": 20, "xtick.labelsize": 16,
        "ytick.labelsize": 16, "legend.fontsize": 12, "figure.figsize": (10, 6),
    })

    if not os.path.isdir(base_dir):
        print(f"Error: Directory '{base_dir}' not found.")
        return

    reward_types = set()
    sb3_pattern = re.compile(rf"{ENV}_(.+?)_(\w+)_policy")
    pgpe_pattern = re.compile(rf"{ENV}_pgpe_(.+?)_policy")

    for dir_name in os.listdir(base_dir):
        sb3_match = sb3_pattern.match(dir_name)
        pgpe_match = pgpe_pattern.match(dir_name)
        if sb3_match and sb3_match.group(1) in algos:
            reward_types.add(sb3_match.group(2))
        elif pgpe_match:
            reward_types.add(pgpe_match.group(1))

    print(f"\nFound reward types: {list(reward_types)}. Generating plots...")

    for reward_type in reward_types:
        print(f"\n--- Processing plot for reward type: '{reward_type}' ---")
        
        sb3_df = load_sb3_data(base_dir, algos, architectures, reward_type)
        pgpe_df = load_pgpe_data(base_dir, architectures, reward_type)

        master_df = pd.concat([sb3_df, pgpe_df], ignore_index=True)
        
        if master_df.empty:
            print(f"No data found for reward type '{reward_type}'. Skipping.")
            continue
        
        master_df['Algorithm'] = pd.Categorical(master_df['Algorithm'], categories=algo_order, ordered=True)
        master_df['Architecture'] = pd.Categorical(master_df['Architecture'], categories=arch_order, ordered=True)
        master_df = master_df.sort_values(by=['Algorithm', 'Architecture', 'Timestep'])

        fig, ax = plt.subplots()
        sns.lineplot(
            data=master_df, x="Timestep", y="Reward",
            hue="Algorithm", style="Architecture", ax=ax,
            errorbar=('ci', 95), linewidth=2,
            dashes=style_map,
            err_kws={'alpha':0.15,},
            hue_order=algo_order,
            style_order=arch_order,
            palette=color_map,
            legend=False,
            seed=42
        )
        
        ax.set_xlabel('Timesteps')
        right = 150000 if ENV == "reacher" else 300000
        ax.set_xlim(left=0, right=right)
        ax.set_ylabel('Return')
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
        
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        plt.tight_layout()

        output_file = os.path.join(base_dir, f'seaborn_performance_{reward_type}.pdf')
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Plot successfully saved to '{output_file}'")
        plt.show()

def plot_pgpe_only(pgpe_color: Tuple, base_dir: str = 'baselines'):
    """
    Generates a separate plot for each reward type, showing only PGPE performance.
    """
    architectures = {'(4,4)': 'Small', '(32,32)': 'Medium', '(64, 64)': 'Large', '(400,300)': 'Large'}
    
    style_map = {'Small': (1, 1), 'Medium': (4, 2), 'Large': ()}
    arch_order = ['Small', 'Medium', 'Large']
    
    # FIX: Create a palette dictionary mapping all architectures to the single PGPE color
    arch_color_map = {arch: pgpe_color for arch in arch_order}

    sns.set_theme(style="whitegrid")
    plt.rcParams.update({
        "font.family": "serif", "font.serif": ["Times New Roman", "DejaVu Serif"],
        "font.size": 16, "axes.labelsize": 20, "xtick.labelsize": 16,
        "ytick.labelsize": 16, "legend.fontsize": 12, "figure.figsize": (10, 6),
    })

    if not os.path.isdir(base_dir):
        print(f"Error: Directory '{base_dir}' not found.")
        return

    reward_types = set()
    pgpe_pattern = re.compile(rf"{ENV}_pgpe_(\w+)_policy")
    for dir_name in os.listdir(base_dir):
        pgpe_match = pgpe_pattern.match(dir_name)
        if pgpe_match:
            reward_types.add(pgpe_match.group(1))

    print(f"\nFound PGPE reward types: {list(reward_types)}. Generating PGPE-only plots...")

    for reward_type in reward_types:
        print(f"\n--- Processing PGPE-only plot for reward type: '{reward_type}' ---")
        
        pgpe_df = load_pgpe_data(base_dir, architectures, reward_type)
        
        if pgpe_df.empty:
            print(f"No PGPE data found for reward type '{reward_type}'. Skipping.")
            continue

        pgpe_df['Architecture'] = pd.Categorical(pgpe_df['Architecture'], categories=arch_order, ordered=True)
        pgpe_df = pgpe_df.sort_values(by=['Architecture', 'Timestep'])
            
        fig, ax = plt.subplots()
        sns.lineplot(
            data=pgpe_df, x="Timestep", y="Reward",
            hue="Architecture", 
            style="Architecture", 
            dashes=style_map,
            hue_order=arch_order,
            style_order=arch_order,
            ax=ax,
            err_kws={'alpha':0.15,},
            palette=arch_color_map, # Use the corrected color map here
            errorbar=('ci', 95), 
            linewidth=2,
            legend=False, # Corrected typo from legent
            seed=42
        )
        
        ax.set_xlabel('Timesteps')
        ax.set_ylabel('Return')
        ax.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
        ax.set_xlim(left=0) 
        
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        plt.tight_layout()

        output_file = os.path.join(base_dir, f'seaborn_pgpe_only_{reward_type}.pdf')
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Plot successfully saved to '{output_file}'")
        plt.show()

def generate_legend(entries, output_filename: str, base_dir: str = 'baselines', fontsize: int = 100):
    """
    Generates and saves a standalone legend from a dictionary of entries.
    """
    # Dynamically set a wide figsize based on the number of legend entries
    figsize_width = len(entries) * 2  # Heuristic for a wide layout
    figsize_height = 1
    fig, ax = plt.subplots(figsize=(figsize_width, figsize_height))
    
    lines = []
    labels = []
    for label, properties in entries.items():
        if 'lw' not in properties and 'linewidth' not in properties:
            properties['lw'] = 50
        lines.append(plt.Line2D([0], [0], **properties))
        labels.append(label)

    ax.legend(lines, labels, ncol=len(labels), frameon=False, loc='center', fontsize=fontsize)
    ax.axis('off')
    
    legend_file_path = os.path.join(base_dir, output_filename)
    fig.savefig(legend_file_path, bbox_inches='tight')
    plt.close(fig)
    print(f"Legend successfully saved to '{legend_file_path}'")

if __name__ == '__main__':
    # --- Define Master Color and Style Maps ---
    main_algos = ['DDPG', 'PPO', 'SAC', 'TD3', 'PGPE']
    full_palette = sns.color_palette("husl", n_colors=len(main_algos) + 3)
    master_color_map = {algo: full_palette[i + 3] for i, algo in enumerate(main_algos)}
    pgpe_color = master_color_map['PGPE']

    # Use a matplotlib-compatible linestyle map for generating legends
    master_style_map = {'Small': (0, (1, 1)), 'Medium': (0, (2, 1)), 'Large': '-'}

    if False:
        # --- Run Plotting Functions ---
        aggregate_and_plot_seaborn(master_color_map, base_dir='baselines')
        print("\n" + "="*50 + "\n")
        plot_pgpe_only(pgpe_color, base_dir='baselines')
        print("\n" + "="*50 + "\n")

    if ENV == "mountain_car":
        # --- Generate Flexible Standalone Legends ---
        print("Generating standalone legends...")

        # Case 1: Two separate legends for plots with color=algo and style=arch
        legend_1a_entries = {
            algo: {'color': master_color_map[algo]} for algo in main_algos
        }
        generate_legend(legend_1a_entries, 'legend_main_colors.pdf')

        legend_1b_entries = {
            'Small':  {'color': 'black', 'linestyle': master_style_map['Small']},
            'Medium': {'color': 'black', 'linestyle': master_style_map['Medium']},
            'Large':  {'color': 'black', 'linestyle': master_style_map['Large']},
        }
        generate_legend(legend_1b_entries, 'legend_main_styles.pdf')

        # Case 2: Just PGPE with 3 types of dashes (for architectures)
        # This is identical to legend_1b_entries, but we generate it again for clarity
        legend_2_entries = {
            'Small':  {'color': pgpe_color, 'linestyle': master_style_map['Small']},
            'Medium': {'color': pgpe_color, 'linestyle': master_style_map['Medium']},
            'Large':  {'color': pgpe_color, 'linestyle': master_style_map['Large']},
        }
        generate_legend(legend_2_entries, 'legend_pgpe_architectures.pdf')
    else:
        legend_2_entries = {
            'PGPE':  {'color': pgpe_color, 'linestyle': '-'},
            'SAC':  {'color': master_color_map["SAC"], 'linestyle': '-'},
        }
        generate_legend(legend_2_entries, 'legend_pgpe_architectures.pdf')