import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import time
import random
import os

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.functions.train_larrp_unimodal import train_overcomplete_ae2

import argparse
parser = argparse.ArgumentParser(description='Swiss roll intrinsic dimensionality estimation')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use for the computation')
parser.add_argument('--n_samples', type=int, default=10000, help='number of samples')
parser.add_argument('--noise', type=float, default=0.0, help='noise level')
parser.add_argument('--epochs', type=int, default=5000, help='training epochs')
parser.add_argument('--r_square_threshold', type=float, default=0.05, help='R-squared threshold for rank reduction')
parser.add_argument('--threshold_type', type=str, default='absolute', choices=['relative', 'absolute'], help='threshold type')
args = parser.parse_args()

# Define multiple seeds to run
seeds = [0, 42, 554, 9306, 89413]

# Training configuration
class Args:
    def __init__(self):
        # latent
        self.latent_dim = 3

        # Training parameters
        self.batch_size = 512
        self.lr = 1e-3
        self.weight_decay = 2e-5
        self.dropout = 0.1
        self.epochs = args.epochs
        
        # Model architecture
        self.ae_depth = 2
        self.ae_width = 1
        #self.ae_width = 10
        
        # Rank reduction parameters
        self.rank_or_sparse = 'rank'
        
        # GPU parameters
        self.multi_gpu = False
        self.gpu_ids = ''
        self.gpu = args.gpu

train_args = Args()

# Method hyperparameters
method_hyperparameters = {
    "r_square_thresholds": args.r_square_threshold,
    "early_stopping": 50,
    "rank_reduction_frequencies": 10,
    "rank_reduction_thresholds": 0.01,
    "patiences": 10,
    #"patiences": 20
}

# Base output directory
base_out_dir = f"03_results/paper_results/unimodal/swissroll_multiseed_n-{args.n_samples}_width-{train_args.ae_width}_noise-{args.noise}_rsq-{args.r_square_threshold}"

# Set device
if torch.cuda.is_available():
    DEVICE = torch.device(f'cuda:{args.gpu}')
else:
    DEVICE = torch.device('cpu')

print(f"Using device: {DEVICE}")

# Create output directory
os.makedirs(base_out_dir, exist_ok=True)

# Storage for results across seeds
all_detailed_curves = []
all_results = []

