import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.datasets import make_swiss_roll, make_s_curve
import os
from pathlib import Path
import sys

project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))
from config import get_results_dir

# Paths to results for each dataset
base_dir = os.path.join(get_results_dir(), 'paper_results', 'unimodal') + '/'
settings = 'multiseed_n-10000_width-1_noise-0.0_rsq-0.05'

paths = {
    'hypersphere': f'{base_dir}hypersphere_{settings}/',
    'swissroll': f'{base_dir}swissroll_{settings}/',
    'scurve': f'{base_dir}scurve_{settings}/',
}

# Specify seeds for latent representation visualizations for each dataset
dataset_seeds = {
    'hypersphere': 42,    # Seed for hypersphere latent representation
    'swissroll': 42,      # Seed for swiss roll latent representation  
    'scurve': 42          # Seed for s-curve latent representation
}

n_samples = 10000
noise = 0.0
n_samples_train = int(0.9 * n_samples)

# Create figure with proper 3D subplots - now 3x5 for R² column
fig = plt.figure(figsize=(22, 12))
datasets = ['hypersphere', 'swissroll', 'scurve']

def make_hypersphere_data(n_samples, seed, noise=0.0):
    """Generate hypersphere data matching the experiment logic"""
    np.random.seed(seed)
    # Generate 3D hypersphere data using the SAME method as 032_hypersphere.py
    # Uniformly sample points on the surface of a 3D sphere
    phi = np.random.uniform(0, np.pi, n_samples)
    theta = np.random.uniform(0, 2 * np.pi, n_samples)
    radius = 1.0
    x = radius * np.sin(phi) * np.cos(theta)
    y = radius * np.sin(phi) * np.sin(theta)
    z = radius * np.cos(phi)
    data = np.stack([x, y, z], axis=1)
    # Add noise if specified
    if noise > 0.0:
        data += np.random.normal(0, noise, data.shape)
    # Use phi for coloring (same as original experiment)
    color = phi
    return data, color

