import torch
import torch.nn as nn
import wandb

'''
Train and validation functions for the least core
'''

def train(model, loader, optimizer, loss_fn, device, LOG=False):
    '''
    Training loop
    :param model: torch object containing the model architecture
    :param device: cuda or cpu device to work on
    :param loader: data loader 
    :param optimizer: torch optimizer
    :param loss_fn: error metric
    :param LOG: whether to log loss and predictions with wandb
    :return: average batch loss (over epoch)
    '''

    model.train()
    cumu_loss = 0
    for i, (x_stack, sol_stack) in enumerate(loader):
        
        # Move data to device
        x_stack, sol_stack  = x_stack.to(device), sol_stack.to(device)

        # Forward pass
        y_pred, eps_pred = model(x_stack)

        # Stack and compute loss
        sol_pred = torch.hstack((y_pred, eps_pred))
        loss = loss_fn(sol_pred, sol_stack)
        cumu_loss += loss.item()

        if LOG:
            wandb.log({'Inputs [train]'        : x_stack,
                       'Pred payoffs [train]'  : y_pred,
                       'Pred epsilon [train]'  : eps_pred,
                       'Batch loss   [train]'  : loss.item()        
                    })

        # Zero parameter gradients
        optimizer.zero_grad()

        # Backward pass
        loss.backward()

        # Step
        optimizer.step()
    
    return cumu_loss / len(loader)


def validate(model, loader, loss_fn, device, LOG=False):
    '''
    Validation loop
    :param model: torch object containing the model architecture
    :param device: cuda or cpu device to work on
    :param val_loader: data loader 
    :param loss_fn: error metric
    :param LOG: whether to log loss and predictions with wandb
    :return: average batch loss (over epoch)
    '''

    with torch.no_grad():
        cumu_loss = 0

        for i, (x_stack, sol_stack) in enumerate(loader):
            
            # Move data to device
            x_stack, sol_stack  = x_stack.to(device), sol_stack.to(device)

            # Forward pass
            y_pred, eps_pred = model(x_stack)

            # Stack and compute loss
            sol_pred = torch.hstack((y_pred, eps_pred))
            loss = loss_fn(sol_pred, sol_stack)
            cumu_loss += loss.item()

            if LOG:
                wandb.log({'Inputs [val]'        : x_stack,
                           'Pred payoffs [val]'  : y_pred,
                           'Pred epsilon [val]'  : eps_pred,
                           'Batch loss   [val]'  : loss.item()        
                        })
        
    return cumu_loss / len(loader)