# Loop through seeds
for seed_idx, seed in enumerate(seeds):
    print(f"\n{'='*60}")
    print(f"Running experiment with seed {seed} ({seed_idx+1}/{len(seeds)})")
    print(f"{'='*60}")
    
    # Set random seeds for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    print(f"Generating Swiss Roll with {args.n_samples} samples, noise={args.noise}")

    # Generate S-curve data
    from sklearn.datasets import make_s_curve
    s_curve_data, color = make_swiss_roll(n_samples=args.n_samples, noise=args.noise, random_state=seed)

    # Save original data for visualization
    X_original = s_curve_data.copy()

    # Standardize the data
    scaler = StandardScaler()
    data = scaler.fit_transform(s_curve_data)

    # Convert to torch tensor
    data = torch.tensor(data, dtype=torch.float32)
    n_samples_train = int(0.9 * args.n_samples)

    print(f"Data shape: {data.shape}")
    print(f"Training samples: {n_samples_train}")
    print(f"Validation samples: {args.n_samples - n_samples_train}")

    print(f"Training with parameters:")
    print(f"  - Latent dimension: {train_args.latent_dim}")
    print(f"  - R-squared threshold: {method_hyperparameters['r_square_thresholds']}")
    print(f"  - Threshold type: {args.threshold_type}")
    print(f"  - Early stopping: {method_hyperparameters['early_stopping']}")
    print(f"  - Rank reduction frequency: {method_hyperparameters['rank_reduction_frequencies']}")
    print(f"  - Patience: {method_hyperparameters['patiences']}")

    # Train the model
    print("\nStarting training...")
    start_time = time.time()

    model, reps, train_loss, r_squares, rank_history, loss_curves, detailed_curves = train_overcomplete_ae2(
        data, 
        n_samples_train,
        train_args.latent_dim,
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=method_hyperparameters["early_stopping"],
        initial_rank_ratio=1.0,
        rank_reduction_frequency=method_hyperparameters["rank_reduction_frequencies"],
        rank_reduction_threshold=method_hyperparameters["rank_reduction_thresholds"],
        warmup_epochs=method_hyperparameters["early_stopping"],
        patience=method_hyperparameters["patiences"],
        min_rank=1,
        r_square_threshold=method_hyperparameters["r_square_thresholds"],
        threshold_type=args.threshold_type,
        verbose=False,
        model_name=None,
        return_detailed_history=True,
        #lr_scheduler=True
    )

    end_time = time.time()
    training_time = (end_time - start_time) / 60  # in minutes

    print(f"\nTraining completed in {training_time:.2f} minutes")

    # Store detailed curves for averaging later
    all_detailed_curves.append(detailed_curves)

    # Process results
    temp_df = pd.DataFrame(rank_history)
    # Extract final ranks - handle both string and list formats
    final_ranks = rank_history["ranks"][-1] if "ranks" in rank_history and len(rank_history["ranks"]) > 0 else [0]
    print(f"Final ranks (raw): {final_ranks}")

    # If final_ranks is a string (comma-separated), parse it
    if isinstance(final_ranks, str):
        try:
            # Parse comma-separated string like "5, 3" into list of integers
            final_ranks = [int(x.strip()) for x in final_ranks.split(',')]
        except (ValueError, AttributeError):
            # If parsing fails, assume it's a single value
            try:
                final_ranks = [int(final_ranks)]
            except ValueError:
                final_ranks = [0]

    # Calculate final rank sum and ensure it's an integer
    final_rank_sum = sum(final_ranks) if isinstance(final_ranks, list) else int(final_ranks)

    # Convert final_ranks to string for DataFrame storage (scalar value for all rows)
    final_ranks_str = ', '.join(map(str, final_ranks)) if isinstance(final_ranks, list) else str(final_ranks)
    temp_df["final_ranks"] = final_ranks_str
    temp_df["final_rank_sum"] = final_rank_sum

    # Add experiment parameters
    temp_df["n_samples"] = args.n_samples
    temp_df["noise"] = args.noise
    temp_df["r_square_threshold"] = args.r_square_threshold
    temp_df["threshold_type"] = args.threshold_type
    temp_df["seed"] = seed
    temp_df["training_time_minutes"] = training_time
    temp_df["final_loss"] = train_loss
    temp_df["final_r_square"] = r_squares

    # Store results for averaging
    all_results.append({
        'seed': seed,
        'final_rank_sum': final_rank_sum,
        'final_loss': train_loss,
        'final_r_square': r_squares,
        'training_time': training_time
    })

    print(f"Final results for seed {seed}:")
    print(f"  - Final loss: {train_loss:.6f}")
    print(f"  - Final R-squared: {r_squares:.6f}")
    print(f"  - Final ranks: {final_ranks}")
    print(f"  - Training time: {training_time:.2f} minutes")

    # Save individual seed results
    out_file = f"{base_out_dir}/seed_{seed}_results.csv"
    print(f"\nSaving results to: {out_file}")

    temp_df.to_csv(out_file, mode='w', header=True, index=False)
    print("Results saved to file")

    # Save detailed training curves for this seed
    curves_file = out_file.replace('.csv', '_curves.npz')
    # Unpack detailed curves for separate saving
    all_losses, all_rsquares, all_rsquare_epochs, all_ranks, all_lrs = detailed_curves
    np.savez(curves_file, 
             loss_curves=loss_curves,
             rank_history=rank_history,
             representations=reps.cpu().numpy(),
             color=color[:n_samples_train],
             detailed_losses=all_losses,
             detailed_rsquares=all_rsquares,
             detailed_rsquare_epochs=all_rsquare_epochs,
             detailed_ranks=all_ranks)

    print(f"Training curves and representations saved to: {curves_file}")

    # Create individual visualization plots for this seed
    print("Creating individual visualization plots...")

    # Unpack detailed curves
    all_losses, all_rsquares, all_rsquare_epochs, all_ranks, all_lrs = detailed_curves

    # Plot training curves with detailed data
    plt.figure(figsize=(20, 4))

    plt.subplot(1, 4, 1)
    if len(all_losses) > 0:
        plt.plot(all_losses)
        plt.title(f'Training Loss (Seed {seed})')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.grid(True)

    plt.subplot(1, 4, 2)
    if len(all_ranks) > 0:
        plt.plot(all_ranks, 'o-', linewidth=2)
        plt.title(f'Total Rank Over Epochs (Seed {seed})')
        plt.xlabel('Epoch')
        plt.ylabel('Total Rank')
        plt.grid(True)

    plt.subplot(1, 4, 3)
    if len(all_rsquares) > 0 and len(all_rsquare_epochs) > 0:
        plt.plot(all_rsquare_epochs, all_rsquares, 'o-', linewidth=2)
        plt.title(f'R² Values Over Epochs (Seed {seed})')
        plt.xlabel('Epoch')
        plt.ylabel('R²')
        plt.grid(True)

    # Plot learning rate if available
    plt.subplot(1, 4, 4)
    plt.plot(all_lrs, 'o-', linewidth=2)
    plt.title('Learning Rate')
    plt.xlabel('Epoch')
    plt.ylabel('LR')
    plt.grid(True)

    plt.tight_layout()
    plot_file = out_file.replace('.csv', '_training_plots.png')
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    print(f"Training plots saved to: {plot_file}")
    plt.close()

    # Visualize latent space for this seed
    print("Creating latent space visualization...")
    latent_data = reps.cpu().numpy()

    # Determine number of subplots based on final rank
    if final_rank_sum <= 2:
        fig = plt.figure(figsize=(12, 5))
        ax1 = fig.add_subplot(1, 2, 1, projection='3d')
        ax2 = fig.add_subplot(1, 2, 2)
    else:
        fig = plt.figure(figsize=(15, 5))
        ax1 = fig.add_subplot(1, 3, 1, projection='3d')
        ax2 = fig.add_subplot(1, 3, 2)
        ax3 = fig.add_subplot(1, 3, 3)

    # Original Swiss roll
    ax1.scatter(X_original[:n_samples_train, 0], X_original[:n_samples_train, 1], X_original[:n_samples_train, 2], 
               c=color[:n_samples_train], cmap=plt.cm.viridis)
    ax1.set_title(f'Original Swiss Roll (3D) - Seed {seed}')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_zlabel('Z')

    # Latent space (first 2 dimensions)
    ax2.scatter(latent_data[:, 0], latent_data[:, 1], c=color[:n_samples_train], cmap=plt.cm.viridis)
    ax2.set_title(f'Latent Space (First 2 Dims) - Seed {seed}')
    ax2.set_xlabel('Latent Dim 1')
    ax2.set_ylabel('Latent Dim 2')
    ax2.grid(True)

    # PCA visualization if final rank > 2
    if final_rank_sum > 2 and latent_data.shape[1] > 2:
        # Use only the active dimensions for PCA
        active_dims = min(final_rank_sum, latent_data.shape[1])
        pca = PCA(n_components=2, random_state=seed)
        latent_2d = pca.fit_transform(latent_data[:, :active_dims])
        ax3.scatter(latent_2d[:, 0], latent_2d[:, 1], c=color[:n_samples_train], cmap=plt.cm.viridis)
        ax3.set_title(f'Latent Space (PCA from {active_dims}D) - Seed {seed}')
        ax3.set_xlabel('PC 1')
        ax3.set_ylabel('PC 2')
        ax3.grid(True)

    plt.tight_layout()
    latent_plot_file = out_file.replace('.csv', '_latent_visualization.png')
    plt.savefig(latent_plot_file, dpi=150, bbox_inches='tight')
    print(f"Latent space visualization saved to: {latent_plot_file}")
    plt.close()

    print(f"Seed {seed} completed!")

