import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import numpy as np

# To run this script:
# python -m plotting.experiment_1_plot

"""
Plotting for Experiment 1

This script visualizes how the internal structure of the threshold sequence
(mu_2, mu_3, mu_4, mu_5) evolves as we vary environmental parameters.

It generates four line plots:
1. Threshold Evolution vs Retention (gamma)
2. Threshold Evolution vs Patience (beta)
3. Threshold Evolution vs Improvement Cost (c_plus)
4. Threshold Evolution vs Gaming Cost (c_minus)

Input: data/exp_1/exp_1_data.pkl (Generated by experiment_1.py)
Output: plots/exp_1/
"""

def load_data():
    file_path = os.path.join('data', 'exp_1', 'exp_1_data.pkl')
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Data file not found at {file_path}. Run experiment_1.py first.")
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def plot_threshold_evolution(data, param_label, filename):
    plot_dir = os.path.join('plots', 'exp_1')
    if not os.path.exists(plot_dir): 
        os.makedirs(plot_dir)
    
    # Set professional style
    sns.set_theme(context="paper", style="darkgrid", palette="deep", font_scale=1.2)
    plt.figure(figsize=(8, 6))
    
    x_vals = data['x_vals']
    
    lines = [
        ('mu_2', r'$\mu_2$', 'C0', 'o'),
        ('mu_3', r'$\mu_3$', 'C1', 's'),
        ('mu_4', r'$\mu_4$', 'C2', '^'),
        ('mu_5', r'$\mu_5$', 'C3', 'D')
    ]
    
    for key, label, color, marker in lines:
        if key in data:
            y_vals = data[key]
            plt.plot(x_vals, y_vals, 
                     label=label, 
                     color=color, 
                     marker=marker, 
                     linewidth=2, 
                     markersize=6, 
                     alpha=0.8)

    plt.xlabel(param_label, fontsize=16)
    plt.ylabel("Attribute Value", fontsize=16)
    plt.legend(
        title="Threshold Level", 
        loc='upper left', 
        frameon=True, 
        framealpha=0.9,
        fontsize=14,        # Size of the labels (mu_2, etc.)
        title_fontsize=15   # Size of the title "Threshold Level"
    )
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # --- SAVE LOGIC ---
    # Construct the base path without extension
    base_name = os.path.splitext(filename)[0]
    save_path_base = os.path.join(plot_dir, base_name)
    
    # Save as PNG
    plt.savefig(f"{save_path_base}.png", dpi=300)
    # Save as PDF
    plt.savefig(f"{save_path_base}.pdf")
    
    print(f"Saved plots: {base_name}.png and {base_name}.pdf")
    plt.close()

if __name__ == "__main__":
    print("Loading Experiment 1 Data...")
    try:
        results = load_data()
        
        # 1. Gamma Sweep
        print("Plotting Gamma Sensitivity...")
        plot_threshold_evolution(
            results['gamma'], 
            r'Retention Rate ($\gamma$)', 
            'exp_1_gamma_evolution.png'
        )
        
        # 2. Beta Sweep
        print("Plotting Beta Sensitivity...")
        plot_threshold_evolution(
            results['beta'], 
            r'Discount Factor ($\beta$)', 
            'exp_1_beta_evolution.png'
        )
        
        # 3. Improvement Cost Sweep
        print("Plotting Improvement Cost Sensitivity...")
        plot_threshold_evolution(
            results['c_plus'], 
            r'Improvement Cost ($c^+$)', 
            'exp_1_cplus_evolution.png'
        )

        # 4. Gaming Cost Sweep (Check if exists first)
        if 'c_minus' in results:
            print("Plotting Gaming Cost Sensitivity...")
            plot_threshold_evolution(
                results['c_minus'], 
                r'Gaming Cost ($c^-$)', 
                'exp_1_cminus_evolution.png'
            )
        
        print("\nDone! Check the 'plots/exp_1/' directory.")
        
    except Exception as e:
        print(f"Error: {e}")