import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.datasets import make_swiss_roll, make_s_curve
import os
from pathlib import Path
import sys

project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))
from config import get_results_dir

# Configure matplotlib to match paper style (see 040_mm_paramsim_plots.py)
# Start from the default style then apply our rcParams so any later style resets
# don't accidentally revert our font choice when this script is run standalone.
plt.style.use('default')
plt.rcParams.update({
    'font.size': 11,
    'axes.titlesize': 11,
    'axes.labelsize': 11,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 12,
    # Force Times New Roman explicitly to match paper script
    'font.family': 'sans-serif',
    # Prefer Times New Roman for publication consistency (fallbacks included)
    #'font.serif': ['Times New Roman', 'Times', 'Palatino', 'DejaVu Serif', 'serif'],
    #'pdf.fonttype': 42,
    #'ps.fonttype': 42,
    'mathtext.fontset': 'stix',
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
    'axes.linewidth': 0.8,
    'grid.linewidth': 0.5,
})

# Paths to results for each dataset
base_dir = os.path.join(get_results_dir(), 'paper_results', 'unimodal') + '/'
settings = 'multiseed_n-10000_width-1_noise-0.0_rsq-0.05'

paths = {
    'hypersphere': f'{base_dir}hypersphere_{settings}/',
    'swissroll': f'{base_dir}swissroll_{settings}/',
    'scurve': f'{base_dir}scurve_{settings}/',
}

# Specify seeds for latent representation visualizations for each dataset
dataset_seeds = {
    'hypersphere': 42,    # Seed for hypersphere latent representation
    'swissroll': 42,      # Seed for swiss roll latent representation  
    'scurve': 42          # Seed for s-curve latent representation
}

n_samples = 10000
noise = 0.0
n_samples_train = int(0.9 * n_samples)

# Minimized figure: only hypersphere row, 1 x 5 columns
# Use same figure size as paper combined plots (9.45 inches in width)
fig = plt.figure(figsize=(9.45, 2.32))
datasets = ['hypersphere']

def make_hypersphere_data(n_samples, seed, noise=0.0):
    """Generate hypersphere data matching the experiment logic"""
    np.random.seed(seed)
    # Generate 3D hypersphere data using the SAME method as 032_hypersphere.py
    # Uniformly sample points on the surface of a 3D sphere
    phi = np.random.uniform(0, np.pi, n_samples)
    theta = np.random.uniform(0, 2 * np.pi, n_samples)
    radius = 1.0
    x = radius * np.sin(phi) * np.cos(theta)
    y = radius * np.sin(phi) * np.sin(theta)
    z = radius * np.cos(phi)
    data = np.stack([x, y, z], axis=1)
    # Add noise if specified
    if noise > 0.0:
        data += np.random.normal(0, noise, data.shape)
    # Use phi for coloring (same as original experiment)
    color = phi
    return data, color

