import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import argparse
import os
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator

from clipper import Clipper
from model import DiffusionEEGModel, ImageToEEGModel
from data import get_eeg_dls
import utils


def parse_args():
    parser = argparse.ArgumentParser(description='Image-to-EEG Training Script')
    
    # Data parameters
    parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
    parser.add_argument('--subject', type=int, default=1, help='Subject number')
    parser.add_argument('--batch_size', type=int, default=16, help='Training batch size')
    parser.add_argument('--val_batch_size', type=int, default=32, help='Validation batch size')
    parser.add_argument('--num_workers', type=int, default=4, help='Data loader process count')
    
    # Model parameters
    parser.add_argument('--clip_variant', type=str, default='ViT-L/14', 
                       choices=['RN50', 'ViT-L/14', 'ViT-B/32', 'RN50x64'], help='CLIP model version')
    parser.add_argument('--eeg_channels', type=int, default=63, help='Number of EEG electrodes')
    parser.add_argument('--eeg_length', type=int, default=250, help='Number of EEG sampling points')
    parser.add_argument('--hidden_dim', type=int, default=768, help='Hidden layer dimension')
    parser.add_argument('--num_train_timesteps', type=int, default=1000, help='Diffusion training time steps')
    
    # Training parameters
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay')
    parser.add_argument('--warmup_steps', type=int, default=1000, help='Learning rate warmup steps')
    parser.add_argument('--gradient_clip', type=float, default=1.0, help='Gradient clipping')
    parser.add_argument('--loss_type', type=str, default='mse', choices=['mse'], 
                       help='Loss function type (mse: MSE Loss)')
    
    # Other parameters
    parser.add_argument('--device', type=str, default='cuda', help='Device type')
    parser.add_argument('--gpu_id', type=int, default=None, help='Specify GPU number (e.g.: 0, 1, 2)')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--save_dir', type=str, default='./checkpoints', help='Model save directory')
    parser.add_argument('--log_interval', type=int, default=100, help='Log output interval')
    
    return parser.parse_args()


def train_epoch(model, dataloader, optimizer, criterion, accelerator, args, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    diffusion_losses = []
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{args.epochs}', disable=not accelerator.is_local_main_process)
    
    for batch_idx, batch_data in enumerate(pbar):
        eeg_data, image_data = batch_data
        
        # Data preprocessing - check data type and convert
        eeg_data = eeg_data.float()
        image_data = image_data.float()
        
        # Process EEG data dimensions [batch_size, 4, 63, 250] -> [batch_size, 63, 250]
        if eeg_data.dim() == 4:
            eeg_data = utils.average_eeg_trials(eeg_data)
        
        # Expand EEG data dimensions [batch_size, 63, 250] -> [batch_size, 1, 63, 250]
        if eeg_data.dim() == 3:
            eeg_data = eeg_data.unsqueeze(1)
        
        # Forward propagation
        outputs = model(image_data, eeg_data, mode='train')
        
        # Calculate Diffusion loss
        diffusion_loss = criterion(outputs['noise_pred'], outputs['noise'])
        
        # Total loss = Diffusion loss (remove modality alignment loss)
        total_batch_loss = diffusion_loss
        
        # Backpropagation
        optimizer.zero_grad()
        accelerator.backward(total_batch_loss)
        
        # Gradient clipping
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(model.parameters(), args.gradient_clip)
        
        optimizer.step()
        
        # Record loss
        total_loss += total_batch_loss.item()
        diffusion_losses.append(diffusion_loss.item())
        
        # Update progress bar
        loss_name = 'MSE' if args.loss_type == 'mse' else 'MSE'  # Default to MSE
        pbar.set_postfix({
            'Loss': f'{total_batch_loss.item():.4f}',
            loss_name: f'{diffusion_loss.item():.4f}'
        })
        
        # Log output
        if batch_idx % args.log_interval == 0 and accelerator.is_local_main_process:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, '
                  f'Total Loss: {total_batch_loss.item():.4f}, '
                  f'Diffusion Loss ({loss_name}): {diffusion_loss.item():.4f}')
    
    avg_loss = total_loss / len(dataloader)
    avg_diffusion_loss = np.mean(diffusion_losses)
    
    return avg_loss, avg_diffusion_loss