for i, ds in enumerate(datasets):
    print(f"Processing dataset: {ds}")
    print(f"Looking for data in: {paths[ds]}")
    
    # Check if directory exists
    if not os.path.exists(paths[ds]):
        print(f"Directory does not exist: {paths[ds]}")
        # List available directories
        parent_dir = os.path.dirname(paths[ds])
        if os.path.exists(parent_dir):
            print(f"Available directories in {parent_dir}:")
            for item in os.listdir(parent_dir):
                if ds.lower() in item.lower():
                    print(f"  - {item}")
        continue
    
    # Regenerate original data to match experiment - use dataset-specific seed
    dataset_seed = dataset_seeds[ds]
    np.random.seed(dataset_seed)
    if ds == 'hypersphere':
        X_original, color_orig = make_hypersphere_data(n_samples, dataset_seed, noise)
    elif ds == 'swissroll':
        X_original, color_orig = make_swiss_roll(n_samples=n_samples, noise=noise, random_state=dataset_seed)
    elif ds == 'scurve':
        X_original, color_orig = make_s_curve(n_samples=n_samples, noise=noise, random_state=dataset_seed)
    
    print(f"Generated {ds} data shape: {X_original.shape}")
    
    # Plot original data in 3D (column 1)
    ax1 = fig.add_subplot(3, 5, 5*i + 1, projection='3d')
    ax1.scatter(X_original[:n_samples_train, 0], X_original[:n_samples_train, 1], X_original[:n_samples_train, 2], 
               c=color_orig[:n_samples_train], cmap=plt.cm.viridis, s=8, alpha=0.4)
    ax1.set_title(f'{ds.capitalize()} Data')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')
    
    # Load seed-specific data for R² and latent representations
    dataset_seed = dataset_seeds[ds]
    curves_file = os.path.join(paths[ds], f'seed_{dataset_seed}_results_curves.npz')
    print(f"Looking for curves file: {curves_file}")
    
    # R² values (column 2) - for the visualized seed only
    ax2 = fig.add_subplot(3, 5, 5*i + 2)
    try:
        if not os.path.exists(curves_file):
            print(f"Curves file does not exist: {curves_file}")
            # List available files in directory
            print(f"Available files in {paths[ds]}:")
            for file in os.listdir(paths[ds]):
                if 'seed_' in file and file.endswith('_curves.npz'):
                    print(f"  - {file}")
            ax2.text(0.5, 0.5, f'No R² data for {ds}', ha='center', va='center')
            ax2.set_xlim(0, 1)
            ax2.set_ylim(0, 1)
        else:
            curves = np.load(curves_file, allow_pickle=True)
            print(f"Loaded curves with keys: {list(curves.keys())}")
            
            # Plot R² values
            if 'detailed_rsquares' in curves.files and 'detailed_rsquare_epochs' in curves.files:
                rsquares = curves['detailed_rsquares']
                rsquare_epochs = curves['detailed_rsquare_epochs']
                
                if len(rsquares) > 0 and len(rsquare_epochs) > 0:
                    ax2.plot(rsquare_epochs, rsquares, 'g-o', linewidth=2, markersize=4, label=f'Seed {dataset_seed}')
                    # Threshold line at first_r2_value - threshold
                    first_r2_value = rsquares[0]
                    threshold = 0.05
                    threshold_line = first_r2_value - threshold
                    ax2.axhline(y=threshold_line, color='r', linestyle='--', alpha=0.7, 
                               label=f'Threshold ({first_r2_value:.3f} - {threshold})')
                ax2.set_title(f'{ds.capitalize()} R² (Seed {dataset_seed})')
                ax2.set_xlabel('Epoch')
                ax2.set_ylabel('R²')
                ax2.legend(fontsize=8)
                ax2.grid(True)
            else:
                ax2.text(0.5, 0.5, f'No R² data for {ds}', ha='center', va='center')
                ax2.set_xlim(0, 1)
                ax2.set_ylim(0, 1)
                
    except Exception as e:
        print(f"Error loading R² data for {ds}: {e}")
        ax2.text(0.5, 0.5, f'R² error: {str(e)[:30]}', ha='center', va='center')
        ax2.set_xlim(0, 1)
        ax2.set_ylim(0, 1)

    # Load latent representations (column 3)
    try:
        if os.path.exists(curves_file):
            curves = np.load(curves_file, allow_pickle=True)
            latent_data = curves['representations']
            color_latent = curves['color']
            print(f"Latent data shape: {latent_data.shape}, Color shape: {color_latent.shape}")
            
            ax3 = fig.add_subplot(3, 5, 5*i + 3)
            ax3.scatter(latent_data[:, 0], latent_data[:, 1], c=color_latent, cmap=plt.cm.viridis, s=8)
            ax3.set_title(f'{ds.capitalize()} Latent (Seed {dataset_seed})')
            ax3.set_xlabel('Latent Dim 1')
            ax3.set_ylabel('Latent Dim 2')
            ax3.grid(True)
        else:
            ax3 = fig.add_subplot(3, 5, 5*i + 3)
            ax3.text(0.5, 0.5, f'No latent for {ds}', ha='center', va='center')
            ax3.set_xlim(0, 1)
            ax3.set_ylim(0, 1)
            
    except Exception as e:
        print(f"Error loading latent data for {ds}: {e}")
        ax3 = fig.add_subplot(3, 5, 5*i + 3)
        ax3.text(0.5, 0.5, f'No latent for {ds}\n{str(e)[:50]}', ha='center', va='center')
        ax3.set_xlim(0, 1)
        ax3.set_ylim(0, 1)

    # Average training loss curves (column 3)
    avg_curves_file = os.path.join(paths[ds], 'averaged_curves.npz')
    print(f"Looking for averaged curves: {avg_curves_file}")
    
    try:
        if not os.path.exists(avg_curves_file):
            print(f"Averaged curves file does not exist: {avg_curves_file}")
        else:
            avg_curves = np.load(avg_curves_file, allow_pickle=True)
            print(f"Loaded averaged curves with keys: {list(avg_curves.keys())}")
            avg_losses = avg_curves['avg_losses']
            std_losses = avg_curves['std_losses']
            # Calculate SEM (standard error of mean) = std / sqrt(n_seeds)
            n_seeds = len(avg_curves['seeds']) if 'seeds' in avg_curves.files else 5  # default to 5 seeds
            sem_losses = std_losses / np.sqrt(n_seeds)
            epochs = np.arange(len(avg_losses))
            print(f"Loss data - length: {len(avg_losses)}, n_seeds: {n_seeds}")
            
            ax4 = fig.add_subplot(3, 5, 5*i + 4)
            ax4.plot(epochs, avg_losses, 'b-', linewidth=2, label='Mean')
            ax4.fill_between(epochs, avg_losses - sem_losses, avg_losses + sem_losses, alpha=0.3, color='blue', label='±1 SEM')
            ax4.set_title(f'{ds.capitalize()} Avg Loss')
            ax4.set_xlabel('Epoch')
            ax4.set_ylabel('Loss')
            ax4.legend()
            ax4.grid(True)
            
            # Average total ranks (column 5)
            avg_ranks = avg_curves['avg_ranks']
            std_ranks = avg_curves['std_ranks']
            # Calculate SEM (standard error of mean) = std / sqrt(n_seeds)
            sem_ranks = std_ranks / np.sqrt(n_seeds)
            epochs_r = np.arange(len(avg_ranks))
            print(f"Rank data - length: {len(avg_ranks)}")
            
            ax5 = fig.add_subplot(3, 5, 5*i + 5)
            ax5.plot(epochs_r, avg_ranks, 'r-', linewidth=2, label='Mean')
            ax5.fill_between(epochs_r, avg_ranks - sem_ranks, avg_ranks + sem_ranks, alpha=0.3, color='red', label='±1 SEM')
            ax5.axhline(y=2, color='k', linestyle='--', alpha=0.7, label='Target Rank (2)')
            ax5.set_title(f'{ds.capitalize()} Avg Rank')
            ax5.set_xlabel('Epoch')
            ax5.set_ylabel('Total Rank')
            ax5.legend()
            ax5.grid(True)
        
    except Exception as e:
        print(f"Error loading averaged curves for {ds}: {e}")
        ax4 = fig.add_subplot(3, 5, 5*i + 4)
        ax4.text(0.5, 0.5, f'No avg loss for {ds}\n{str(e)[:50]}', ha='center', va='center')
        ax4.set_xlim(0, 1)
        ax4.set_ylim(0, 1)
        
        ax5 = fig.add_subplot(3, 5, 5*i + 5)
        ax5.text(0.5, 0.5, f'No avg rank for {ds}\n{str(e)[:50]}', ha='center', va='center')
        ax5.set_xlim(0, 1)
        ax5.set_ylim(0, 1)
    
    print(f"Finished processing {ds}\n")

# Create output directory if it doesn't exist
output_dir = os.path.join(get_results_dir(), 'paper_results', 'figures')
os.makedirs(output_dir, exist_ok=True)

plt.tight_layout()
output_file = os.path.join(output_dir, 'unimodal_3d_analysis.png')
plt.savefig(output_file, dpi=200, bbox_inches='tight')
print(f"Figure saved to: {output_file}")
plt.show()
