#%%
import h5py
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from typing import Dict, List, Tuple
import seaborn as sns
from tqdm import tqdm

def load_activations(file_path: Path) -> Tuple[Dict[int, np.ndarray], Dict[str, np.ndarray]]:
    """Load activations and metadata from H5 file."""
    with h5py.File(file_path, 'r') as f:
        # Load activations
        activations = {}
        for layer_idx in f['activations'].keys():
            activations[int(layer_idx)] = f['activations'][layer_idx][:]
        
        # Load metadata
        metadata = {}
        for key in f['metadata'].keys():
            metadata[key] = f['metadata'][key][:]
            # Convert byte strings to regular strings if needed
            if isinstance(metadata[key][0], bytes):
                metadata[key] = [s.decode('utf-8') for s in metadata[key]]
            
    return activations, metadata

def compute_pca(activations: np.ndarray, n_components: int = 2) -> Tuple[np.ndarray, PCA]:
    """Compute PCA on activation data."""
    pca = PCA(n_components=n_components)
    pca_result = pca.fit_transform(activations)
    return pca_result, pca

def create_pca_plots(
    base_acts, base_meta,
    ft_acts, ft_meta,
    splits: List[str] = ['test', 'validation'],
    n_components: int = 2
):
    """Create PCA plots comparing base and finetuned model activations."""
    # Set up the plotting style
    sns.set_theme(style="whitegrid")
    
    # Colors for different conditions
    colors = {
        'normal_correct': '#3d81eb', 
        'normal_incorrect': '#355382', 
        'sandbagging_correct': '#e3469a',   
        'sandbagging_incorrect': '#782250'   
    }
    
    # Create a figure with subplots for each layer
    n_layers = len(base_acts)
    fig, axes = plt.subplots(2, n_layers, figsize=(6*n_layers, 10))
    fig.suptitle(f'PCA of Model Activations - {split.capitalize()} Split', fontsize=16, y=1.02)
    
    # Process each layer
    for i, layer_idx in enumerate(sorted(base_acts.keys())):
        print(f"Processing layer {layer_idx}")
        
        # Base model
        pca_results_base, pca = compute_pca(base_acts[layer_idx])
        
        # Project finetuned model activations onto same PCA space
        pca_results_ft = pca.transform(ft_acts[layer_idx])
        
        # Plot for base model
        ax = axes[0, i]
        for condition, color in colors.items():
            env, outcome = condition.split('_')
            mask = (
                (base_meta['sandbagging_environment'] == (env == 'sandbagging')) &
                (base_meta['is_correct'] == (outcome == 'correct'))
            )
            ax.scatter(
                pca_results_base[mask, 0],
                pca_results_base[mask, 1],
                c=color,
                alpha=0.6,
                label=condition.replace('_', ' ').title(),
                s=50
            )
        ax.set_title(f'Base Model - Layer {layer_idx}')
        if i == 0:
            ax.set_ylabel('PC2')
            ax.legend(bbox_to_anchor=(0, 1.02, 1, 0.2), loc="lower left", mode="expand", ncol=2)
        
        # Plot for finetuned model
        ax = axes[1, i]
        for condition, color in colors.items():
            env, outcome = condition.split('_')
            mask = (
                (ft_meta['sandbagging_environment'] == (env == 'sandbagging')) &
                (ft_meta['is_correct'] == (outcome == 'correct'))
            )
            ax.scatter(
                pca_results_ft[mask, 0],
                pca_results_ft[mask, 1],
                c=color,
                alpha=0.6,
                label=condition.replace('_', ' ').title(),
                s=50
            )
        ax.set_title(f'Finetuned Model - Layer {layer_idx}')
        ax.set_xlabel('PC1')
        if i == 0:
            ax.set_ylabel('PC2')
    
    # Adjust layout and save
    plt.tight_layout()
    
    # Create output directory if it doesn't exist
    output_dir = Path('figures')
    output_dir.mkdir(exist_ok=True)
    
    # Save the figure
    plt.savefig(output_dir / f'pca_comparison_{split}.png', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    # Configuration
    base_dir = Path('activations')
    base_model_path = base_dir / 'gemma-2-9b-it'
    finetuned_model_path = base_dir / 'gemma-sandbagging-0w4j7rba-step1024'

    split = 'test'

    # Load data for both models
    base_acts, base_meta = load_activations(base_model_path / f"{split}.h5")
    ft_acts, ft_meta = load_activations(finetuned_model_path / f"{split}.h5")

    create_pca_plots(base_acts, base_meta, ft_acts, ft_meta)

    print("\nPlots have been saved to the 'figures' directory")

if __name__ == "__main__":
    main()