import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.channels import make_standard_montage
from mne.viz import plot_topomap
import pickle
import torch
import torch.nn.functional as F
import pandas as pd
from datetime import datetime
import os

# Specify to use GPU 1
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

def plot_real_vs_generated_topography(real_eeg, generated_eeg, train_eeg, times, ch_names, subject_id, save_dir=None):
    """
    Plot topography comparison of training set EEG, real EEG and generated EEG
    The fourth row shows the difference between training set and test set (train_eeg - real_eeg)
    
    Parameters:
    -----------
    real_eeg : np.array
        Real EEG data (test set), shape (n_channels, n_timepoints)
    generated_eeg : np.array  
        Generated EEG data, shape (n_channels, n_timepoints)
    train_eeg : np.array
        Average EEG data of training set, shape (n_channels, n_timepoints)
    times : np.array
        Time points array, unit in seconds
    ch_names : list
        Electrode name list
    subject_id : int
        Subject ID
    save_dir : str, optional
        Save directory path
    """
    
    print(f"Subject {subject_id} - Debug information:")
    print(f"train_eeg.shape: {train_eeg.shape}")
    print(f"real_eeg.shape: {real_eeg.shape}")
    print(f"generated_eeg.shape: {generated_eeg.shape}")
    print(f"times.shape: {times.shape}")
    print(f"times range: {times.min():.3f} - {times.max():.3f}")
    
    # Ensure time array matches EEG data time dimension
    n_timepoints = real_eeg.shape[1]
    if len(times) != n_timepoints:
        print(f"Warning: Time array length ({len(times)}) does not match EEG time dimension ({n_timepoints})")
        print(f"Recreating time array...")
        times = np.linspace(0, 1, n_timepoints)
        print(f"New time array: {times.min():.3f} - {times.max():.3f}, length: {len(times)}")
    
    # Create standard electrode layout
    montage = make_standard_montage('standard_1020')
    
    # Create MNE Info object
    info = mne.create_info(ch_names=ch_names, sfreq=250, ch_types='eeg')
    info.set_montage(montage)
    
    # Define time windows (100ms per window)
    time_windows = [(0.0, 0.1), (0.1, 0.2), (0.2, 0.3), (0.3, 0.4), 
                    (0.4, 0.5), (0.5, 0.6), (0.6, 0.7), (0.7, 0.8),
                    (0.8, 0.9), (0.9, 1.0)]
    
    # Create figure
    fig = plt.figure(figsize=(24, 10))
    
    # Use GridSpec to create optimized layout (4 rows)
    from matplotlib.gridspec import GridSpec
    gs = GridSpec(4, len(time_windows), 
                  figure=fig,
                  left=0.05,
                  right=0.89,
                  top=0.90,
                  bottom=0.06,
                  hspace=0.02,
                  wspace=0.10)
    
    # Calculate difference data (difference between training set and test set)
    diff_data = train_eeg - real_eeg
    
    # Calculate unified color mapping range (including all data)
    all_data = np.concatenate([train_eeg, real_eeg, generated_eeg, diff_data], axis=1)
    vmin, vmax = np.percentile(all_data, [2, 98])
    
    # Use unified color range
    vlim = max(abs(vmin), abs(vmax))  # Symmetric range, suitable for all topographic maps
    
    # Store colorbar image references 
    im_train = None
    im_real = None
    im_gen = None
    im_diff = None
    
    # Store position information of the first subplot in each row for calculating row title positions
    row_positions = []
    
    for i, (tmin, tmax) in enumerate(time_windows):
        # Find indices for corresponding time window
        time_mask = (times >= tmin) & (times < tmax)
        print(f"Subject {subject_id} - Time window {tmin}-{tmax}s: Found {np.sum(time_mask)} time points")
        if np.sum(time_mask) == 0:
            print(f"Warning: No time points found for time window {tmin}-{tmax}s")
            print(f"Skipping this time window...")
            continue
            
        # Calculate average data for this time window
        train_window = np.mean(train_eeg[:, time_mask], axis=1)
        real_window = np.mean(real_eeg[:, time_mask], axis=1)
        gen_window = np.mean(generated_eeg[:, time_mask], axis=1)
        diff_window = np.mean(diff_data[:, time_mask], axis=1)
        
        # Plot training set EEG topography
        ax_train = fig.add_subplot(gs[0, i])
        im_train = plot_topomap(train_window, info, axes=ax_train, show=False, 
                               vlim=(-vlim, vlim), cmap='RdBu_r', contours=0)
        ax_train.set_title(f'{int(tmin*1000)}-{int(tmax*1000)} ms', 
                          fontsize=12, pad=10)
        if i == 0:
            row_positions.append(ax_train.get_position().y0)
        
        # Plot real EEG topography
        ax_real = fig.add_subplot(gs[1, i])
        im_real = plot_topomap(real_window, info, axes=ax_real, show=False, 
                              vlim=(-vlim, vlim), cmap='RdBu_r', contours=0)
        if i == 0:
            row_positions.append(ax_real.get_position().y0)
        
        # Plot generated EEG topography
        ax_gen = fig.add_subplot(gs[2, i])
        im_gen = plot_topomap(gen_window, info, axes=ax_gen, show=False, 
                             vlim=(-vlim, vlim), cmap='RdBu_r', contours=0)
        if i == 0:
            row_positions.append(ax_gen.get_position().y0)
        
        # Plot difference topography (training - real)
        ax_diff = fig.add_subplot(gs[3, i])
        im_diff = plot_topomap(diff_window, info, axes=ax_diff, show=False, 
                              vlim=(-vlim, vlim), cmap='RdBu_r', contours=0)
        if i == 0:
            row_positions.append(ax_diff.get_position().y0)
    
    # Add row labels
    row_labels = ['Training EEG', 'Real Test EEG', 'Generated EEG', 'Train-Test Difference']
    for i, label in enumerate(row_labels):
        if i < len(row_positions):
            fig.text(0.01, row_positions[i] + 0.04, label, 
                    rotation=90, verticalalignment='center', 
                    fontsize=14, fontweight='bold')
    
    # Add colorbar
    cbar_ax = fig.add_axes([0.91, 0.06, 0.015, 0.84])  # [left, bottom, width, height]
    cbar = fig.colorbar(im_train, cax=cbar_ax)
    cbar.set_label('Amplitude (μV)', fontsize=12)
    
    # Add overall title
    fig.suptitle(f'Training vs Test vs Generated EEG Topography Comparison (Subject {subject_id})', 
                 fontsize=16, y=0.95)
    
    # Save figure if save_dir is provided
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)
        save_path = os.path.join(save_dir, f'subject_{subject_id}_topography_comparison.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Subject {subject_id} topographic map saved to: {save_path}")
    else:
        plt.show()
    
    plt.close()


def load_subject_data(subject_id, data_path):
    """
    Load subject data (test set)
    
    Parameters:
    -----------
    subject_id : int
        Subject ID
    data_path : str
        Data path
    
    Returns:
    --------
    real_eeg : np.array
        Real EEG data (test set), shape (n_trials, n_channels, n_timepoints)
    """
    # Format subject ID (e.g., 1 -> '01')
    subject_str = str(subject_id).zfill(2)
    
    # Load test set EEG file
    test_eeg_file = f"/path/to/folder"
    print(f"Loading file: {test_eeg_file}")
    
    try:
        # Load data
        test_eeg_data = np.load(test_eeg_file, allow_pickle=True)
        if isinstance(test_eeg_data, np.lib.npyio.NpzFile):
            test_eeg_data = test_eeg_data['preprocessed_eeg_data']
        elif isinstance(test_eeg_data, dict):
            test_eeg_data = test_eeg_data['preprocessed_eeg_data']
        
        print(f"Subject {subject_id} - Raw test EEG data shape: {test_eeg_data.shape}")
        
        # Average repeated trials
        if test_eeg_data.ndim == 4:
            real_eeg = np.mean(test_eeg_data, axis=1)  # Average along trial dimension
            print(f"Subject {subject_id} - Averaged real EEG data shape: {real_eeg.shape}")
        else:
            real_eeg = test_eeg_data
            
        # Final averaging (average all test data)
        if real_eeg.ndim == 3:
            real_eeg_avg = np.mean(real_eeg, axis=0)  # Average along sample dimension
            print(f"Subject {subject_id} - Final averaged real EEG data shape: {real_eeg_avg.shape}")
        else:
            real_eeg_avg = real_eeg
            
        return real_eeg_avg
        
    except FileNotFoundError:
        print(f"Error: Cannot find data file for subject {subject_id}: {test_eeg_file}")
        return None
    except Exception as e:
        print(f"Error loading data for subject {subject_id}: {e}")
        return None


def load_training_data(subject_id, data_path):
    """
    Load training data for the subject and average across trials and samples
    
    Parameters:
    -----------
    subject_id : int
        Subject ID
    data_path : str
        Data path
    
    Returns:
    --------
    train_eeg_avg : np.array
        Averaged training EEG data, shape (n_channels, n_timepoints)
    """
    # Format subject ID (e.g., 1 -> '01')
    subject_str = str(subject_id).zfill(2)
    
    # Load training set EEG file
    train_eeg_file = f"/path/to/folder"
    print(f"Loading training set file: {train_eeg_file}")
    
    try:
        # Load data
        train_eeg_data = np.load(train_eeg_file, allow_pickle=True)
        if isinstance(train_eeg_data, np.lib.npyio.NpzFile):
            train_eeg_data = train_eeg_data['preprocessed_eeg_data']
        elif isinstance(train_eeg_data, dict):
            train_eeg_data = train_eeg_data['preprocessed_eeg_data']
        
        print(f"Subject {subject_id} - Raw training EEG data shape: {train_eeg_data.shape}")
        
        # Average repeated trials (if data has 4 dimensions)
        if train_eeg_data.ndim == 4:
            print(f"Subject {subject_id} - Averaging repeated trials...")
            train_eeg_data = np.mean(train_eeg_data, axis=1)  # Average along trial dimension
            print(f"Subject {subject_id} - Training EEG data shape after averaging trials: {train_eeg_data.shape}")
            
        # Final averaging (average all training data)
        if train_eeg_data.ndim == 3:
            train_eeg_avg = np.mean(train_eeg_data, axis=0)  # Average along sample dimension
            print(f"Subject {subject_id} - Final averaged training EEG data shape: {train_eeg_avg.shape}")
        else:
            train_eeg_avg = train_eeg_data
            
        return train_eeg_avg
        
    except FileNotFoundError:
        print(f"Error: Cannot find training data file for subject {subject_id}: {train_eeg_file}")
        return None
    except Exception as e:
        print(f"Error loading training data for subject {subject_id}: {e}")
        return None


def generate_eeg_for_subject(subject_id, model_path, data_path, device):
    """
    Generate EEG data for the subject using the trained model
    
    Parameters:
    -----------
    subject_id : int
        Subject ID
    model_path : str
        Model path
    data_path : str
        Data path
    device : torch.device
        Device to use (CPU or GPU)
    
    Returns:
    --------
    generated_eeg_avg : np.array
        Averaged generated EEG data, shape (n_channels, n_timepoints)
    """
    print(f"Subject {subject_id} - Starting to generate EEG signals...")
    
    # Format subject ID (e.g., 1 -> '01')
    subject_str = str(subject_id).zfill(2)
    
    # Load trained model
    model_file = f"{model_path}/subject{subject_id}/best_model_mse.pth"
    if not os.path.exists(model_file):
        print(f"Error: Model file does not exist: {model_file}")
        return None
        
    # Import required modules
    from model import DiffusionEEGModel, ImageToEEGModel
    from clipper import Clipper
    from data import get_eeg_dls
    import utils
    
    # Initialize model
    clip_model = Clipper(clip_variant='ViT-L/14', device=device)
    diffusion_model = DiffusionEEGModel(
        eeg_channels=63,
        eeg_length=250,
        hidden_dim=768,
        device=device
    ).to(device)
    
    model = ImageToEEGModel(clip_model, diffusion_model).to(device)
    
    # Load model weights
    try:
        state_dict = torch.load(model_file, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Subject {subject_id} - Successfully loaded model")
    except Exception as e:
        print(f"Subject {subject_id} - Error loading model: {e}")
        return None
    
    # Load test data loader to get image data
    _, test_dl = get_eeg_dls(
        subject=subject_id,
        data_path=data_path,
        batch_size=32,
        val_batch_size=32,
        num_workers=0,
        seed=42
    )
    
    # Generate EEG signals for all images
    all_generated_eeg = []
    model.eval()
    
    with torch.no_grad():
        for batch_idx, (eeg_data, image_data) in enumerate(test_dl):
            print(f"Subject {subject_id} - Processing batch {batch_idx+1}/{len(test_dl)}")
            
            # Move data to device
            image_data = image_data.float().to(device)
            
            # Generate EEG signals
            outputs = model(image_data, mode='test')
            generated_eeg = outputs['generated_eeg']
            
            # Convert to numpy and store
            generated_eeg_np = generated_eeg.cpu().numpy()
            all_generated_eeg.append(generated_eeg_np)
    
    # Concatenate all generated EEG data
    if all_generated_eeg:
        generated_eeg_all = np.concatenate(all_generated_eeg, axis=0)
        print(f"Subject {subject_id} - All generated EEG shape: {generated_eeg_all.shape}")
        
        # Average across batch dimension
        generated_eeg_avg = np.mean(generated_eeg_all, axis=0)  # Shape: (1, 63, 250)
        print(f"Subject {subject_id} - Averaged generated EEG shape: {generated_eeg_avg.shape}")
        
        # Remove extra dimension if needed
        if generated_eeg_avg.ndim == 3 and generated_eeg_avg.shape[0] == 1:
            generated_eeg_avg = generated_eeg_avg[0]  # Shape: (63, 250)
            print(f"Subject {subject_id} - Final generated EEG shape: {generated_eeg_avg.shape}")
        elif generated_eeg_avg.ndim == 4:
            # If data has 4 dimensions, average first two dimensions
            generated_eeg_avg = np.mean(generated_eeg_avg, axis=(0, 1))  # Shape: (63, 250)
            print(f"Subject {subject_id} - Final generated EEG shape: {generated_eeg_avg.shape}")
            
        return generated_eeg_avg
    else:
        print(f"Subject {subject_id} - No generated EEG data")
        return None


def compute_similarity_metrics(real_eeg, generated_eeg, train_eeg):
    """
    Compute similarity metrics between real EEG and generated EEG
    
    Parameters:
    -----------
    real_eeg : np.array
        Real EEG data, shape (n_channels, n_timepoints)
    generated_eeg : np.array
        Generated EEG data, shape (n_channels, n_timepoints)
    train_eeg : np.array
        Training EEG data, shape (n_channels, n_timepoints)
    
    Returns:
    --------
    metrics : dict
        Dictionary containing various similarity metrics
    """
    print(f"Computing similarity metrics...")
    
    # Flatten data for correlation calculation
    real_flat = real_eeg.flatten()
    gen_flat = generated_eeg.flatten()
    train_flat = train_eeg.flatten()
    
    # Compute Pearson correlation coefficient
    correlation, _ = pearsonr(real_flat, gen_flat)
    
    # Compute MSE
    mse = np.mean((real_flat - gen_flat) ** 2)
    
    # Compute correlation between training and real data
    train_real_corr, _ = pearsonr(train_flat, real_flat)
    
    # Compute correlation between training and generated data
    train_gen_corr, _ = pearsonr(train_flat, gen_flat)
    
    metrics = {
        'overall_correlation': correlation,
        'mse': mse,
        'train_real_correlation': train_real_corr,
        'train_gen_correlation': train_gen_corr
    }
    
    return metrics


def save_metrics_to_file(metrics, subject_id, save_dir):
    """
    Save similarity metrics to text file
    
    Parameters:
    -----------
    metrics : dict
        Dictionary containing similarity metrics
    subject_id : int
        Subject ID
    save_dir : str
        Save directory
    """
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Save to text file
    txt_path = os.path.join(save_dir, f'subject_{subject_id}_similarity_metrics.txt')
    with open(txt_path, 'w', encoding='utf-8') as f:
        f.write(f"Subject {subject_id} - EEG Similarity Metrics\n")
        f.write("=" * 50 + "\n\n")
        f.write(f"Pearson correlation coefficient:     {metrics['overall_correlation']:.6f}\n")
        f.write(f"Mean Squared Error (MSE):     {metrics['mse']:.6f}\n")
        f.write(f"Training-Real correlation:    {metrics['train_real_correlation']:.6f}\n")
        f.write(f"Training-Generated correlation: {metrics['train_gen_correlation']:.6f}\n")
    
    print(f"Subject {subject_id} similarity metrics saved to: {txt_path}")


def process_all_subjects(data_path, model_path, save_dir, device_ids=[0]):
    """
    Process all subjects and generate comparison plots and metrics
    
    Parameters:
    -----------
    data_path : str
        Data path
    model_path : str
        Model path
    save_dir : str
        Save directory
    device_ids : list
        List of GPU device IDs to use
    """
    ch_names = [
        'Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 
        'T7', 'T8', 'P7', 'P8', 'Fz', 'Cz', 'Pz', 'Oz', 'FC1', 'FC2', 'CP1', 'CP2', 
        'FC5', 'FC6', 'CP5', 'CP6', 'F9', 'FT10', 'TP9', 'TP10', 'F1', 'F2', 'C1', 'C2',
        'P1', 'P2', 'AF3', 'AF4', 'FC3', 'FC4', 'CP3', 'C4', 'PO3', 'PO4', 'F5', 'F6',
        'C5', 'C6', 'P5', 'P6', 'AF7', 'AF8', 'FT7', 'FT8', 'TP7', 'TP8', 'PO7', 'PO8',
        'F1', 'FP2', 'CPz', 'POz'
    ]
    
    # Time points (0-1 second, 250Hz sampling rate)
    times = np.linspace(0, 1, 250)
    
    # Process each subject
    all_metrics = {}
    
    for subject_id in range(1, 11):
        print(f"\n{'='*60}")
        print(f"Processing subject {subject_id}")
        print(f"{'='*60}")
        
        # Select device (round-robin)
        device_id = device_ids[(subject_id - 1) % len(device_ids)]
        device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        try:
            # Load real EEG data (test set)
            print(f"Subject {subject_id} - Loading real EEG data...")
            real_eeg = load_subject_data(subject_id, data_path)
            if real_eeg is None:
                print(f"Subject {subject_id} - Failed to load real EEG data, skipping...")
                continue
                
            # Load training EEG data
            print(f"Subject {subject_id} - Loading training EEG data...")
            train_eeg = load_training_data(subject_id, data_path)
            if train_eeg is None:
                print(f"Subject {subject_id} - Failed to load training EEG data, skipping...")
                continue
                
            # Generate EEG data using model
            print(f"Subject {subject_id} - Generating EEG data...")
            generated_eeg = generate_eeg_for_subject(subject_id, model_path, data_path, device)
            if generated_eeg is None:
                print(f"Subject {subject_id} - Failed to generate EEG data, skipping...")
                continue
                
            # Ensure all data have the same shape
            print(f"Subject {subject_id} - Data shapes:")
            print(f"  Training EEG: {train_eeg.shape}")
            print(f"  Real EEG: {real_eeg.shape}")
            print(f"  Generated EEG: {generated_eeg.shape}")
            
            # If needed, transpose data to ensure (channels, timepoints) format
            if train_eeg.ndim == 3:
                train_eeg = train_eeg[0]  # Take first sample if 3D
            if real_eeg.ndim == 3:
                real_eeg = real_eeg[0]  # Take first sample if 3D
            if generated_eeg.ndim == 3:
                generated_eeg = generated_eeg[0]  # Take first sample if 3D
                
            # Ensure data are 2D (channels, timepoints)
            if train_eeg.ndim != 2 or real_eeg.ndim != 2 or generated_eeg.ndim != 2:
                print(f"Subject {subject_id} - Data dimension mismatch, skipping...")
                continue
                
            # Plot topography comparison
            print(f"Subject {subject_id} - Plotting topography comparison...")
            plot_real_vs_generated_topography(
                real_eeg, generated_eeg, train_eeg, times, ch_names, subject_id, save_dir
            )
            
            # Compute similarity metrics
            print(f"Subject {subject_id} - Computing similarity metrics...")
            metrics = compute_similarity_metrics(real_eeg, generated_eeg, train_eeg)
            all_metrics[subject_id] = metrics
            
            # Save metrics to file
            save_metrics_to_file(metrics, subject_id, save_dir)
            
            print(f"Subject {subject_id} - Processing completed successfully!")
            
        except Exception as e:
            print(f"Subject {subject_id} - Error during processing: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    # Generate summary report
    if all_metrics:
        generate_summary_report(all_metrics, save_dir)
    else:
        print("No metrics data available for summary report")


def generate_summary_report(all_metrics, save_dir):
    """
    Generate summary report for all subjects
    
    Parameters:
    -----------
    all_metrics : dict
        Dictionary containing metrics for all subjects
    save_dir : str
        Save directory
    """
    # Create save directory
    os.makedirs(save_dir, exist_ok=True)
    
    # Generate summary report
    summary_txt = os.path.join(save_dir, 'all_subjects_summary.txt')
    with open(summary_txt, 'w', encoding='utf-8') as f:
        f.write("All Subjects EEG Similarity Metrics Summary\n")
        f.write("=" * 60 + "\n\n")
        
        # Write header
        f.write(f"{'Subject':<8} {'Correlation':<12} {'MSE':<12}\n")
        f.write("-" * 60 + "\n")
        
        # Write data for each subject
        for subject_id, metrics in all_metrics.items():
            f.write(f"{subject_id:<8} {metrics['overall_correlation']:<12.4f} "
                   f"{metrics['mse']:<12.6f}\n")
        
        # Calculate and write average
        f.write("-" * 60 + "\n")
        avg_corr = np.mean([m['overall_correlation'] for m in all_metrics.values()])
        avg_mse = np.mean([m['mse'] for m in all_metrics.values()])
        f.write(f"{'Average':<8} {avg_corr:<12.4f} {avg_mse:<12.6f}\n")
        
        # Calculate and write standard deviation
        std_corr = np.std([m['overall_correlation'] for m in all_metrics.values()])
        std_mse = np.std([m['mse'] for m in all_metrics.values()])
        f.write(f"{'Std Dev':<8} {std_corr:<12.4f} {std_mse:<12.6f}\n")
    
    print(f"All subjects summary report saved to: {summary_txt}")
    
    # Save to CSV file as well
    csv_path = os.path.join(save_dir, 'all_subjects_metrics.csv')
    df = pd.DataFrame.from_dict(all_metrics, orient='index')
    df.index.name = 'Subject'
    df.to_csv(csv_path)
    print(f"All subjects metrics saved to CSV: {csv_path}")


if __name__ == "__main__":
    import argparse
    from scipy.stats import pearsonr
    
    parser = argparse.ArgumentParser(description='Plot EEG topography comparison')
    parser.add_argument('--data_path', type=str, required=True, help='Data path')
    parser.add_argument('--model_path', type=str, required=True, help='Model path')
    parser.add_argument('--save_dir', type=str, default='./topography_results', help='Save directory')
    parser.add_argument('--device_ids', type=str, default='0', help='GPU device IDs (comma separated)')
    
    args = parser.parse_args()
    
    # Parse device IDs
    device_ids = [int(x) for x in args.device_ids.split(',')]
    
    # Process all subjects
    process_all_subjects(args.data_path, args.model_path, args.save_dir, device_ids)