import torch

def pred_score_batch(xt, mu, sigma): #! mu = torch.sqrt(atbar.unsqueeze(1)) * x0), sigma = torch.sqrt(1-atbar)
    """
    mu: (B1,channel,img_size,img_size)
    xt: (B2,channel,img_size,img_size)
    """
    weight = xt.shape[2]
    height = xt.shape[3]
    channel = xt.shape[1]

    softmax = torch.nn.Softmax(dim=1) 
    # atbar = torch.tensor(atbar).squeeze() #! torch.Size([B])
    xt = xt.reshape(xt.shape[0],-1) #! torch.Size([B2, 3*1024])
    mu = mu.reshape(mu.shape[0],-1)
    #! xt.unsqueeze(1)  torch.Size([B1, 1, 3*1024])
    #! torch.sqrt(atbar) * x0) torch.Size([B2,B1])

    norm = torch.sum((xt.unsqueeze(1) - mu)**2,dim=[-1]) #! [B2,B1] [3,5]

    pt = norm.div(-2*(sigma**2)) #! [B2,B1] [3,5]

    softmax_ = softmax(pt) #! [B2,B1] [3,5]

    diff = mu - xt.unsqueeze(1) #! torch.Size([B2, B1, 3*1024]) torch.Size([3, 5, 3*1024])

    temp = torch.einsum('bc,bcd->bd', softmax_, diff) #! torch.Size([B2, 3*1024]) torch.Size([3, 3*1024])
    # temp = temp.reshape(xt.size(0), channel, weight, height) #! torch.Size([3, 3, 32, 32])
    score = temp / (sigma**2)
    eps = - temp / sigma
    pre_x0 = score * (sigma**2) + xt

    eps = eps.reshape(xt.size(0), channel, weight, height)
    score = score.reshape(xt.size(0), channel, weight, height)
    pre_x0 = pre_x0.reshape(xt.size(0), channel, weight, height)

    return score, eps, pre_x0 #! (B2,channel,img_size,img_size) 和 xt 一样 [3,3,32,32] 

def pred_score_batch_eps(x0, xt, atbar):
    """
    x0: (B1,channel,img_size,img_size)
    xt: (B2,channel,img_size,img_size)
    atbar: xt: (B2,)
    """
    weight = xt.shape[2]
    height = xt.shape[3]
    channel = xt.shape[1]

    softmax = torch.nn.Softmax(dim=1) 
    # atbar = torch.tensor(atbar).squeeze() #! torch.Size([B])
    xt = xt.reshape(xt.shape[0],-1) #! torch.Size([B2, 3*1024])
    x0 = x0.reshape(x0.shape[0],-1) #! torch.Size([B1, 3*1024])
    atbar = atbar.reshape(atbar.shape[0],1)
    #! xt.unsqueeze(1)  torch.Size([B1, 1, 3*1024])
    #! torch.sqrt(atbar) * x0) torch.Size([B2,B1])

    norm = torch.sum((xt.unsqueeze(1) - torch.sqrt(atbar.unsqueeze(1)) * x0)**2,dim=[-1]) #! [B2,B1] [3,5]

    pt = norm.div(-2*(1-atbar)) #! [B2,B1] [3,5]

    softmax_ = softmax(pt) #! [B2,B1] [3,5]

    diff = torch.sqrt(atbar.unsqueeze(1)) * x0 - xt.unsqueeze(1) #! torch.Size([B2, B1, 3*1024]) torch.Size([3, 5, 3*1024])

    temp = torch.einsum('bc,bcd->bd', softmax_, diff) #! torch.Size([B2, 3*1024]) torch.Size([3, 3*1024])
    # temp = temp.reshape(xt.size(0), channel, weight, height) #! torch.Size([3, 3, 32, 32])
    score = temp / (1-atbar)
    eps = - temp / torch.sqrt(1-atbar)

    score = score.reshape(xt.size(0), channel, weight, height)
    eps = eps.reshape(xt.size(0), channel, weight, height)
    
    return score, eps, softmax_ #! (B2,channel,img_size,img_size) 和 xt 一样 [3,3,32,32] 

def stf_targets(sigmas, perturbed_samples, ref):
        """

        Args:
            sigmas: noisy levels
            perturbed_samples: perturbed samples with perturbation kernel N(0, sigmas**2)
            ref: the reference batch

        Returns: stable target

        """
        with torch.no_grad():
            perturbed_samples_vec = perturbed_samples.reshape((len(perturbed_samples), -1)) #! [128,sizs*size]
            ref_vec = ref.reshape((len(ref), -1)) #! [1024,sizs*size]

            gt_distance = torch.sum((perturbed_samples_vec.unsqueeze(1) - ref_vec) ** 2,
                                    dim=[-1])  #! ||xt - xk||^2
            #! perturbed_samples_vec.unsqueeze(1) [128,1,sizs*size]
            #! (perturbed_samples_vec.unsqueeze(1) - ref_vec) ** 2 [128,1024,size*size] 利用boardcast计算两两距离
            #! gt_distance [128,1024] 两两之间的距离
            
            gt_distance = - gt_distance / (2 * sigmas.unsqueeze(1) ** 2) #! 公式 7 -||xt - xk||^2 / 2*sigma^2
            # adding a constant to the log-weights to prevent numerical issue
            distance = - torch.max(gt_distance, dim=1, keepdim=True)[0] + gt_distance #! 按照128那个维度减去最大值防止exp的时候数值爆炸
            distance = torch.exp(distance)[:, :, None] #! [:, :, None]在最后加一个维度 变成[128,1024,1]
            # self-normalize the per-sample weight of reference batch
            weights = distance / (torch.sum(distance, dim=1, keepdim=True)) #! [128,1024,1] / [128,1,1]对应相除 =  [128,1024,1]

            target = ref_vec.unsqueeze(0).repeat(len(perturbed_samples), 1, 1) #! [128, 1024, 3072] ref_vec按照dim=1重复1024次 
            #! 这个target好像跟论文里对不上
            # calculate the stable targets with reference batch
            stable_targets = torch.sum(weights * target, dim=1) #! 这个是整个分数的sum 
            return stable_targets #! [128, 3072]