import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import numpy as np

# To run this script, make sure you are in the project root directory and run:
# python -m plotting.experiment_2_plot

"""
Plotting for Experiment 2

This script visualizes the 2D sweep data as heatmaps.

Plots Generated:
1. Max Levels Heatmap:
   - Color indicates the number of levels constructed.
   
2. Max Attribute Heatmap:
   - Color indicates the highest incentivizable attribute (mu_L).

Outputs:
    - A directory 'plots/exp_2/' created if it does not exist.
    - exp_2_heatmap_levels.png/pdf: Heatmap showing the number of levels (L).
    - exp_2_heatmap_attribute.png/pdf: Heatmap showing the max attribute (mu_L).
"""

def load_data():
    file_path = os.path.join('data', 'exp_2', 'exp_2_data.pkl')
    if not os.path.exists(file_path):
        raise FileNotFoundError(
            f"Data file not found at {file_path}.\n"
            "Please run 'python -m experiments.experiment_2' first."
        )
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def plot_heatmap(matrix, x_vals, y_vals, title, cbar_label, filename, cmap='viridis'):
    plot_dir = os.path.join('plots', 'exp_2')
    if not os.path.exists(plot_dir): 
        os.makedirs(plot_dir)
    
    # Set style
    sns.set_theme(context="paper", style="white", font_scale=1.2)
    plt.figure(figsize=(8, 7)) 
    
    xticklabels = [f"{x:.2f}" for x in x_vals]
    yticklabels = [f"{y:.2f}" for y in y_vals]
    
    # Create Heatmap
    ax = sns.heatmap(
        np.flipud(matrix), 
        xticklabels=xticklabels, 
        yticklabels=list(reversed(yticklabels)),
        cmap=cmap,
        annot=False,       
        cbar_kws={'label': cbar_label}
    )

    # Increase Colorbar label size
    ax.figure.axes[-1].yaxis.label.set_size(15)
    
    # Reduce tick density if too crowded
    if len(x_vals) > 10:
        ax.set_xticks(ax.get_xticks()[::2])
        ax.set_xticklabels(xticklabels[::2])
        ax.set_yticks(ax.get_yticks()[::2])
        ax.set_yticklabels(list(reversed(yticklabels))[::2])
    
    plt.xlabel(r"Retention Rate ($\gamma$)", fontsize=15)
    plt.ylabel(r"Discount Factor ($\beta$)", fontsize=15)
    plt.yticks(rotation=0)
    
    plt.tight_layout()

    base_name = os.path.splitext(filename)[0]
    save_path_base = os.path.join(plot_dir, base_name)
    
    plt.savefig(f"{save_path_base}.png", dpi=300)
    plt.savefig(f"{save_path_base}.pdf")
    
    print(f"Saved: {base_name}.png and {base_name}.pdf")
    plt.close()

if __name__ == "__main__":
    print("Loading data...")
    res = load_data()
    
    beta_vals = res['beta_vals']
    gamma_vals = res['gamma_vals']
    
    # 1. Plot Max Levels Heatmap
    print("Generating Max Levels Heatmap...")
    plot_heatmap(
        res['max_levels'], 
        gamma_vals, 
        beta_vals, 
        title="Maximum Incentivizable Number of Levels",
        cbar_label="Number of Levels",
        filename="exp_2_heatmap_levels.png",
        cmap="Blues"
    )
    
    # 2. Plot Max Attribute Heatmap
    print("Generating Max Attribute Heatmap...")
    plot_heatmap(
        res['max_attribute'], 
        gamma_vals, 
        beta_vals, 
        title=r"Maximum Incentivizable Attribute",
        cbar_label="Max Attribute Value",
        filename="exp_2_heatmap_attribute.png",
        cmap="rocket_r" 
    )