import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from src.models.larrp_multimodal import AdaptiveRankReducedAE
from src.models.larrp_unimodal import AdaptiveRankReducedAE as UnimodalAdaptiveRankReducedAE
from src.data.loading import MMSimData

def pretrain_overcomplete_ae(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True, recon_loss_balancing=False, paired=False):
    """
    Train an autoencoder for mm_sim pretraining (no rank reduction).
    
    Parameters:
    - data: Input data tensor
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank
    - patience: Early stopping patience
    - verbose: Print progress
    - recon_loss_balancing: Adaptive loss balancing across modalities
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model with adaptive rank reduction
    # Only rank-based AE supported
    input_dims = [d.shape[1] for d in data]
    if isinstance(latent_dim, int):
        latent_dims = [latent_dim] * (len(input_dims) + 1) # adding one for the shared space
    elif isinstance(latent_dim, list):
        if (len(latent_dim) == 1) & (len(input_dims) > 1):
            latent_dims = [latent_dim[0]] * (len(input_dims) + 1)
        else:
            latent_dims = latent_dim
    model = AdaptiveRankReducedAE(
        input_dims, latent_dims, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank
    ).to(device)
    # print the device the model is on
    print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loader
    # careful with the non-paired data because of how it is concatenated
    # first randomize the rows
    if paired:
        data_indices = torch.randperm(data[0].shape[0])
        train_indices = data_indices[:n_samples_train]
        val_indices = data_indices[n_samples_train:]
    else:
        data_indices = None
        train_indices = slice(0, n_samples_train)
        val_indices = slice(n_samples_train, None)
    train_data = [d[train_indices] for d in data]  # Randomize rows
    train_data = MMSimData(train_data)
    # Use pin_memory and num_workers from args if available
    num_workers = getattr(args, 'num_workers', 0)
    data_loader = torch.utils.data.DataLoader(
        train_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers
    )
    val_data = [data[i][val_indices] for i in range(len(data))]  # Split data into validation set
    val_data = MMSimData(val_data)
    val_data_loader = torch.utils.data.DataLoader(
        val_data, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=num_workers
    )
    n_samples = data[0].shape[0]
    n_samples_val = n_samples - n_samples_train
    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 
    
    # Initialize loss scaling factors for dynamic loss balancing
    loss_scales = torch.ones(len(data), device=device)
    #loss_scales[1] = 0.1
    loss_history = {f'mod_{i}_loss': [] for i in range(len(data))}
    
    # Initialize loss balancer for reconstruction losses
    if recon_loss_balancing:
        modality_loss_emas = [None] * len(data)
        ema_decay = 0.9

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        total_ortho_loss = 0.0
        per_modality_losses = [0.0] * len(data)
        
        for batch_idx, (x, mask) in enumerate(data_loader):
            ### plotting test
            # Store last batch for plotting
            last_batch_data = [x_m.clone() for x_m in x]
            # Get labels if they exist in the dataset
            if hasattr(train_data, 'labels') and train_data.labels is not None:
                start_idx = batch_idx * batch_size
                end_idx = min(start_idx + batch_size, len(train_data.labels))
                last_batch_labels = train_data.labels[start_idx:end_idx].clone()
            else:
                last_batch_labels = None
            ###
            
            loss = torch.tensor(0.0, device=device)
            total_loss = torch.tensor(0.0, device=device)
            x = [x_m.to(device, non_blocking=True) for x_m in x]
            
            # Forward pass
            x_hat, h_list = model(x)
            
            ortho_loss = torch.tensor(0.0, device=device)
            total_ortho_loss += ortho_loss.item()

            # Calculate separate losses for each modality
            modality_losses = []
            
            # Extract masks for each modality
            modality_masks = []
            if mask is not None:
                start_idx = 0
                for i, x_m in enumerate(x):
                    end_idx = start_idx + x_m.shape[1]
                    modality_masks.append(mask[:, start_idx:end_idx])
                    start_idx = end_idx
            else:
                modality_masks = [None] * len(x)
            
            # Calculate per-modality MSE losses
            for i, (x_m, x_hat_m) in enumerate(zip(x, x_hat)):
                
                # Compute MSE loss for this modality with mask if provided
                if modality_masks[i] is not None:
                    m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                else:
                    m_loss = F.mse_loss(x_hat_m, x_m)
                
                # Check for NaN 
                if torch.isnan(m_loss):
                    if verbose:
                        print(f"Warning: NaN loss detected for modality {i}")
                    m_loss = torch.tensor(0.0, device=device)
                
                modality_losses.append(m_loss)
                per_modality_losses[i] += m_loss.item()
            
            # Apply reconstruction loss balancing if enabled
            if recon_loss_balancing:
                # Update exponential moving averages for each modality
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] is None:
                        modality_loss_emas[i] = m_loss.item()
                    else:
                        modality_loss_emas[i] = ema_decay * modality_loss_emas[i] + (1 - ema_decay) * m_loss.item()
                
                # Calculate balanced loss using the minimum EMA as reference
                min_ema = min(ema for ema in modality_loss_emas if ema is not None and ema > 0)
                for i, m_loss in enumerate(modality_losses):
                    if modality_loss_emas[i] > 0:
                        balance_scale = min_ema / modality_loss_emas[i]
                        loss += balance_scale * m_loss
                    else:
                        loss += m_loss
            else:
                # Standard loss computation without balancing
                for i, m_loss in enumerate(modality_losses):
                    loss += loss_scales[i] * m_loss
            
            total_loss += loss
            
            # Backward pass and optimize
            optimizer.zero_grad()
            total_loss.backward()
            
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
    # Ortho loss is not used in pretraining
        per_modality_losses = [loss / len(data_loader) for loss in per_modality_losses]
        train_losses.append(train_loss)
        
        # Store per-modality losses in history
        for i, loss in enumerate(per_modality_losses):
            loss_history[f'mod_{i}_loss'].append(loss)
        
        # Validation phase with similar safeguards
        with torch.no_grad():
            for x_val, mask in val_data_loader:
                x_val = [x_m.to(device, non_blocking=True) for x_m in x_val]
                x_val_hat, _ = model(x_val)

                modality_masks = []
                if mask is not None:
                    start_idx = 0
                    for i, x_m in enumerate(x_val):
                        end_idx = start_idx + x_m.shape[1]
                        modality_masks.append(mask[:, start_idx:end_idx])
                        start_idx = end_idx
                else:
                    modality_masks = [None] * len(x_val)
                
                # Calculate validation loss
                val_batch_loss = 0.0
                for i, (x_m, x_hat_m) in enumerate(zip(x_val, x_val_hat)):
                    if modality_masks[i] is not None:
                        m_loss = F.mse_loss(x_hat_m[modality_masks[i]], x_m[modality_masks[i]])
                    else:
                        m_loss = F.mse_loss(x_hat_m, x_m)
                    if not torch.isnan(m_loss):
                        val_batch_loss += m_loss.item()
                
                val_loss += val_batch_loss / len(x_val)
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1
        
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
    
    return model, [train_losses, val_losses], data_indices


def pretrain_overcomplete_ae_unimodal(data, n_samples_train, latent_dim, device, args, epochs=100, early_stopping=50, 
                         lr=0.001, batch_size=128, ae_depth=2, ae_width=0.5, dropout=0.0, wd=1e-5, 
                         initial_rank_ratio=1.0, min_rank=10, 
                         patience=10, verbose=True):
    """
    Train a unimodal autoencoder for pretraining (no rank reduction).
    
    Parameters:
    - data: Input data tensor (single modality)
    - n_samples_train: Number of samples to use for training
    - latent_dim: Dimension of the latent space
    - epochs: Maximum number of training epochs
    - early_stopping: Number of epochs for early stopping patience
    - lr: Learning rate
    - batch_size: Batch size for training
    - ae_depth: Depth of the autoencoder
    - ae_width: Width multiplier for hidden layers
    - dropout: Dropout rate
    - wd: Weight decay
    - initial_rank_ratio: Initial rank ratio (1.0 = full rank)
    - min_rank: Minimum rank
    - patience: Early stopping patience
    - verbose: Print progress
    """
    # Declare multi_gpu as global so it can be accessed
    multi_gpu = args.multi_gpu if hasattr(args, 'multi_gpu') else False
    
    # Create model with adaptive rank reduction for unimodal data
    model = UnimodalAdaptiveRankReducedAE(
        data.shape[1], latent_dim, depth=ae_depth, width=ae_width, 
        dropout=dropout, initial_rank_ratio=initial_rank_ratio, 
        min_rank=min_rank
    ).to(device)
    
    if verbose:
        print(f"Model is on device: {next(model.parameters()).device}")
    
    # Handle multi-GPU setup
    if multi_gpu:
        # Adjust batch size to be divisible by number of GPUs
        if args.gpu_ids:
            num_gpus = len(args.gpu_ids.split(','))
        else:
            num_gpus = torch.cuda.device_count()
            
        # Ensure batch size is divisible by number of GPUs
        if batch_size % num_gpus != 0:
            original_batch_size = batch_size
            batch_size = (batch_size // num_gpus) * num_gpus
            if verbose:
                print(f"Adjusted batch size from {original_batch_size} to {batch_size} to be divisible by {num_gpus} GPUs")
            
        try:
            # If we need cuda:0 but it's not available, disable multi_gpu
            if 0 not in [int(id) for id in args.gpu_ids.split(',')]:
                raise RuntimeError("DataParallel requires cuda:0 which is not available.")
                
            # Ensure model is on cuda:0 for DataParallel
            cuda0_device = torch.device('cuda:0')
            model = model.to(cuda0_device)
            
            # Double-check all parameters are on cuda:0
            for param in model.parameters():
                if param.device != cuda0_device:
                    param.data = param.data.to(cuda0_device)
                    
            # Wrap model with DataParallel - explicitly specify device_ids
            model = nn.DataParallel(model, device_ids=[int(id) for id in args.gpu_ids.split(',')])
            if verbose:
                print(f"Using DataParallel across GPUs: {args.gpu_ids}")
        except Exception as e:
            print(f"Failed to use DataParallel: {e}")
            print(f"Falling back to single GPU mode on {device}")
            multi_gpu = False
            model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    
    # Create data loader for unimodal data
    train_data = data[:n_samples_train]
    data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_data = data[n_samples_train:]
    val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
    n_samples = data.shape[0]
    n_samples_val = n_samples - n_samples_train
    
    start_reduction = False
    
    # Train the model
    train_losses = []
    val_losses = []
    best_loss = float('inf') 

    patience_counter = 0
    pbar = tqdm.tqdm(range(epochs))
    for epoch in pbar:
        # Training phase
        model.train()
        train_loss = 0.0
        val_loss = 0.0
        
        for x in data_loader:
            x = x.to(device, non_blocking=True)
            
            # Forward pass
            x_hat = model(x)
            
            # Calculate MSE loss
            loss = F.mse_loss(x_hat, x)
            
            # Check for NaN 
            if torch.isnan(loss):
                if verbose:
                    print(f"Warning: NaN loss detected")
                loss = torch.tensor(0.0, device=device)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Average losses
        train_loss /= len(data_loader)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            for x_val in val_data_loader:
                x_val = x_val.to(device, non_blocking=True)
                x_val_hat = model(x_val)
                
                # Calculate validation loss
                val_batch_loss = F.mse_loss(x_val_hat, x_val)
                if not torch.isnan(val_batch_loss):
                    val_loss += val_batch_loss.item()
                
        val_loss /= len(val_data_loader)
        val_losses.append(val_loss)
        
        # Update best loss
        if train_loss < best_loss:
            best_loss = train_loss
            patience_counter = 0  # Reset patience counter
        else:
            patience_counter += 1
        
        # Update progress bar
        pbar.set_postfix({
            'loss': round(train_loss, 6),
            'val_loss': round(val_loss, 6),
            'best_loss': round(best_loss, 6),
        })
        
        if (epoch > early_stopping) & (min(val_losses[-early_stopping:]) > min(val_losses)) & (start_reduction is False):
            if verbose:
                print(f"Early stopping at epoch {epoch} with best loss {best_loss}")
            break
    
    return model, [train_losses, val_losses]
