import argparse
import json
import torch
import numpy as np
import os 
import pickle
import sys
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

# Configure matplotlib for LaTeX-style plots without requiring LaTeX installation
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Computer Modern', 'Times New Roman', 'DejaVu Serif'],
    'text.usetex': False,  # Use matplotlib's LaTeX-style fonts instead
    'mathtext.fontset': 'cm',  # Computer Modern math fonts
    'font.size': 24,
    'axes.titlesize': 24,
    'axes.labelsize': 24,
    'xtick.labelsize': 24,
    'ytick.labelsize': 24,
    'legend.fontsize': 24,
    'figure.titlesize': 22,
    'lines.linewidth': 3.5,
    'axes.linewidth': 1.5,
    'xtick.major.width': 1.2,
    'ytick.major.width': 1.2,
    'grid.linewidth': 1.0,
    'axes.grid': True,
    'grid.alpha': 0.4
})

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from SAE.unlearning_utils import compute_feature_importance

def load_latents_dict(latents_path):
    if latents_path.endswith('.pt') or latents_path.endswith('.pth'):
        return torch.load(latents_path)
    elif latents_path.endswith('.npy'):
        arr = np.load(latents_path, allow_pickle=True).item()
        return {k: torch.tensor(v) for k, v in arr.items()}
    elif latents_path.endswith('.pkl'):
        with open(latents_path, 'rb') as f:
            return pickle.load(f)
    else:
        raise ValueError("Unsupported file type for latents dict.")

def plot_concept_scores(all_scores, output_dir=None, save_format='pdf', plot_style='line', marker_size=6, line_width=2.5):
    """
    Plot average scores for each concept with neuron index on x-axis and score on y-axis.
    Enhanced for LaTeX paper publication with better visibility and typography.
    Y-limits are dynamically set based on the data range for each concept.
    
    Args:
        all_scores (dict): Dictionary with concept names as keys and lists of score arrays as values
        output_dir (str, optional): Directory to save plots. If None, plots are displayed
        save_format (str): Format to save plots ('png', 'pdf', 'svg', etc.)
        plot_style (str): 'line', 'scatter', or 'both'
        marker_size (float): Size of markers for scatter plots
        line_width (float): Width of lines for line plots
    """
    concepts = list(all_scores.keys())
    
    # Create output directory if specified
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    
    # Individual plots for each concept (averaged across timesteps)
    for concept in concepts:
        timestep_scores = all_scores[concept]
        
        # Convert to numpy arrays and compute average across timesteps
        score_arrays = [np.array(scores) for scores in timestep_scores]
        avg_scores = np.mean(score_arrays, axis=0)
        neuron_indices = np.arange(len(avg_scores))
        
        # Create figure with appropriate size for papers
        fig, ax = plt.subplots(figsize=(12, 6))
        
        if plot_style == 'line':
            ax.plot(neuron_indices, avg_scores, linewidth=line_width, color='#1f77b4')
        elif plot_style == 'scatter':
            ax.scatter(neuron_indices, avg_scores, s=marker_size, alpha=0.7, color='#1f77b4')
        elif plot_style == 'both':
            ax.plot(neuron_indices, avg_scores, linewidth=line_width, alpha=0.8, color='#1f77b4')
            ax.scatter(neuron_indices, avg_scores, s=marker_size, alpha=0.9, color='#ff7f0e')
        
        # Enhanced title and labels (no LaTeX escaping needed)
        # ax.set_title(concept, fontweight='bold', pad=20)
        ax.set_xlabel('Neuron Index', fontweight='bold')
        ax.set_ylabel('Score', fontweight='bold')
        
        # Dynamically set y-limits based on data range
        min_score = np.min(avg_scores)
        max_score = np.max(avg_scores)
        
        # Add some padding to the limits (10% of the range)
        score_range = max_score - min_score
        padding = score_range * 0.1 if score_range > 0 else 0.001
        
        y_min = min_score - padding
        y_max = max_score + padding
        
        ax.set_ylim(y_min, y_max)
        
        # Enhanced grid
        ax.grid(True, alpha=0.4, linestyle='--', linewidth=1.0)
        ax.set_axisbelow(True)
        
        # Add statistics with improved formatting
        mean_score = np.mean(avg_scores)
        max_neuron = np.argmax(avg_scores)
        
        # Thicker mean line
        ax.axhline(mean_score, color='red', linestyle='--', alpha=0.8, linewidth=2.0,
                   label=f'Mean: {mean_score:.4f}')
        
        # Enhanced legend
        # ax.legend(loc='upper right', frameon=True, fancybox=True, shadow=True)
        
        # Improved statistics text box
        textstr = ( f'Max: {max_score:.4f} (neuron {max_neuron})\n'
                  f'Min: {min_score:.4f}')
        
        props = dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.8, 
                    edgecolor='black', linewidth=1.2)
        ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=20,
               verticalalignment='top', bbox=props, family='monospace')
        
        # Improve tick formatting
        ax.tick_params(axis='both', which='major', length=6, width=1.2)
        ax.tick_params(axis='both', which='minor', length=4, width=1.0)
        
        # Add minor ticks
        ax.minorticks_on()
        
        plt.tight_layout()
        
        if output_dir:
            filename = f'{concept}_avg_scores.{save_format}'
            filepath = os.path.join(output_dir, filename)
            plt.savefig(filepath, dpi=300, bbox_inches='tight', 
                       facecolor='white', edgecolor='none',
                       format=save_format, pad_inches=0.1)
            print(f"Saved average score plot for {concept} to {filepath}")
        else:
            plt.show()
        
        plt.close()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compute and save concept feature importance scores with enhanced plotting for LaTeX papers.")
    parser.add_argument("--model_checkpoint", type=str, required=False, 
                        help="Path to the model checkpoint (not used for score computation, but kept for clarity).")
    parser.add_argument("--latents_path", type=str, required=True, 
                        help="Path to the concept latents dictionary (.pt or .npy).")
    parser.add_argument("--concept_type", type=str, choices=["objects", "styles"], required=True,
                        help="Type of concepts (objects or styles).")
    parser.add_argument("--num_timesteps", type=int, required=True,
                        help="Number of timesteps to process per concept.")
    parser.add_argument("--output_json", type=str, required=True,
                        help="Path to output JSON file.")
    
    # Score plotting arguments with enhanced defaults
    parser.add_argument("--plot_scores", action="store_true",
                        help="Generate score plots for each concept (neuron index vs score).")
    parser.add_argument("--plot_output_dir", type=str, default=None,
                        help="Directory to save plots. If not specified, plots will be displayed.")
    parser.add_argument("--save_format", type=str, default="pdf", choices=["png", "pdf", "svg", "jpg"],
                        help="Format to save plots.")
    parser.add_argument("--plot_style", type=str, default="line", choices=["line", "scatter", "both"],
                        help="Style of plot: line, scatter, or both.")
    parser.add_argument("--marker_size", type=float, default=10.0,
                        help="Size of markers for scatter plots.")
    parser.add_argument("--line_width", type=float, default=3.5,
                        help="Width of lines for line plots.")
    parser.add_argument("--plot_timesteps_concept", type=str, default=None,
                        help="Generate timestep-specific histograms for a particular concept.")
    parser.add_argument("--bins", type=int, default=50,
                        help="Number of bins for histograms.")
    parser.add_argument("--alpha", type=float, default=0.8,
                        help="Transparency for histogram bars.")
    
    args = parser.parse_args()
    main(args)