# Create averaged results and plots
print(f"\n{'='*60}")
print("Creating averaged results across all seeds")
print(f"{'='*60}")

# Print summary statistics
results_df = pd.DataFrame(all_results)
n_seeds = len(results_df)
print(f"Summary across seeds:")
print(f"  - Mean final rank: {results_df['final_rank_sum'].mean():.2f} ± {results_df['final_rank_sum'].std() / np.sqrt(n_seeds):.2f}")
print(f"  - Mean final loss: {results_df['final_loss'].mean():.6f} ± {results_df['final_loss'].std() / np.sqrt(n_seeds):.6f}")
print(f"  - Mean final R²: {results_df['final_r_square'].mean():.6f} ± {results_df['final_r_square'].std() / np.sqrt(n_seeds):.6f}")
print(f"  - Mean training time: {results_df['training_time'].mean():.2f} ± {results_df['training_time'].std() / np.sqrt(n_seeds):.2f} minutes")

# Save summary results
summary_file = f"{base_out_dir}/summary_results.csv"
results_df.to_csv(summary_file, index=False)
print(f"Summary results saved to: {summary_file}")

# Create averaged training curves plot
# Create averaged training curves plot
print("Creating averaged training curves...")

# Find the maximum length among all curves for padding
max_loss_len = max([len(curves[0]) for curves in all_detailed_curves])
max_ranks_len = max([len(curves[3]) for curves in all_detailed_curves])

# Pad and average loss curves
padded_losses = []
for curves in all_detailed_curves:
    losses = curves[0]
    # Pad with the last value if shorter
    if len(losses) < max_loss_len:
        padded = list(losses) + [losses[-1]] * (max_loss_len - len(losses))
    else:
        padded = losses[:max_loss_len]
    padded_losses.append(padded)

