'''
This module contains methods for training models with different loss functions.
'''

import torch
from torch.nn import functional as F
from torch import nn
from tqdm import tqdm

from Losses.loss import cross_entropy, focal_loss, focal_loss_adaptive, dual_focal_loss
from Losses.loss import ce_soft_ece, focal_soft_ece, ce_soft_avuc, focal_soft_avuc
from Losses.loss import mmce, mmce_weighted
from Losses.loss import brier_score


loss_function_dict = {
    'cross_entropy': cross_entropy,
    'focal_loss': focal_loss,
    'focal_loss_adaptive': focal_loss_adaptive,
    'dual_focal_loss': dual_focal_loss,
    'mmce': mmce,
    'mmce_weighted': mmce_weighted,
    'brier_score': brier_score,
    'ce_soft_ece': ce_soft_ece,  
    'focal_soft_ece': focal_soft_ece,
    'ce_soft_avuc': ce_soft_avuc,  # Alias for backward compatibility
    'focal_soft_avuc': focal_soft_avuc,  # Alias for backward compatibility
}


def train_single_epoch(epoch, model, train_loader, optimizer, device, loss_function='cross_entropy', 
                       gamma=1.0, lamda=1.0, loss_mean=False, scheduler=None, scaler=None):
    '''
    Util method for training a model for a single epoch.
    '''
    log_interval = 10
    model.train()
    train_loss = 0
    num_samples = 0
    # for batch_idx, (data, labels) in enumerate(train_loader):
    for data, labels in tqdm(train_loader, total=len(train_loader), bar_format="{l_bar}{bar:30}{r_bar}"):
        data = data.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast(): 
            if scheduler is not None:
                logits = model(data)
            else:
                logits, _ = model(data)
            if ('mmce' in loss_function) or ('soft' in loss_function):  
                loss = (len(data) * loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device))
            else:
                loss = loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device)
        
        if loss_mean:
            loss = loss / len(data)
        
        # scaler.scale(loss).backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
        # train_loss += loss.item()        
        # scaler.step(optimizer)
        # scaler.update()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 2) 
        optimizer.step()
        train_loss += loss.item()  

        if scheduler is not None:  
            scheduler.step()
        
        num_samples += len(data)

        # if batch_idx % log_interval == 0:
        #     print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #         epoch, batch_idx * len(data), len(train_loader) * len(data),
        #         100. * batch_idx / len(train_loader),
        #         loss.item()))

    # print('====> Epoch: {} Average Train loss: {:.4f}'.format(epoch, train_loss / num_samples), end=' ')
    return train_loss / num_samples


def test_single_epoch(epoch, model, test_val_loader, device, loss_function='cross_entropy', gamma=1.0, lamda=1.0, scheduler=None):
    '''
    Util method for testing a model for a single epoch.
    '''
    model.eval()
    loss = 0
    num_samples = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_val_loader):
            data = data.to(device)
            labels = labels.to(device)

            if scheduler is not None:  
                logits = model(data)
            else:
                logits, _ = model(data)
            if ('mmce' in loss_function)  or ('soft' in loss_function):
                loss += (len(data) * loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device).item())
            else:
                loss += loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device).item()
            num_samples += len(data)

    # print('======> Test set loss: {:.4f}'.format(loss / num_samples))
    return loss / num_samples