import torch
import numpy as np



def postHocLogits(transforamtion,logits_loader,device,num_classes,mask):
    """
    应用transformation 算法到原始的logits
    """
    transforamtion.eval()
    transforamtion.to(device)
   
    if isinstance(mask,np.ndarray):
        num_classes = len(mask)
    elif mask==None:
        mask=np.arange(num_classes)
    else:
        raise NotImplementedError
    logits = torch.zeros((len(logits_loader.dataset), num_classes)) # 1000 classes in Imagenet.
    labels = torch.zeros((len(logits_loader.dataset),))
    i = 0
    with torch.no_grad():
        for batch_logits, targets in logits_loader:
            batch_logits = transforamtion(batch_logits.to(device))
            logits[i:(i+batch_logits.shape[0]), :] = batch_logits
            labels[i:(i+batch_logits.shape[0])] = targets.cpu()
            i = i + batch_logits.shape[0]
    dataset_logits = torch.utils.data.TensorDataset(logits, labels.long()) 
    return dataset_logits



import torch.nn as nn 


class PostHoc(nn.Module):


    def forward(self,batch_logits):
        return batch_logits


class LogitNormalization(PostHoc):
    def __init__(self):
        super().__init__()

    def forward(self,batch_logits):
        norms = torch.norm(batch_logits, p=2, dim=-1, keepdim=True) + 1e-7
        batch_logits = torch.div(batch_logits, norms) 
        # batch_logits = batch_logits
        return batch_logits
    

    
    

    
class InputyAtypicality(PostHoc):
    def __init__(self,num_class) -> None:
        super().__init__()
        self.num_class = num_class
        self.weight=    nn.Parameter(torch.rand(3))


    def forward(self,batch_logits):
        norms = torch.norm(batch_logits, p=2, dim=-1, keepdim=True) + 1e-7
        phi  = norms**2*self.weight[0]+norms*self.weight[1]+self.weight[2]
        phi  = (norms*self.weight[0]+self.weight[1])**2
        # batch_logits = batch_logits*phi +
        return batch_logits
    


class OptimalTeamperatureScaling(PostHoc):
    """optimal teamperature"""
    def __init__(self,temperature=1) -> None:
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1)*temperature)


    def forward(self,batch_logits):
        return batch_logits/self.temperature
    
    
class FixedTeamperatureScaling(PostHoc):
    """optimal teamperature"""
    def __init__(self,temperature=1) -> None:
        super().__init__()
        self.temperature = torch.ones(1)*temperature
        self.temperature.cuda()


    def forward(self,batch_logits):
        return batch_logits/self.temperature.cuda()
    
    
    



