import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from tqdm import tqdm


class DistanseAwareCalibration():

    def __init__(self):
        pass 

    def fit(self, base_text_features_zs, current_text_features_zs, base_text_features_tuned, current_text_features_tuned, k):

        """
        base_text_features_zs: base class text feature generated by zero-shot CLIP
        current_text_features_zs: new class text feature generated by zero-shot CLIP
        base_text_features_tuned: base class text feature generated by few-shot CLIP
        current_text_features_tuned: new class text feature generated by few-shot CLIP
        k: the number of top k nearest features

        """
        class_confidence = []
        cur_class_num = current_text_features_zs.shape[0]
        for i in range(cur_class_num):

            # zero shot
            distances = np.linalg.norm(base_text_features_zs - current_text_features_zs[i], axis=1)
            top_k_distances = np.sort(distances)[:k]  # top distance
            top_k_indices = np.argsort(distances)[:k]  # index
            zs_score = np.exp(-np.sum(top_k_distances) / k)

            # few shot
            distances = np.linalg.norm(base_text_features_tuned - current_text_features_tuned[i], axis=1)
            top_k_distances = np.sort(distances)[:k]  # top distance
            top_k_indices = np.argsort(distances)[:k]  # index
            fs_score = np.exp(-np.sum(top_k_distances) / k)

            if top_k_distances[0] < 0.05: # base class aware
                class_confidence_i = 1.0
            else:
                class_confidence_i = fs_score / zs_score # new class calibration

            class_confidence.append(class_confidence_i)

        self.class_confidence = np.array(class_confidence)


    def predict(self, logits):
        
        # gpu version to acc inference
        logits = torch.from_numpy(logits).float().cuda()
        class_confidences = torch.from_numpy(self.class_confidence).float().cuda()

        pred = logits.max(1)[1]
        logits *= class_confidences[pred][:, None]

        return logits.cpu().numpy()



#cpu version
    
# def difficulity_aware_calibrator(logits, class_confidences):

#     pred = logits.max(1)[1]

#     for i in range(logits.shape[0]): 
#         label = pred[i].item() 
#         confidence = class_confidences[label]
#         logits[i] *= confidence


#     return logits