def validate_epoch(model, dataloader, criterion, accelerator, args):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    correlations = []
    
    with torch.no_grad():
        for batch_data in tqdm(dataloader, desc='Validating', disable=not accelerator.is_local_main_process):
            eeg_data, image_data = batch_data
            
            # Data preprocessing - check data type and convert
            eeg_data = eeg_data.float()
            image_data = image_data.float()
            
            # Process EEG data dimensions
            if eeg_data.dim() == 4:
                eeg_data = utils.average_eeg_trials(eeg_data)
            
            # Expand EEG data dimensions [batch_size, 63, 250] -> [batch_size, 1, 63, 250]
            if eeg_data.dim() == 3:
                eeg_data = eeg_data.unsqueeze(1)
            
            # Generate EEG signals
            outputs = model(image_data, mode='test')
            generated_eeg = outputs['generated_eeg']
            
            # Calculate reconstruction loss
            recon_loss = criterion(generated_eeg, eeg_data)
            total_loss += recon_loss.item()
            
            # Calculate correlation
            batch_correlations = utils.compute_correlation(generated_eeg, eeg_data)
            correlations.extend(batch_correlations.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    avg_correlation = np.mean(correlations)
    
    return avg_loss, avg_correlation


def main():
    args = parse_args()
    
    # Initialize Accelerator
    accelerator = Accelerator()
    
    # Set random seed
    utils.seed_everything(args.seed)
    
    # Create save directory
    if accelerator.is_local_main_process:
        os.makedirs(args.save_dir, exist_ok=True)
    
    # Load data
    if accelerator.is_local_main_process:
        print('Loading data...')
    train_dl, val_dl = get_eeg_dls(
        subject=args.subject,
        data_path=args.data_path,
        batch_size=args.batch_size,
        val_batch_size=args.val_batch_size,
        num_workers=args.num_workers,
        seed=args.seed
    )
    
    # Initialize CLIP model
    if accelerator.is_local_main_process:
        print('Initializing CLIP model...')
    clip_model = Clipper(
        clip_variant=args.clip_variant,
        device=accelerator.device
    )
    
    # Initialize Diffusion model
    if accelerator.is_local_main_process:
        print('Initializing Diffusion model...')
    diffusion_model = DiffusionEEGModel(
        eeg_channels=args.eeg_channels,
        eeg_length=args.eeg_length,
        hidden_dim=args.hidden_dim,
        num_train_timesteps=args.num_train_timesteps,
        device=accelerator.device
    )
    
    # Initialize complete model
    model = ImageToEEGModel(clip_model, diffusion_model)
    
    # Optimizer
    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.lr,
        weight_decay=args.weight_decay
    )
    
    # Loss function selection
    if args.loss_type == 'mse':
        criterion = nn.MSELoss()
        loss_name = 'MSE'
    
    if accelerator.is_local_main_process:
        print(f'Using loss function: {loss_name}')
    
    # Use accelerator prepare model, optimizer and data loader
    model, optimizer, train_dl, val_dl = accelerator.prepare(
        model, optimizer, train_dl, val_dl
    )
    
    # Training records
    train_losses = []
    val_losses = []
    correlations = []
    best_val_loss = float('inf')
    best_val_correlation = 0.0  # Save correlation corresponding to best validation loss
    
    if accelerator.is_local_main_process:
        print('Starting training...')
    for epoch in range(args.epochs):
        # Training
        train_loss, train_diff_loss = train_epoch(
            model, train_dl, optimizer, criterion, accelerator, args, epoch
        )
        
        # Validation
        val_loss, val_correlation = validate_epoch(
            model, val_dl, criterion, accelerator, args
        )
        
        # Record results
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        correlations.append(val_correlation)
        
        if accelerator.is_local_main_process:
            print(f'Epoch {epoch+1}/{args.epochs}:')
            print(f'  Training loss ({loss_name}): {train_loss:.4f} (Diffusion: {train_diff_loss:.4f})')
            print(f'  Validation loss ({loss_name}): {val_loss:.4f}')
            print(f'  Validation correlation: {val_correlation:.4f}')
        
        # Save best model (based on minimum validation loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_correlation = val_correlation  # Save corresponding validation correlation
            if accelerator.is_local_main_process:
                model_filename = f'best_model_{args.loss_type}.pth'
                best_model_path = os.path.join(args.save_dir, model_filename)
                accelerator.save(accelerator.unwrap_model(model).state_dict(), best_model_path)
                
                # Save model metadata
                metadata_filename = f'best_model_{args.loss_type}_metadata.txt'
                metadata_path = os.path.join(args.save_dir, metadata_filename)
                with open(metadata_path, 'w', encoding='utf-8') as f:
                    f.write(f'Best Model Information:\n')
                    f.write(f'Training epoch: {epoch + 1}\n')
                    f.write(f'Validation loss: {best_val_loss:.6f}\n')
                    f.write(f'Validation correlation: {best_val_correlation:.6f}\n')
                    f.write(f'Loss function type: {loss_name}\n')
                    f.write(f'Subject: {args.subject}\n')
                    f.write(f'Model file: {model_filename}\n')
                
                print(f'Saved best model ({loss_name}), validation loss: {best_val_loss:.4f}, corresponding validation correlation: {best_val_correlation:.4f}')
    
    # Training completed
    if accelerator.is_local_main_process:
        print('Training completed!')
        print(f'Best validation loss ({loss_name}): {best_val_loss:.4f}')
        print(f'Corresponding validation correlation: {best_val_correlation:.4f}')
        print(f'Model saved in: {args.save_dir}')
        
        # Save best validation loss and corresponding validation correlation to file
        result_filename = f'best_val_results_{args.loss_type}.txt'
        result_file = os.path.join(args.save_dir, result_filename)
        with open(result_file, 'w', encoding='utf-8') as f:
            f.write(f'Loss function type: {loss_name}\n')
            f.write(f'Best validation loss: {best_val_loss:.6f}\n')
            f.write(f'Corresponding validation correlation: {best_val_correlation:.6f}\n')
            f.write(f'Number of training epochs: {args.epochs}\n')
            f.write(f'Subject: {args.subject}\n')
        print(f'Best validation results ({loss_name}) saved to: {result_file}')


if __name__ == '__main__':
    main()