import torch
import wandb

def train(dataloader, model, loss_fn, optimizer, scheduler, device,acumulation_grad_steps=8):
    """
    Execute one epoch of training for the model.
    
    Args:
        dataloader (torch.utils.data.DataLoader): DataLoader containing training data.
        model (torch.nn.Module): The model to train.
        loss_fn (callable): Loss function for computing training loss.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
    
    Returns:
        None: Updates model parameters in-place.
        
    ##TODO: 
    - Allow model to return additional losses, e.g., consistency loss, variance loss, etc.
    - It might be useful to also report the "unscaled" loss, i.e., the loss after
      inverse_transform is applied to the predictions and targets.
    """
    # Set model to training mode (enables dropout, batch norm, etc.)
    model.train()
    # Start timing the epoch
    #epoch_start = torch.cuda.Event(enable_timing=True)
    #epoch_end = torch.cuda.Event(enable_timing=True)
    #epoch_start.record()
    
    for batch_idx, (x, y) in enumerate(dataloader):
        # Start timing the batch
        torch.cuda.empty_cache()  # free cached GPU memory

        x, y = x.to(device), y.to(device)
        # Forward pass: compute model predictions
        pred = model(x)

        # Compute loss between predictions and targets
        loss = loss_fn(pred, y)/acumulation_grad_steps
        loss.backward()

        if (batch_idx+1) % acumulation_grad_steps == 0:
            # Backward pass:
            # 1. Clear accumulated gradients
            # 2. Compute gradients of loss w.r.t. model parameters
            # 3. Clip gradients to prevent exploding gradients (max norm of 1.0)
            # 4. Update model parameters using optimizer
            # 5. Update learning rate using scheduler
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Clear GPU memory cache to prevent memory leaks
        torch.cuda.empty_cache()  # free cached GPU memory


        # Log training metrics to wandb
        wandb.log(
            {
                "learning_rate": scheduler.get_last_lr()[0],
                "train/loss": loss.item(),
            }
        )


        # Print training progress every 10% of the dataset or for the first batch
        if len(dataloader) > 10 and (batch_idx == 0 or batch_idx % (len(dataloader) // 10) == 0):
            processed_samples = batch_idx * dataloader.batch_size
            total_samples = len(dataloader.dataset)
            current_loss = loss.item()
            
            print(
                f"Samples: {processed_samples:>5d} / {total_samples:>5d}, "
                f"Train Loss: {current_loss:>7f}"
            )
