import torch
import torch.nn as nn
import wandb

def train(model, device, train_loader, optimizer, loss_fn, LOG=True):
    '''
    Training loop
    :param model: torch object containing the model architecture
    :param device: cuda or cpu device to work on
    :param train_loader: data loader 
    :param optimizer: torch optimizer
    :param loss_fn: error metric
    :param LOG: whether to log loss and predictions with wandb
    :return: average batch loss (over epoch)
    '''

    model.train()
    cumu_loss = 0
    for i, (x_batch, y_batch) in enumerate(train_loader):
       
        # Transfer tensors to GPU
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        # Zero parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        y_hat_batch = model(x_batch)

        # Compute loss
        loss = loss_fn(y_batch, y_hat_batch)

        cumu_loss += loss.item() 

        if LOG:
            wandb.log({'Train batch loss'  : loss.item(),
                       'Train Predictions' : y_hat_batch,
                    })

        # Backward pass
        loss.backward()
        
        # Step
        optimizer.step()
    
    return cumu_loss / len(train_loader)


def validate(model, device, val_loader, loss_fn, LOG=True):
    '''
    Validation loop
    :param model: torch object containing the model architecture
    :param device: cuda or cpu device to work on
    :param val_loader: data loader 
    :param loss_fn: error metric
    :param LOG: whether to log loss and predictions with wandb
    :return: average batch loss (over epoch)
    '''
    with torch.no_grad():
        cumu_loss = 0

        for i, (x_batch, y_batch) in enumerate(val_loader):
            # Transfer tensors to GPU
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            # Forward pass
            y_hat_batch = model(x_batch)

            # Compute loss
            loss = loss_fn(y_batch, y_hat_batch)
        
            # Store batch errors
            cumu_loss += loss.item() 

            # Log
            if LOG:
                wandb.log({'Val batch loss'    : loss.item(),
                           'Val predictions'   : y_hat_batch,
                        })
        
        return cumu_loss / len(val_loader)