for i, ds in enumerate(datasets):
    print(f"Processing dataset: {ds}")
    print(f"Looking for data in: {paths[ds]}")

    # Regenerate original hypersphere data using the chosen seed
    dataset_seed = dataset_seeds[ds]
    X_original, color_orig = make_hypersphere_data(n_samples, dataset_seed, noise)

    # Plot original data in 3D (column 1)
    ax1 = fig.add_subplot(1, 5, 1, projection='3d')
    ax1.scatter(X_original[:n_samples_train, 0], X_original[:n_samples_train, 1], X_original[:n_samples_train, 2], 
               c=color_orig[:n_samples_train], cmap=plt.cm.viridis, s=1, alpha=0.2)
    ax1.set_title(f'Sphere')
    # Hide tick marks and tick labels for the 3D data panel, but show tight axis labels
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_zticks([])
    # Put axis labels very close to the axes
    ax1.set_xlabel('X', fontsize=11, labelpad=-10)
    ax1.set_ylabel('Y', fontsize=11, labelpad=-10)
    # z-label often needs a slightly different pad to appear close on 3D axes
    ax1.set_zlabel('Z', fontsize=11, labelpad=-12)

    # Load seed-specific curves file
    curves_file = os.path.join(paths[ds], f'seed_{dataset_seed}_results_curves.npz')
    print(f"Looking for curves file: {curves_file}")

    # Load curves once and reuse for multiple subplots
    curves = None
    if os.path.exists(curves_file):
        try:
            curves = np.load(curves_file, allow_pickle=True)
        except Exception:
            curves = None

    # Column 2: Latent representation (seed-specific)
    ax2 = fig.add_subplot(1, 5, 2)
    if curves is not None:
        try:
            if 'representations' in curves.files and 'color' in curves.files:
                latent_data = curves['representations']
                color_latent = curves['color']
                ax2.scatter(latent_data[:, 0], latent_data[:, 1], c=color_latent, cmap=plt.cm.viridis, s=4, alpha=0.6)
                ax2.set_title(f'Representation')
                ax2.set_xlabel('Latent 1')
                ax2.set_ylabel('Latent 2')
            else:
                ax2.text(0.5, 0.5, 'No latent data', ha='center', va='center')
        except Exception:
            ax2.text(0.5, 0.5, 'Latent error', ha='center', va='center')
    else:
        ax2.text(0.5, 0.5, 'Curves file missing', ha='center', va='center')

    # Column 3: Rank for the selected seed (seed-specific) - moved to column 3 per request
    ax3 = fig.add_subplot(1, 5, 3)
    if curves is not None:
        try:
            seed_ranks = None
            for k in ['ranks', 'total_ranks', 'detailed_ranks', 'ranks_over_epochs']:
                if k in curves.files:
                    seed_ranks = curves[k]
                    break
            if seed_ranks is not None:
                epochs_r = np.arange(len(seed_ranks))
                # Rank trace: green (colorblind-friendly)
                ax3.plot(epochs_r, seed_ranks, color='#33a02c', linewidth=2)
                # Baseline rank line: light gray dashed
                ax3.axhline(y=2, color='#cccccc', linestyle='--', linewidth=1, alpha=0.9)
                ax3.set_title(f'ID Estimate')
                ax3.set_xlabel('Epoch')
                ax3.set_ylabel('Rank')
            else:
                ax3.text(0.5, 0.5, 'No seed rank data', ha='center', va='center')
        except Exception:
            ax3.text(0.5, 0.5, 'Rank error', ha='center', va='center')
    else:
        ax3.text(0.5, 0.5, 'Curves file missing', ha='center', va='center')

    # Column 4: Loss for the selected seed (seed-specific)
    ax4 = fig.add_subplot(1, 5, 4)
    if os.path.exists(curves_file):
        try:
            seed_losses = None
            print(curves)
            # Try several possible keys
            for k in ['losses', 'train_losses', 'detailed_losses', 'loss_curve', 'final_losses']:
                if k in curves.files:
                    seed_losses = curves[k]
                    break
            if seed_losses is not None:
                epochs = np.arange(len(seed_losses))
                # Loss trace: blue (colorblind-friendly)
                ax4.plot(epochs, seed_losses, color='#1f78b4', linewidth=2)
                ax4.set_title(f'Train Loss')
                ax4.set_xlabel('Epoch')
                ax4.set_ylabel('MSE Loss')
            else:
                ax4.text(0.5, 0.5, 'No seed loss data', ha='center', va='center')
        except Exception:
            ax4.text(0.5, 0.5, 'Loss error', ha='center', va='center')
    else:
        ax4.text(0.5, 0.5, 'Curves file missing', ha='center', va='center')

    # Column 5: R² (seed-specific) - moved to column 5 per request
    ax5 = fig.add_subplot(1, 5, 5)
    if curves is not None:
        try:
            # R²
            if 'detailed_rsquares' in curves.files and 'detailed_rsquare_epochs' in curves.files:
                rsquares = curves['detailed_rsquares']
                rsquare_epochs = curves['detailed_rsquare_epochs']
                # R² trace: purple, dotted with markers (colorblind-friendly)
                ax5.plot(rsquare_epochs, rsquares, color='#6a51a3', linestyle=':', marker='o', linewidth=1, markersize=4)
                # Threshold line
                if len(rsquares) > 0:
                    first_r2_value = rsquares[0]
                    threshold_line = first_r2_value - 0.05
                    # Threshold line: light gray dashed for readability
                    ax5.axhline(y=threshold_line, color='#cccccc', linestyle='--', linewidth=1, alpha=0.9)
                ax5.set_title(f'Distortion Metric')
                ax5.set_xlabel('Epoch')
                ax5.set_ylabel('R²')
            else:
                ax5.text(0.5, 0.5, 'No R² data', ha='center', va='center')
        except Exception as e:
            ax5.text(0.5, 0.5, f'R² error', ha='center', va='center')
    else:
        ax5.text(0.5, 0.5, 'Curves file missing', ha='center', va='center')

    # Ensure epoch ticks are shared and aligned across R² (ax3), Loss (ax4), and Rank (ax5)
    if curves is not None:
        def _len_for_keys(c, keys):
            for kk in keys:
                if kk in c.files:
                    return len(c[kk])
            return 0

        rs_len = _len_for_keys(curves, ['detailed_rsquare_epochs', 'detailed_rsquares'])
        loss_len = _len_for_keys(curves, ['losses', 'train_losses', 'detailed_losses', 'loss_curve', 'final_losses'])
        rank_len = _len_for_keys(curves, ['ranks', 'total_ranks', 'detailed_ranks', 'ranks_over_epochs'])

        max_len = max(rs_len, loss_len, rank_len)
        if max_len > 0:
            # Use fixed tick positions at 0, 200, 400 but clamp to available epoch range
            desired = [0, 200, 400]
            tick_positions = [int(t) for t in desired if t <= max_len - 1]
            if not tick_positions:
                # Fallback: show first and last epoch
                tick_positions = [0, max_len - 1]
            for a in (ax3, ax4, ax5):
                a.set_xticks(tick_positions)
                a.set_xticklabels([str(int(t)) for t in tick_positions])

    print(f"Finished processing {ds}\n")

# Create output directory if it doesn't exist
output_dir = os.path.join(get_results_dir(), 'paper_results', 'figures')
os.makedirs(output_dir, exist_ok=True)

plt.tight_layout()
output_file = os.path.join(output_dir, 'unimodal_3d_analysis_sphere.png')
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {output_file}")
plt.show()
