import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import numpy as np


# To run this script:
# python -m plotting.experiment_5_plot

"""
Plotting for Experiment 5

This script generates the time-series visualizations for the FICO-based simulation.
It produces mean trajectories with standard deviation bands for three key metrics:
1. Attribute Value (x): Tracking genuine skill improvement.
2. Agent Level (l): Visualizing promotion/relegation dynamics.
3. Improvement Fraction: Measuring the ratio of honest effort to gaming.

Inputs: 
    - data/exp_5/results_plot_A.pkl (Parameter Ablation)
    - data/exp_5/results_plot_B.pkl (Level Ablation)
Outputs: 
    - plots/exp_5/exp_5_A_*.png (and .pdf)
    - plots/exp_5/exp_5_B_*.png (and .pdf)
"""


def load_data(filename):
    file_path = os.path.join('data', 'exp_5', filename)
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Data file not found at {file_path}.")
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def plot_trajectory_comparison(results_dict, param_label, filename, y_label, data_key='traj_x'):
    """
    Plots trajectories with cycling linestyles.
    """
    plot_dir = os.path.join('plots', 'exp_5')
    os.makedirs(plot_dir, exist_ok=True)
    
    sns.set_theme(context="paper", style="darkgrid", palette="deep")
    plt.figure(figsize=(6, 6))
    
    colors = sns.color_palette("deep", n_colors=len(results_dict))
    styles = ['-', '--', '-.', ':'] 
    
    sorted_keys = sorted(results_dict.keys())
    
    for i, val in enumerate(sorted_keys):
        data_packet = results_dict[val]
        trajectory_matrix = data_packet[data_key]
        
        mean_traj = np.mean(trajectory_matrix, axis=0)
        std_traj = np.std(trajectory_matrix, axis=0)
        time_steps = np.arange(len(mean_traj))
        
        label_str = f"{param_label} = {val}"
        color = colors[i]
        # Cycle through styles if there are more than 4 keys
        current_style = styles[i % len(styles)]
        
        plt.plot(time_steps, mean_traj, label=label_str, color=color, 
                 linewidth=2.5, linestyle=current_style)
        
        plt.fill_between(time_steps, mean_traj - std_traj, mean_traj + std_traj, 
                         color=color, alpha=0.15) 

    plt.xlabel("Time Step (t)", fontsize=16)
    plt.ylabel(y_label, fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.legend(loc='lower right', frameon=True, framealpha=0.9, fontsize=16)
    
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    base_name = os.path.splitext(filename)[0]
    save_path_base = os.path.join(plot_dir, base_name)
    plt.savefig(f"{save_path_base}.png", dpi=300)
    plt.savefig(f"{save_path_base}.pdf")
    plt.show()
    plt.close()

if __name__ == "__main__":
    # --- Plot A: Parameter Ablation ---
    try:
        results_A = load_data('results_plot_A.pkl')
    
        for _, c_val in results_A.items():
            for val, result in c_val.items():
                traj_a_plus = result['traj_ap']
                traj_a_minus = result['traj_am']
                improvement_fraction = (traj_a_plus + 1e-6) / (
                        traj_a_minus + traj_a_plus + 1e-6)
                result['improvement_fraction'] = improvement_fraction
                
                result['traj_l'] = result['traj_l'] + 1
        
        # Configuration for the sweeps
        # Format: (Dictionary Key, LaTeX Symbol, Attribute Filename, Level Filename)
        params = [
            ('gamma', r'$\gamma$', 'exp_5_A_gamma_attr.png',
             'exp_5_A_gamma_level.png', 'exp_5_A_gamma_action.png'),
            ('beta',  r'$\beta$',  'exp_5_A_beta_attr.png',
             'exp_5_A_beta_level.png', 'exp_5_A_beta_action.png'),
            ('delta', r'$\delta$', 'exp_5_A_delta_attr.png',
             'exp_5_A_delta_level.png', 'exp_5_A_delta_action.png'),
        ]

        for key, sym, attr_fname, level_fname, action_fname in params:
            if key in results_A:
                # 1. Save Attribute (x) Plot
                plot_trajectory_comparison(
                    results_A[key], sym, attr_fname, "Attribute Value ($x$)", data_key='traj_x'
                )
                
                # 2. Save Level (l) Plot
                plot_trajectory_comparison(
                    results_A[key], sym, level_fname, "Agent Level ($l$)", data_key='traj_l'
                )
                
                # 3. Save Action (improvement fraction) Plot
                plot_trajectory_comparison(
                    results_A[key], sym, action_fname,
                    "Improvement Fraction ($\\frac{a^+}{a^++a^-}$)",
                    data_key='improvement_fraction'
                )
        
    except FileNotFoundError as e:
        print(e)

    # --- Plot B: Level Ablation ---
    try:
        results_B = load_data('results_plot_B.pkl')
        plot_trajectory_comparison(
            results_B, r'$L$', 'exp_5_B_levels.png', "Agent Level ($l$)", data_key='traj_l'
        )
        
        # Adding Attribute (x) plot for Level Ablation as well for completeness
        plot_trajectory_comparison(
            results_B, r'$L$', 'exp_5_B_attr.png', "Attribute Value ($x$)", data_key='traj_x'
        )
    except FileNotFoundError as e:
        print(e)