'''
This module contains methods for training models with different loss functions.
'''

import torch
from torch.nn import functional as F
from torch import nn
from tqdm import tqdm
import numpy as np

from Losses.loss import cross_entropy, focal_loss, focal_loss_adaptive, dual_focal_loss
from Losses.loss import ce_soft_ece, focal_soft_ece, ce_soft_avuc, focal_soft_avuc
from Losses.loss import mmce, mmce_weighted
from Losses.loss import brier_score
from Losses.soft_ece import SoftBinnedECE, SoftAvUCLoss
from Losses.loss import bsce_gra



loss_function_dict = {
    'cross_entropy': cross_entropy,
    'focal_loss': focal_loss,
    'focal_loss_adaptive': focal_loss_adaptive,
    'dual_focal_loss': dual_focal_loss,
    'mmce': mmce,
    'mmce_weighted': mmce_weighted,
    'brier_score': brier_score,
    'ce_soft_ece': ce_soft_ece,  
    'focal_soft_ece': focal_soft_ece,
    'ce_soft_avuc': ce_soft_avuc,  # Alias for backward compatibility
    'focal_soft_avuc': focal_soft_avuc,  # Alias for backward compatibility
    # 'soft_binned_ece':SoftBinnedECE,
    'bsce_gra': bsce_gra,
}

def compute_gradient_conflict_and_update(model, loss_main, loss_calib, optimizer, scaler=None):
    # trainable_params = [p for p in model.parameters() if p.requires_grad]  
    trainable_params = [p for p in model.fc.parameters() if p.requires_grad]
    # print("Trainable parameters:")
    # for p in trainable_params:
    #     print(f"  {p.shape} - requires_grad: {p.requires_grad}")
    if len(trainable_params) == 0:
        print("No trainable parameters found in the model.")
        return 0.0

    optimizer.zero_grad()
    loss_main.backward(retain_graph=True)
    # scaler.scale(loss_main).backward(retain_graph=True)
    main_grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p) for p in trainable_params]
    
    optimizer.zero_grad()
    loss_calib.backward()
    # scaler.scale(loss_calib).backward()
    calib_grads = [p.grad.clone() if p.grad is not None else torch.zeros_like(p) for p in trainable_params]

    main_flat = torch.cat([g.flatten() for g in main_grads])
    calib_flat = torch.cat([g.flatten() for g in calib_grads])
    
    dot_product = torch.dot(main_flat, calib_flat)
    
    final_grad_flat = None 

    if dot_product < 0:  
        # G_final = G_main - ( (G_main ⋅ G_calib) / ||G_calib||² ) * G_calib
        calib_norm_sq = torch.dot(calib_flat, calib_flat)
        projection_coeff = dot_product / (calib_norm_sq + 1e-8)
        final_grad_flat = main_flat - projection_coeff * calib_flat
    else:  
        # final_grad_flat = main_flat + beta * calib_flat
        final_grad_flat = main_flat

    pointer = 0
    for p in trainable_params:
        num_param = p.numel()
        if num_param > 0 and p.requires_grad:
            p.grad = final_grad_flat[pointer:pointer + num_param].view_as(p).data
            pointer += num_param
    
    return dot_product.item()


def train_single_epoch(epoch, model, train_loader, optimizer, device, loss_function='cross_entropy', 
                       gamma=1.0, lamda=1.0, loss_mean=False,
                       use_gradient_correction=False, gradient_func='softece', num_classes=10,scaler=None):
    '''
    Util method for training a model for a single epoch.
    '''
    model.train()
    train_loss = 0
    num_samples = 0
    times = 0
    # softECE_train, softAvUC_trian = 0, 0

    for batch_data in tqdm(train_loader, total=len(train_loader), bar_format="{l_bar}{bar:30}{r_bar}"):
        if len(batch_data) == 3 :  # (data, labels, is_corrupted)  
            data, labels, is_corrupted = batch_data
            clean_mask = ~is_corrupted  
        else:  
            data, labels = batch_data
            clean_mask = torch.ones(len(data), dtype=torch.bool)
        
        data = data.to(device)
        labels = labels.to(device)
        clean_mask = clean_mask.to(device)

        with torch.cuda.amp.autocast(): 
            logits, _ = model(data)
            
            if ('mmce' in loss_function) or ('soft' in loss_function):
                loss_main = (len(data) * loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device))
            else:
                loss_main = loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device)
            
            if torch.sum(clean_mask) > 0:
                clean_logits = logits[clean_mask]
                clean_labels = labels[clean_mask] 
                if gradient_func == 'softece':
                    loss_calib = SoftBinnedECE(clean_logits, clean_labels)
                else:
                    loss_calib = SoftAvUCLoss(clean_logits, clean_labels)
            else:
                loss_calib = torch.tensor(0.0, device=device)
            
        
        if loss_mean:
            loss_main = loss_main / len(data)
            if torch.sum(clean_mask) > 0:
                loss_calib = loss_calib / torch.sum(clean_mask)
            # softECE = softECE / len(data)
        
        if use_gradient_correction and torch.sum(clean_mask) > 0:
            dot_product = compute_gradient_conflict_and_update(model, loss_main, loss_calib, optimizer, scaler)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()
            # scaler.step(optimizer)
            # scaler.update()
        else:
            # scaler.scale(loss_main).backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            # scaler.step(optimizer)
            # scaler.update()
            optimizer.zero_grad()
            loss_main.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()

        train_loss += loss_main.item()
        times += 1
        num_samples += len(data)
        
    
    # print(f'====> Epoch: {epoch} Train loss: {(train_loss/ num_samples):.4f}, softECE: {(softECE_train/ num_samples):.4f}, softAvUC: {(softAvUC_trian/ num_samples):.4f}')
    
    return train_loss / num_samples


def test_single_epoch(epoch, model, test_val_loader, device, loss_function='cross_entropy', gamma=1.0, lamda=1.0, scheduler=None,
                      cal_loss_function='soft_binned_ece', cal_gamma=1.0, cal_lamda=1.0, num_classes=10):
    '''
    Util method for testing a model for a single epoch.
    '''
    model.eval()
    loss = 0
    softECE_train, softAvUC_trian = 0,0
    num_samples = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_val_loader):
            data = data.to(device)
            labels = labels.to(device)

            if scheduler is not None: 
                logits = model(data)
            else:
                logits, _ = model(data)
                
            if ('mmce' in loss_function) or ('soft' in loss_function):
                loss += (len(data) * loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device).item())
            # elif loss_function == 'MaxEnt_loss':
            #     loss = loss_function_dict[loss_function](logits, labels, ratio=prior_ratio, device=device, num_classes=num_classes)
            else:
                loss += loss_function_dict[loss_function](logits, labels, gamma=gamma, lamda=lamda, device=device).item()
            
            softECE_train += SoftBinnedECE(logits, labels).item()
            softAvUC_trian += SoftAvUCLoss(logits, labels).item()
            num_samples += len(data)

    print(f'====> Epoch: {epoch} test loss: {(loss / num_samples):.4f}, softECE: {(softECE_train / num_samples):.4f}, softAvUC: {(softAvUC_trian / num_samples):.4f}')
    return loss / num_samples