import torch
from torch import nn, optim
from torch.nn import functional as F
from utils.utils import *

class ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, processor):
        softmaxes = F.softmax(logits, dim=1)
        confidences, max_token_id = torch.max(softmaxes, dim=-1)
        # predictions = processor.decode(max_token_id.item(), skip_special_tokens=True)
        # confidences, predictions = torch.max(softmaxes, 1)
        predictions = []
        for i in range(max_token_id.shape[0]):
            token_seq = max_token_id[i].unsqueeze(0)
            decoded = processor.decode(token_seq, skip_special_tokens=True)
            predictions.append(str2int(extract_prediction(decoded)))
        
        predictions = torch.tensor(predictions,device=logits.device)
        accuracies = predictions.eq(labels)
        #print('accuracy: ', accuracies.float().mean().item())
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

class ECELoss_MMB(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_MMB, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, processor):
        softmaxes = F.softmax(logits, dim=1)
        confidences, max_token_id = torch.max(softmaxes, dim=-1)
        # predictions = processor.decode(max_token_id.item(), skip_special_tokens=True)
        # confidences, predictions = torch.max(softmaxes, 1)
        predictions = []
        for i in range(max_token_id.shape[0]):
            token_seq = max_token_id[i].unsqueeze(0)
            decoded = processor.tokenizer.decode(token_seq, skip_special_tokens=True)
            predictions.append(str2int(extract_prediction(decoded)))
        
        predictions = torch.tensor(predictions,device=logits.device)
        accuracies = predictions.eq(labels)
        #print('accuracy: ', accuracies.float().mean().item())
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece


class ECELoss_SM(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_SM, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        #softmaxes = F.softmax(logits, dim=1)
        softmaxes = logits
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        #print('accuracy: ', accuracies.float().mean().item())
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece


class ECELoss_Top1(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_Top1, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, softmaxes, confidences, labels):
        _, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(labels)
        #print('accuracy: ', accuracies.float().mean().item())
        ece = torch.zeros(1, device=softmaxes.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece
    

class ECELoss_Text(nn.Module):
    """
    Calculates the Expected Calibration Error of a model that only process text info like Qwen.
    
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_Text, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels, tokenizer):
        softmaxes = F.softmax(logits, dim=1)
        confidences, max_token_id = torch.max(softmaxes, dim=-1)
        # predictions = processor.decode(max_token_id.item(), skip_special_tokens=True)
        # confidences, predictions = torch.max(softmaxes, 1)
        predictions = []
        for i in range(max_token_id.shape[0]):
            token_seq = max_token_id[i].unsqueeze(0)
            # decoded = processor.decode(token_seq, skip_special_tokens=True)
            decoded = tokenizer.decode(token_seq.cpu()[0], skip_special_tokens=True)
            predictions.append(str2int(extract_prediction(decoded)))
        
        predictions = torch.tensor(predictions,device=labels.device)
        accuracies = predictions.eq(labels)
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                #print(accuracy_in_bin)
                avg_confidence_in_bin = confidences[in_bin].mean()
                #print(avg_confidence_in_bin)
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece


class ECELoss_VQAv2(nn.Module):
    """
    caculate eccloss for vqav2
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_VQAv2, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, input_list, answers, model, t):
        predictions = []
        confidences = []
        device = model.device
        for item in input_list:
            inputs = {k: v.to(device) for k, v in item.items() if isinstance(v, torch.Tensor)}
            pred, conf = model.decode_outputs(inputs, t)
            predictions.append(pred)
            confidences.append(conf)
        confidences = torch.tensor(confidences).clone().detach().to(device)   

        correct = torch.tensor([clean(p) == clean(a) for p, a in zip(predictions, answers)],
                                dtype = torch.float32,
                                device = device)
        #print('accuracy: ', accuracies.float().mean().item())
        bin_lowers = self.bin_lowers.clone().detach().to(device)   
        bin_uppers = self.bin_uppers.clone().detach().to(device)   

        # (n_samples, n_bins)
        conf_unsqueezed = confidences.unsqueeze(1)  # (n_samples, 1)
        in_bins_mask = (conf_unsqueezed > bin_lowers) & (conf_unsqueezed <= bin_uppers)


        prop_in_bin = in_bins_mask.float().mean(dim=0)  # (n_bins,)
        mask_non_empty = prop_in_bin > 0 

        correct_expanded = correct.unsqueeze(1)  # (n_samples, 1)
        correct_sum = (correct_expanded * in_bins_mask).sum(dim=0)  # (n_bins,)
        conf_sum = (confidences.unsqueeze(1) * in_bins_mask).sum(dim=0)  # (n_bins,)
        count_per_bin = in_bins_mask.sum(dim=0).float()  # (n_bins,)

        accuracy_in_bin = torch.zeros_like(prop_in_bin)
        avg_conf_in_bin = torch.zeros_like(prop_in_bin)
        accuracy_in_bin[mask_non_empty] = correct_sum[mask_non_empty] / count_per_bin[mask_non_empty]
        avg_conf_in_bin[mask_non_empty] = conf_sum[mask_non_empty] / count_per_bin[mask_non_empty]

        ece = (torch.abs(avg_conf_in_bin - accuracy_in_bin) * prop_in_bin).sum()

        return ece.unsqueeze(0) 


class ECELoss_VizWiz(nn.Module):
    """
    caculate eccloss for vqav2
    """
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECELoss_VizWiz, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, input_list, answers, model, t):
        predictions = []
        confidences = []
        device = model.device
        for item in input_list:
            inputs = {k: v.to(device) for k, v in item.items() if isinstance(v, torch.Tensor)}
            pred, conf = model.decode_outputs(inputs, t)
            predictions.append(pred)
            confidences.append(conf)
        confidences = torch.tensor(confidences).clone().detach().to(device)   
        correct = torch.tensor([
            min(1.0, sum(1 for ans in ans_list if clean(pred) == clean(ans)) / 3)
            for pred, ans_list in zip(predictions, answers)
        ], dtype=torch.float32, device=device)

        bin_lowers = self.bin_lowers.clone().detach().to(device)   
        bin_uppers = self.bin_uppers.clone().detach().to(device)   

        conf_unsqueezed = confidences.unsqueeze(1)  # (n_samples, 1)
        in_bins_mask = (conf_unsqueezed > bin_lowers) & (conf_unsqueezed <= bin_uppers)

        prop_in_bin = in_bins_mask.float().mean(dim=0)  # (n_bins,)
        mask_non_empty = prop_in_bin > 0  

        correct_expanded = correct.unsqueeze(1)  # (n_samples, 1)
        correct_sum = (correct_expanded * in_bins_mask).sum(dim=0)  # (n_bins,)
        conf_sum = (confidences.unsqueeze(1) * in_bins_mask).sum(dim=0)  # (n_bins,)
        count_per_bin = in_bins_mask.sum(dim=0).float()  # (n_bins,)

        accuracy_in_bin = torch.zeros_like(prop_in_bin)
        avg_conf_in_bin = torch.zeros_like(prop_in_bin)
        accuracy_in_bin[mask_non_empty] = correct_sum[mask_non_empty] / count_per_bin[mask_non_empty]
        avg_conf_in_bin[mask_non_empty] = conf_sum[mask_non_empty] / count_per_bin[mask_non_empty]

        ece = (torch.abs(avg_conf_in_bin - accuracy_in_bin) * prop_in_bin).sum()

        return ece.unsqueeze(0)  