avg_losses = np.mean(padded_losses, axis=0)
std_losses = np.std(padded_losses, axis=0)

# Pad and average rank curves
padded_ranks = []
for curves in all_detailed_curves:
    ranks = curves[3]
    # Pad with the last value if shorter
    if len(ranks) < max_ranks_len:
        padded = list(ranks) + [ranks[-1]] * (max_ranks_len - len(ranks))
    else:
        padded = ranks[:max_ranks_len]
    padded_ranks.append(padded)

avg_ranks = np.mean(padded_ranks, axis=0)
std_ranks = np.std(padded_ranks, axis=0)

# For R² curves, we need to handle variable timing
all_rsquares_data = []
all_rsquare_epochs_data = []
for curves in all_detailed_curves:
    all_rsquares_data.append(curves[1])
    all_rsquare_epochs_data.append(curves[2])

# Create the averaged plot
plt.figure(figsize=(15, 4))

# Average loss curve
plt.subplot(1, 3, 1)
epochs_loss = np.arange(len(avg_losses))
plt.plot(epochs_loss, avg_losses, 'b-', linewidth=2, label='Mean')
plt.fill_between(epochs_loss, avg_losses - std_losses, avg_losses + std_losses, 
                alpha=0.3, color='blue', label='±1 std')
plt.title(f'Average Training Loss (n={len(seeds)} seeds)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Average rank curve
plt.subplot(1, 3, 2)
epochs_rank = np.arange(len(avg_ranks))
plt.plot(epochs_rank, avg_ranks, 'r-', linewidth=2, label='Mean')
plt.fill_between(epochs_rank, avg_ranks - std_ranks, avg_ranks + std_ranks, 
                alpha=0.3, color='red', label='±1 std')
plt.title(f'Average Total Rank (n={len(seeds)} seeds)')
plt.xlabel('Epoch')
plt.ylabel('Total Rank')
plt.legend()
plt.grid(True)

# R² curves (individual traces since timing varies)
plt.subplot(1, 3, 3)
colors = ['blue', 'red', 'green']
for i, (rsquares, epochs) in enumerate(zip(all_rsquares_data, all_rsquare_epochs_data)):
    if len(rsquares) > 0 and len(epochs) > 0:
        plt.plot(epochs, rsquares, 'o-', linewidth=2, alpha=0.7, 
                color=colors[i % len(colors)], label=f'Seed {seeds[i]}')
plt.title(f'R² Values Over Epochs (n={len(seeds)} seeds)')
plt.xlabel('Epoch')
plt.ylabel('R²')
plt.legend()
plt.grid(True)

plt.tight_layout()
averaged_plot_file = f"{base_out_dir}/averaged_training_plots.png"
plt.savefig(averaged_plot_file, dpi=150, bbox_inches='tight')
print(f"Averaged training plots saved to: {averaged_plot_file}")
plt.close()

# check all types that are supposed to be saved as arrays. if there are lists, make them arrays first
avg_losses = np.array(avg_losses) if isinstance(avg_losses, list) else avg_losses
std_losses = np.array(std_losses) if isinstance(std_losses, list) else std_losses
avg_ranks = np.array(avg_ranks) if isinstance(avg_ranks, list) else avg_ranks
std_ranks = np.array(std_ranks) if isinstance(std_ranks, list) else std_ranks
all_rsquares_data = np.array(all_rsquares_data, dtype=object) if isinstance(all_rsquares_data, list) else all_rsquares_data
all_rsquare_epochs_data = np.array(all_rsquare_epochs_data, dtype=object) if isinstance(all_rsquare_epochs_data, list) else all_rsquare_epochs_data
seeds = np.array(seeds) if isinstance(seeds, list) else seeds

# Save averaged detailed curves
averaged_curves_file = f"{base_out_dir}/averaged_curves.npz"
np.savez(averaged_curves_file,
         avg_losses=avg_losses,
         std_losses=std_losses,
         avg_ranks=avg_ranks,
         std_ranks=std_ranks,
         all_rsquares=all_rsquares_data,
         all_rsquare_epochs=all_rsquare_epochs_data,
         seeds=seeds)

print(f"Averaged curves data saved to: {averaged_curves_file}")

print("Multi-seed Swiss roll experiment completed successfully!")
print(f"Results summary:")
print(f"  - Base directory: {base_out_dir}")
print(f"  - Individual seed results: seed_X_results.csv")
print(f"  - Summary results: {summary_file}")
print(f"  - Averaged plots: {averaged_plot_file}")
print(f"  - Averaged curves: {averaged_curves_file}")
