"""
Unified Image-to-BrainSignal DiT Training Script
Supports both EEG and MEG modalities
Uses Accelerate for distributed training
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import argparse
import os
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator

from clipper import Clipper
from model import DiffusionBrainSignalModel, ImageToBrainSignalModel, DEFAULT_CONFIG
from data import get_dataloaders
import utils


def parse_args():
    parser = argparse.ArgumentParser(description='Image-to-BrainSignal Training Script (DiT version)')
    
    # Modality selection (core parameter)
    parser.add_argument('--modality', type=str, default='eeg', choices=['eeg', 'meg'],
                       help='Data modality: eeg or meg')
    
    # Data parameters
    parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
    parser.add_argument('--subject', type=int, default=1, help='Subject ID')
    parser.add_argument('--batch_size', type=int, default=None, help='Training batch size (default: EEG=16, MEG=4)')
    parser.add_argument('--val_batch_size', type=int, default=32, help='Validation batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of data loader workers')
    
    # Model parameters
    parser.add_argument('--clip_variant', type=str, default='ViT-L/14', 
                       choices=['RN50', 'ViT-L/14', 'ViT-B/32', 'RN50x64'], help='CLIP model variant')
    parser.add_argument('--signal_channels', type=int, default=None, 
                       help='Number of signal channels (default: EEG=63, MEG=271)')
    parser.add_argument('--signal_length', type=int, default=None, 
                       help='Number of signal sampling points (default: EEG=250, MEG=200)')
    parser.add_argument('--hidden_dim', type=int, default=768, help='Hidden dimension')
    parser.add_argument('--num_train_timesteps', type=int, default=1000, help='Diffusion training timesteps')
    
    # DiT model parameters (default uses DiT-XAttention-B/4 configuration)
    parser.add_argument('--patch_size_h', type=int, default=4, help='DiT patch height')
    parser.add_argument('--patch_size_w', type=int, default=4, help='DiT patch width')
    parser.add_argument('--depth', type=int, default=12, help='DiT Transformer layers')
    parser.add_argument('--num_heads', type=int, default=12, help='DiT attention heads')
    parser.add_argument('--mlp_ratio', type=float, default=4.0, help='DiT MLP ratio')
    parser.add_argument('--dropout', type=float, default=0.1, help='DiT Dropout ratio')
    parser.add_argument('--learn_sigma', action='store_true', help='Whether to learn variance')
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=1000, help='Learning rate warmup steps')
    parser.add_argument('--gradient_clip', type=float, default=1.0, help='Gradient clipping')
    parser.add_argument('--loss_type', type=str, default='mse', choices=['mse', 'mae'], 
                       help='Loss function type (mse: MSE Loss, mae: MAE Loss)')
    
    # Other parameters
    parser.add_argument('--device', type=str, default='cuda', help='Device type')
    parser.add_argument('--gpu_id', type=int, default=None, help='Specify GPU ID (e.g., 0, 1, 2)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Model save directory')
    parser.add_argument('--log_interval', type=int, default=100, help='Logging interval')
    
    # Caption parameters
    parser.add_argument('--caption_path', type=str, default=None,
                       help='Captions JSON file path (optional), for injecting text semantics into cross-attention')
    
    args = parser.parse_args()
    
    # Set default parameters based on modality
    modality = args.modality.lower()
    config = DEFAULT_CONFIG.get(modality, DEFAULT_CONFIG['eeg'])
    
    if args.signal_channels is None:
        args.signal_channels = config['channels']
    if args.signal_length is None:
        args.signal_length = config['length']
    if args.batch_size is None:
        args.batch_size = 16 if modality == 'eeg' else 4
    
    return args


def preprocess_signal_data(signal_data, modality):
    """
    Preprocess brain signal data
    """
    signal_data = signal_data.float()
    
    if modality == 'eeg':
        # EEG: Process dimensions [batch_size, n_repeats, 63, 250] -> [batch_size, 63, 250]
        # Training set n_repeats=4, validation set n_repeats=80
        if signal_data.dim() == 4:
            signal_data = signal_data.mean(dim=1)  # Average over repeat dimension
        
        # Expand dimension [batch_size, 63, 250] -> [batch_size, 1, 63, 250]
        if signal_data.dim() == 3:
            signal_data = signal_data.unsqueeze(1)
    else:
        # MEG: Data should already be [batch_size, 1, 271, 200]
        if signal_data.dim() == 3:
            signal_data = signal_data.unsqueeze(1)
    
    return signal_data


def train_epoch(model, dataloader, optimizer, criterion, accelerator, args, epoch):
    """Train one epoch"""
    model.train()
    total_loss = 0
    mse_losses = []
    
    # Check if dataset has captions
    has_captions = hasattr(dataloader.dataset, 'captions') and dataloader.dataset.captions is not None
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{args.epochs}', disable=not accelerator.is_local_main_process)
    
    for batch_idx, batch_data in enumerate(pbar):
        # Unpack data based on whether captions are available
        if has_captions:
            signal_data, image_data, captions = batch_data
        else:
            signal_data, image_data = batch_data
            captions = None
        
        # Data preprocessing
        signal_data = preprocess_signal_data(signal_data, args.modality)
        image_data = image_data.float()
        
        # Forward pass (compute Diffusion loss)
        outputs = model(image_data, signal_data, captions=captions, mode='train')
        
        # Compute MSE loss (predicted noise vs ground truth noise)
        mse_loss = criterion(outputs['noise_pred'], outputs['noise'])
        
        # Backward pass
        optimizer.zero_grad()
        accelerator.backward(total_batch_loss)
        
        # Gradient clipping
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), args.gradient_clip)
        
        optimizer.step()
        
        # Record losses
        total_loss += mse_loss.item()
        mse_losses.append(mse_loss.item())
        
        # Update progress bar
        pbar.set_postfix({'MSE': f'{mse_loss.item():.4f}'})
        
        # Logging output
        if batch_idx % args.log_interval == 0 and accelerator.is_local_main_process:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, MSE: {mse_loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    avg_mse_loss = np.mean(mse_losses)
    
    return avg_loss, avg_mse_loss


def validate_epoch(model, dataloader, criterion, accelerator, args):
    """Validate one epoch"""
    model.eval()
    total_loss = 0
    correlations = []
    cosine_similarities = []
    sync_likelihoods = []
    
    # Check if dataset has captions
    has_captions = hasattr(dataloader.dataset, 'captions') and dataloader.dataset.captions is not None
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Validating', disable=not accelerator.is_local_main_process):
            # Unpack data based on whether captions are available
            if has_captions:
                signal_data, image_data, captions = batch_data
            else:
                signal_data, image_data = batch_data
                captions = None
            
            # Data preprocessing
            signal_data = preprocess_signal_data(signal_data, args.modality)
            image_data = image_data.float()
            
            # Generate brain signal
            outputs = model(image_data, captions=captions, mode='test')
            generated_signal = outputs['generated_signal']
            
            # Compute reconstruction loss
            recon_loss = criterion(generated_signal, signal_data)
            total_loss += recon_loss.item()
            
            # Compute correlation
            batch_correlations = utils.compute_correlation(generated_signal, signal_data)
            correlations.extend(batch_correlations.cpu().numpy())
            
            # Compute cosine similarity
            batch_cosine_sim = utils.compute_cosine_similarity(generated_signal, signal_data)
            cosine_similarities.extend(batch_cosine_sim.cpu().numpy())
            
            # Compute synchronization likelihood
            batch_sl = utils.compute_synchronization_likelihood(generated_signal, signal_data)
            sync_likelihoods.extend(batch_sl)
    
    avg_loss = total_loss / len(dataloader)
    avg_correlation = np.mean(correlations)
    avg_cosine_sim = np.mean(cosine_similarities)
    avg_sync_likelihood = np.mean(sync_likelihoods)
    
    return avg_loss, avg_correlation, avg_cosine_sim, avg_sync_likelihood


def main():
    args = parse_args()
    
    # Initialize Accelerator
    accelerator = Accelerator()
    
    # Set random seed
    utils.seed_everything(args.seed)
    
    # Create save directory
    if accelerator.is_local_main_process:
        os.makedirs(args.save_dir, exist_ok=True)
    
    modality = args.modality.lower()
    modality_upper = modality.upper()
    
    # Print configuration info
    if accelerator.is_local_main_process:
        print(f'\n{"="*60}')
        print(f'Image-to-{modality_upper} Training')
        print(f'{"="*60}')
        print(f'Modality: {modality_upper}')
        print(f'Signal channels: {args.signal_channels}')
        print(f'Signal length: {args.signal_length}')
        print(f'{"="*60}\n')
    
    # Load data
    if accelerator.is_local_main_process:
        print('Loading data...')
        if args.caption_path:
            print(f'Using Captions file: {args.caption_path}')
    train_dl, val_dl = get_dataloaders(
        subject=args.subject,
        data_path=args.data_path,
        modality=modality,
        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        num_workers=args.num_workers,
        seed=args.seed,
        caption_path=args.caption_path
    )
    
    # Initialize CLIP model
    if accelerator.is_local_main_process:
        print('Initializing CLIP model...')
    clip_model = Clipper(
        clip_variant=args.clip_variant,
        hidden_state=True,  # Use patch-level features [B, 257, 768]
        device=accelerator.device
    )
    
    # Initialize DiT Diffusion model
    if accelerator.is_local_main_process:
        print('Initializing DiT Diffusion model...')
        print(f'  - modality: {modality_upper}')
        print(f'  - signal_channels: {args.signal_channels}')
        print(f'  - signal_length: {args.signal_length}')
        print(f'  - patch_size: ({args.patch_size_h}, {args.patch_size_w})')
        print(f'  - depth: {args.depth}')
        print(f'  - num_heads: {args.num_heads}')
        print(f'  - hidden_dim: {args.hidden_dim}')
    
    diffusion_model = DiffusionBrainSignalModel(
        modality=modality,
        signal_channels=args.signal_channels,
        signal_length=args.signal_length,
        patch_size=(args.patch_size_h, args.patch_size_w),
        hidden_dim=args.hidden_dim,
        depth=args.depth,
        num_heads=args.num_heads,
        mlp_ratio=args.mlp_ratio,
        dropout=args.dropout,
        num_train_timesteps=args.num_train_timesteps,
        learn_sigma=args.learn_sigma,
        device=accelerator.device
    )
    
    # Initialize complete model
    model = ImageToBrainSignalModel(clip_model, diffusion_model, modality=modality)
    
    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    
    # Loss function selection
    if args.loss_type == 'mse':
        criterion = nn.MSELoss()
        loss_name = 'MSE'
    elif args.loss_type == 'mae':
        criterion = nn.L1Loss()
        loss_name = 'MAE'
    
    if accelerator.is_local_main_process:
        print(f'Using loss function: {loss_name}')
    
    # Use accelerator prepare for model, optimizer and data loaders
    model, optimizer, train_dl, val_dl = accelerator.prepare(
        model, optimizer, train_dl, val_dl
    )
    
    # Training records
    train_losses = []
    val_losses = []
    correlations = []
    cosine_similarities = []
    sync_likelihoods = []
    best_val_loss = float('inf')
    best_val_correlation = 0.0
    best_val_cosine_sim = 0.0
    best_val_sync_likelihood = 0.0
    
    if accelerator.is_local_main_process:
        print('Starting training...')
    
    for epoch in range(args.epochs):
        # Training
        train_loss, train_mse_loss = train_epoch(
            model, train_dl, optimizer, criterion, accelerator, args, epoch
        )
        
        # Validation
        val_loss, val_correlation, val_cosine_sim, val_sync_likelihood = validate_epoch(
            model, val_dl, criterion, accelerator, args
        )
        
        # Record results
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        correlations.append(val_correlation)
        cosine_similarities.append(val_cosine_sim)
        sync_likelihoods.append(val_sync_likelihood)
        
        if accelerator.is_local_main_process:
            print(f'Epoch {epoch+1}/{args.epochs}:')
            print(f'  Training MSE loss: {train_mse_loss:.4f}')
            print(f'  Validation loss ({loss_name}): {val_loss:.4f}')
            print(f'  Validation correlation: {val_correlation:.4f}')
            print(f'  Validation cosine similarity: {val_cosine_sim:.4f}')
            print(f'  Validation sync likelihood: {val_sync_likelihood:.4f}')
        
        # Save best model (based on highest Pearson correlation)
        if val_correlation > best_val_correlation:
            best_val_loss = val_loss
            best_val_correlation = val_correlation
            best_val_cosine_sim = val_cosine_sim
            best_val_sync_likelihood = val_sync_likelihood
            if accelerator.is_local_main_process:
                model_filename = f'best_{modality}_model_{args.loss_type}.pth'
                best_model_path = os.path.join(args.save_dir, model_filename)
                accelerator.save(accelerator.unwrap_model(model).state_dict(), best_model_path)
                
                # Save model metadata
                metadata_filename = f'best_{modality}_model_{args.loss_type}_metadata.txt'
                metadata_path = os.path.join(args.save_dir, metadata_filename)
                with open(metadata_path, 'w', encoding='utf-8') as f:
                    f.write(f'Best {modality_upper} Model Info:\n')
                    f.write(f'Modality: {modality_upper}\n')
                    f.write(f'Training epochs: {epoch + 1}\n')
                    f.write(f'Validation correlation: {best_val_correlation:.6f}\n')
                    f.write(f'Validation loss: {best_val_loss:.6f}\n')
                    f.write(f'Validation cosine similarity: {best_val_cosine_sim:.6f}\n')
                    f.write(f'Validation sync likelihood: {best_val_sync_likelihood:.6f}\n')
                    f.write(f'Loss function type: {loss_name}\n')
                    f.write(f'Subject: {args.subject}\n')
                    f.write(f'Model file: {model_filename}\n')
                    f.write(f'\nDiT Configuration:\n')
                    f.write(f'  signal_channels: {args.signal_channels}\n')
                    f.write(f'  signal_length: {args.signal_length}\n')
                    f.write(f'  patch_size: ({args.patch_size_h}, {args.patch_size_w})\n')
                    f.write(f'  depth: {args.depth}\n')
                    f.write(f'  num_heads: {args.num_heads}\n')
                    f.write(f'  hidden_dim: {args.hidden_dim}\n')
                
                print(f'Saved best {modality_upper} model, validation correlation: {best_val_correlation:.4f}')
    
    # Training complete
    if accelerator.is_local_main_process:
        print('Training complete!')
        print(f'Best validation correlation: {best_val_correlation:.4f}')
        print(f'Corresponding validation loss ({loss_name}): {best_val_loss:.4f}')
        print(f'Corresponding validation cosine similarity: {best_val_cosine_sim:.4f}')
        print(f'Corresponding validation sync likelihood: {best_val_sync_likelihood:.4f}')
        print(f'Model saved at: {args.save_dir}')
        
        # Save best validation loss and corresponding metrics to file
        result_filename = f'best_{modality}_val_results_{args.loss_type}.txt'
        result_file = os.path.join(args.save_dir, result_filename)
        with open(result_file, 'w', encoding='utf-8') as f:
            f.write(f'Modality: {modality_upper}\n')
            f.write(f'Loss function type: {loss_name}\n')
            f.write(f'Best validation loss: {best_val_loss:.6f}\n')
            f.write(f'Corresponding validation correlation: {best_val_correlation:.6f}\n')
            f.write(f'Corresponding validation cosine similarity: {best_val_cosine_sim:.6f}\n')
            f.write(f'Corresponding validation sync likelihood: {best_val_sync_likelihood:.6f}\n')
            f.write(f'Training epochs: {args.epochs}\n')
            f.write(f'Subject: {args.subject}\n')
        print(f'Best validation results ({loss_name}) saved to: {result_file}')


if __name__ == '__main__':
    main()

