import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from utils.utils import *

class MCELoss_MMB(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(MCELoss_MMB, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, processor):
        softmaxes = F.softmax(logits, dim=1)
        confidences, max_token_id = torch.max(softmaxes, dim=-1)
        # predictions = processor.decode(max_token_id.item(), skip_special_tokens=True)
        # confidences, predictions = torch.max(softmaxes, 1)
        predictions = []
        for i in range(max_token_id.shape[0]):
            token_seq = max_token_id[i].unsqueeze(0)
            if processor.name_or_path in ['Qwen-VL-Chat model path','InternVL3-8B model path']:
                decoded = processor.decode(token_seq, skip_special_tokens=True)
            else:
                decoded = processor.tokenizer.decode(token_seq, skip_special_tokens=True)
            predictions.append(str2int(extract_prediction(decoded)))
        
        predictions = torch.tensor(predictions,device=logits.device)
        accuracies = predictions.eq(labels)
        #print('accuracy: ', accuracies.float().mean().item())
        mce = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                gap = torch.abs(avg_confidence_in_bin - accuracy_in_bin)
                mce = torch.max(mce, gap)

        return mce



class MCELoss_VizWiz(nn.Module):
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(MCELoss_VizWiz, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, input_list, answers, model, t):
        predictions = []
        confidences = []
        device = model.device
        for item in input_list:
            inputs = {k: v.to(device) for k, v in item.items() if isinstance(v, torch.Tensor)}
            pred, conf = model.decode_outputs(inputs, t)
            predictions.append(pred)
            confidences.append(conf)
        confidences = torch.tensor(confidences).clone().detach().to(device)   
        correct = torch.tensor([
            min(1.0, sum(1 for ans in ans_list if clean(pred) == clean(ans)) / 3)
            for pred, ans_list in zip(predictions, answers)
        ], dtype=torch.float32, device=device)

        bin_lowers = self.bin_lowers.clone().detach().to(device)   
        bin_uppers = self.bin_uppers.clone().detach().to(device)   

        conf_unsqueezed = confidences.unsqueeze(1)  # (n_samples, 1)
        in_bins_mask = (conf_unsqueezed > bin_lowers) & (conf_unsqueezed <= bin_uppers)

        prop_in_bin = in_bins_mask.float().mean(dim=0)  # (n_bins,)
        mask_non_empty = prop_in_bin > 0  

        correct_expanded = correct.unsqueeze(1)  # (n_samples, 1)
        correct_sum = (correct_expanded * in_bins_mask).sum(dim=0)  # (n_bins,)
        conf_sum = (confidences.unsqueeze(1) * in_bins_mask).sum(dim=0)  # (n_bins,)
        count_per_bin = in_bins_mask.sum(dim=0).float()  # (n_bins,)

        accuracy_in_bin = torch.zeros_like(prop_in_bin)
        avg_conf_in_bin = torch.zeros_like(prop_in_bin)
        accuracy_in_bin[mask_non_empty] = correct_sum[mask_non_empty] / count_per_bin[mask_non_empty]
        avg_conf_in_bin[mask_non_empty] = conf_sum[mask_non_empty] / count_per_bin[mask_non_empty]

        mce = torch.abs(avg_conf_in_bin - accuracy_in_bin).max()

        return mce.unsqueeze(0)  


class MCELoss_VQAv2(nn.Module):
    """
    caculate eccloss for vqav2
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(MCELoss_VQAv2, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, input_list, answers, model, t):
        predictions = []
        confidences = []
        device = model.device
        for item in input_list:
            inputs = {k: v.to(device) for k, v in item.items() if isinstance(v, torch.Tensor)}
            pred, conf = model.decode_outputs(inputs, t)
            predictions.append(pred)
            confidences.append(conf)
        confidences = torch.tensor(confidences).clone().detach().to(device)   
        
        correct = torch.tensor([clean(p) == clean(a) for p, a in zip(predictions, answers)],
                                dtype = torch.float32,
                                device = device)
        #print('accuracy: ', accuracies.float().mean().item())
        bin_lowers = self.bin_lowers.clone().detach().to(device)   
        bin_uppers = self.bin_uppers.clone().detach().to(device)   

        conf_unsqueezed = confidences.unsqueeze(1)  # (n_samples, 1)
        in_bins_mask = (conf_unsqueezed > bin_lowers) & (conf_unsqueezed <= bin_uppers)
        prop_in_bin = in_bins_mask.float().mean(dim=0)  # (n_bins,)
        mask_non_empty = prop_in_bin > 0 
        correct_expanded = correct.unsqueeze(1)  # (n_samples, 1)
        correct_sum = (correct_expanded * in_bins_mask).sum(dim=0)  # (n_bins,)
        conf_sum = (confidences.unsqueeze(1) * in_bins_mask).sum(dim=0)  # (n_bins,)
        count_per_bin = in_bins_mask.sum(dim=0).float()  # (n_bins,)
        accuracy_in_bin = torch.zeros_like(prop_in_bin)
        avg_conf_in_bin = torch.zeros_like(prop_in_bin)
        accuracy_in_bin[mask_non_empty] = correct_sum[mask_non_empty] / count_per_bin[mask_non_empty]
        avg_conf_in_bin[mask_non_empty] = conf_sum[mask_non_empty] / count_per_bin[mask_non_empty]
        mce = torch.abs(avg_conf_in_bin - accuracy_in_bin).max()

        return mce.unsqueeze(0)  
