"""
This module contains methods for training models.
"""

import torch
from torch.nn import functional as F
from torch import nn
import numpy as np


def train_single_epoch(epoch, model, train_loader, optimizer, device):
    """
    Util method for training a crenet model for a single epoch.
    """
    delta = 0.5
    # delta = 0.75
    # delta = 0.875
    # delta = 0.625
    log_interval = 10
    
    model.train()
    train_loss_up = 0
    train_loss_lo = 0
    num_samples = 0
    num_classes = 10
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)

        train_batch_num = len(data)

        optimizer.zero_grad()

        preds = model(data)
        
        # Extract upper and lower probability bounds
        preds_lo = preds[:, :preds.shape[-1]//2]
        preds_up = preds[:, preds.shape[-1]//2:]

        # Compute loss related to lower probabilities
        loss_lo = F.cross_entropy(preds_lo, labels, reduction='none')

        # Select top delta * batch_size samples with highest loss for backward
        _, indices = torch.sort(loss_lo, descending=True)
        bound_index = int(np.floor(delta * train_batch_num).item()) - 1
        bound_value = loss_lo[indices[bound_index]]
    
        choose_index = loss_lo >= bound_value
        choose_preds_lo = preds_lo[choose_index]
        choose_labels = labels[choose_index]
    
        loss_lo_mod = F.cross_entropy(choose_preds_lo, choose_labels, reduction='mean')
        loss_up = F.cross_entropy(preds_up, labels, reduction='mean')

        loss = loss_lo_mod + loss_up
        
        loss.backward()
        
        train_loss_up += loss_up.item()
        train_loss_lo += loss_lo_mod.item()
        
        optimizer.step()
        # num_samples += len(data)
        num_samples += 1

        # if batch_idx % log_interval == 0:
            
            # print(
            #     "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss_up: {:.6f}\tLoss_lo: {:.6f}".format(
            #         epoch,
            #         batch_idx * len(data),
            #         len(train_loader) * len(data),
            #         100.0 * batch_idx / len(train_loader),
            #         loss_up.item(),
            #         loss_lo_mod.item(),
            #     )
            # )

    # print("====> Epoch: {} Average upper loss: {:.4f} and lower loss: {:.4f}".format(
    #     epoch, train_loss_up / num_samples, train_loss_lo / num_samples)
    #      )

    # print("====> Epoch: {} Average loss: {:.4f}".format(epoch, train_loss / num_samples))
    return train_loss_up / num_samples, train_loss_lo / num_samples


def test_single_epoch(epoch, model, test_val_loader, device, loss_function="cross_entropy"):
    """
    Util method for testing a model for a single epoch.
    """
    model.eval()
    loss_lo = 0
    loss_up = 0
    num_samples = 0
    
    with torch.no_grad():
        for data, labels in test_val_loader:
            data = data.to(device)
            labels = labels.to(device)

            preds = model(inputs)

            # Extract upper and lower probabilities
            preds_lo = preds[:, :labels.size(-1)]
            preds_up = preds[:, labels.size(-1):]

            # print('==============>>>>> LABEL SHAPE CHECK TEST ===============>>>>>', labels.size())

            loss_lo += F.cross_entropy(preds_lo, labels, reduction='mean').item()
            loss_up += F.cross_entropy(preds_up, labels, reduction='mean').item()

            num_samples += 1.0
            # num_samples += len(data)

    print("======> Test set upper loss: {:.4f} and lower loss: {:.4f}".format(loss_up / num_samples, loss_lo / num_samples))
    
    return loss_up / num_samples, loss_lo / num_samples


def model_save_name(model_name, seed):
    
    return str(model_name) + "_" + str(seed)