import itertools
import random

import torch
import matplotlib.pyplot as plt
import numpy as np
from skimage.transform import resize
from torchvision.transforms import functional as TF
from matplotlib.colors import Normalize
import torch.nn.functional as F


def plot_conf_mat(conf_mat, class_names, fig_size, label_font_size=11,
                  show_num=False, show_color_bar=False, xrotate=45):
    """
    Make pyplot figure for confusion matrix

    Args:
        conf_mat (np.array): the confusion matrix
        class_names (Iterable): the class names
        fig_size (Tuple[int, int]): the plot figure size
        label_font_size (int): the font size for x axis tick labels
        show_num (bool): whether to show the number on the figure
        show_color_bar (bool): whether to show the color bar
        xrotate (int): rotation degree for x axis tick labels
    Returns:
        @rtype: plt.figure.Figure
    """
    fig, ax = plt.subplots(figsize=fig_size)
    cm = np.around(conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis], decimals=2)
    cm = np.nan_to_num(cm)
    im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
    ax.set_title('Confusion Matrix')
    tick_marks = np.arange(len(class_names))
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(class_names, fontdict={'fontsize': label_font_size})
    ax.set_yticklabels(class_names, fontdict={'fontsize': label_font_size})
    plt.setp(ax.get_xticklabels(), rotation=xrotate)
    plt.margins(0.5)

    if show_num:
        threshold = cm.max() / 2
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            color = 'white' if cm[i, j] > threshold else 'black'
            ax.text(j, i, cm[i, j], horizontalalignment='right', color=color)

    ax.set_ylabel('GT')
    ax.set_xlabel('pred')
    if show_color_bar:
        fig.colorbar(im, ax=ax)
    plt.tight_layout()
    return fig


def images_denorm(images, mean=None, std=None):
    """
    De-normalize image array.

    Args:
        images (np.array): the image array in the shape [nb, 3, w, h]
        mean (Union[float, Iterable]): the mean values of each channel, if float,
                                       it will repeat for all the 3 channels
        std (Union[float, Iterable]): the standard deviation of each channel, if float,
                                      it will repeat for all the 3 channels

    Returns:
        @rtype: np.array
    """
    if mean is None:
        mean = [0.485, 0.456, 0.406]
    if std is None:
        std = [0.485, 0.456, 0.406]

    if isinstance(mean, float):
        mean = [mean] * 3
    if isinstance(std, float):
        std = [std] * 3

    mean = np.array(mean).reshape((1, 3, 1, 1))
    std = np.array(std).reshape((1, 3, 1, 1))
    return images * std + mean


def plot_attn_maps(images, attn_maps, size=2.5):
    """
    Make pyplot figure for attention maps.

    Args:
        images (np.array): the image array in the shape [nb, w, h, 3]
        attn_maps (np.array): the attention array in the shape [nb, N, w, h]
        size (int): image size in figure

    Returns:
        @rtype: plt.figure.Figure
    """
    num_image = images.shape[0]
    num_attn = attn_maps.shape[1]
    fig, axes = plt.subplots(nrows=num_image, ncols=num_attn,
                             figsize=(size*num_attn, size*num_image),
                             squeeze=False)
    for i, j in itertools.product(range(num_image), range(num_attn)):
            image = np.clip(images[i], 0, 1)
            attn = attn_maps[i, j]
            attn = resize(attn, image.shape[:2], order=0)
            ax = axes[i, j]
            ax.imshow(image)
            ax.imshow(attn, cmap='hot', interpolation='nearest', alpha=.4)
            ax.margins(.1)
            ax.axis('off')
    plt.tight_layout()
    return fig


def plot_mu_prob(images, mu_prob, size=2.5):
    """
    Make pyplot figure for mu probability distributions.

    Args:
        images (np.array): the image array in the shape of [nb, w, h, 3]
        mu_prob (np.array): the probability distribution in the shape of [nb, N, w, h]
        size (int): image size in figure
    """
    num_image = images.shape[0]
    num_attn = mu_prob.shape[1]
    fig, axes = plt.subplots(nrows=num_image, ncols=(num_attn + 1),
                             figsize=(size*(num_attn + 1), size*num_image),
                             squeeze=False)
    for i in range(num_image):
        image = np.clip(images[i], 0, 1)
        ax = axes[i, 0]
        ax.imshow(image)
        ax.margins(.1)
        ax.axis('off')
        for j in range(num_attn):
            ax = axes[i, j + 1]
            attn = mu_prob[i, j]
            attn = resize(attn, image.shape[:2], order=0)
            ax.imshow(attn, cmap='hot', interpolation='nearest', norm=Normalize(0, 1))
            ax.margins(.1)
            ax.axis('off')
    plt.tight_layout()
    return fig


# Pytorch has issue with gradient calculation in logdet() and det(), we hence manually implement
# the determinant for low dim cases in the following functions
# group start
def det2(tensor):
    # determinant of 2 x 2 matrices (along the first dim)
    assert tensor.dim() == 3
    assert tensor.shape[1:] == (2, 2)
    return tensor[:, 0, 0] * tensor[:, 1, 1] - tensor[:, 0, 1] * tensor[:, 1, 0]


def det3(tensor):
    # determinant of 3 x 3 matrices (along the first dim)
    assert tensor.dim() == 3
    assert tensor.shape[1:] == (3, 3)
    term0 = tensor[:, 0, 0] * det2(tensor[:, 1:, 1:])
    term1 = tensor[:, 0, 1] * det2(tensor[:, 1:, [0, 2]])
    term2 = tensor[:, 0, 2] * det2(tensor[:, 1:, :2])
    return term0 - term1 + term2


def det4(tensor):
    # determinant of 4 x 4 matrices (along the first dim)
    assert tensor.dim() == 3
    assert tensor.shape[1:] == (4, 4)
    term0 = tensor[:, 0, 0] * det3(tensor[:, 1:, 1:])
    term1 = tensor[:, 0, 1] * det3(tensor[:, 1:, [0, 2, 3]])
    term2 = tensor[:, 0, 2] * det3(tensor[:, 1:, [0, 1, 3]])
    term3 = tensor[:, 0, 3] * det3(tensor[:, 1:, :3])
    return term0 - term1 + term2 - term3


def det5(tensor):
    # determinant of 5 x 5 matrices (along the first dim)
    assert tensor.dim() == 3
    assert tensor.shape[1:] == (5, 5)
    term0 = tensor[:, 0, 0] * det4(tensor[:, 1:, 1:])
    term1 = tensor[:, 0, 1] * det4(tensor[:, 1:, [0, 2, 3, 4]])
    term2 = tensor[:, 0, 2] * det4(tensor[:, 1:, [0, 1, 3, 4]])
    term3 = tensor[:, 0, 3] * det4(tensor[:, 1:, [0, 1, 2, 4]])
    term4 = tensor[:, 0, 4] * det4(tensor[:, 1:, :4])
    return term0 - term1 + term2 - term3 + term4


def det6(tensor):
    # determinant of 6 x 6 matrices (along the first dim)
    assert tensor.dim() == 3
    assert tensor.shape[1:] == (6, 6)
    term0 = tensor[:, 0, 0] * det5(tensor[:, 1:, 1:])
    term1 = tensor[:, 0, 1] * det5(tensor[:, 1:, [0, 2, 3, 4, 5]])
    term2 = tensor[:, 0, 2] * det5(tensor[:, 1:, [0, 1, 3, 4, 5]])
    term3 = tensor[:, 0, 3] * det5(tensor[:, 1:, [0, 1, 2, 4, 5]])
    term4 = tensor[:, 0, 4] * det5(tensor[:, 1:, [0, 1, 2, 3, 5]])
    term5 = tensor[:, 0, 5] * det5(tensor[:, 1:, :5])
    return term0 - term1 + term2 - term3 + term4 - term5
# group end


class SelectAngleRotate(object):
    """Rotate the input image by angle randomly selected from the list"""

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)


def compute_pair_wise_diff(all_locs, num):
    """
    Compute the pairwise differences between all locations.

    Args:
        all_locs (Tensor): a torch tensor containing all locations in a stacked tensor with shape
                            [nb x (num*2)]
        num (int): number of locations
    """
    loc_x = all_locs[:, :num]
    loc_y = all_locs[:, num:]
    loc = torch.stack([loc_x, loc_y], dim=2)
    all_diffs = []
    for id1, id2 in itertools.combinations(range(num), 2):
        diff = loc[:, id1] - loc[:, id2]
        all_diffs.append(diff)
    diffs = torch.cat(all_diffs, dim=1)
    return diffs

def compute_pair_wise_diff_with_xycoord(all_locs, num):
    """
    Compute the pairwise differences between all locations.

    Args:
        all_locs (Tensor): a torch tensor containing all locations in a stacked tensor with shape
                            [nb x (num*2)]
        num (int): number of locations
        return a list of batch*2 tensor, length is combination number (number of edges)
    """
    loc_x = all_locs[:, :num]
    loc_y = all_locs[:, num:]
    loc = torch.stack([loc_x, loc_y], dim=2)
    all_diffs = []
    for id1, id2 in itertools.combinations(range(num), 2):
        diff = loc[:, id1] - loc[:, id2]
        all_diffs.append(diff)
    #diffs = torch.stack(all_diffs, dim=0)
    return all_diffs #6*150*2

def compute_pair_wise_diff_with_xycoord_all(all_locs, num):
    """
    Compute the pairwise differences between all locations.

    Args:
        all_locs (Tensor): a torch tensor containing all locations in a stacked tensor with shape
                            [nb x (num*2)]
        num (int): number of locations
        return a list of batch*2 tensor, length is combination number (number of edges)
    """
    n = len(all_locs)
    all_diffs_list = []
    for i in range(n):
        loc_x = all_locs[i][:, :num]
        loc_y = all_locs[i][:, num:]
        loc = torch.stack([loc_x, loc_y], dim=2)
        all_diffs = []
        for id1, id2 in itertools.combinations(range(num), 2):
            diff = loc[:, id1] - loc[:, id2]
            all_diffs.append(diff)
        all_diffs_list.append(all_diffs)
    #diffs = torch.stack(all_diffs, dim=0)
    return all_diffs_list #3 个 6 个 150*2


def compute_combined_protonet_scores(features, diffs, target, sup_batch_size):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """
    feat_support = features[:sup_batch_size]
    feat_query = features[sup_batch_size:]
    diffs_support = diffs[:sup_batch_size]
    diffs_query = diffs[sup_batch_size:]
    # target_support = target[:sup_batch_size]

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_features = []
    mean_diffs = []
    target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        mean_features.append(feat_support[mask].sum(dim=0))
        mean_diffs.append(diffs_support[mask].sum(dim=0))
    mean_features = F.normalize(torch.stack(mean_features), dim=1)
    mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
    feat_query = F.normalize(feat_query, dim=1)
    diffs_query = F.normalize(diffs_query, dim=1)

    diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
    feat_sims = torch.einsum('ik,jk->ij', feat_query, mean_features)
    return feat_sims, diff_sims, unique_labels, uni_idx

def compute_combined_protonetpp_scores(features, diffs, target, sup_batch_size, temperature=1):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """
    features = F.normalize(features, dim=1)
    diffs = F.normalize(diffs, dim=1)
    feat_support = features[:sup_batch_size]
    feat_query = features[sup_batch_size:]
    diffs_support = diffs[:sup_batch_size]
    diffs_query = diffs[sup_batch_size:]
    # target_support = target[:sup_batch_size]

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)

    target_support = uni_idx[:sup_batch_size]
    # for i in range(len(unique_labels)):
    #     mask = target_support == i
    #     mean_features.append(feat_support[mask].sum(dim=0))
    #     mean_diffs.append(diffs_support[mask].sum(dim=0))
    # mean_features = F.normalize(torch.stack(mean_features), dim=1)
    # mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
    # feat_query = F.normalize(feat_query, dim=1)
    # diffs_query = F.normalize(diffs_query, dim=1)
    #
    # diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
    # feat_sims = torch.einsum('ik,jk->ij', feat_query, mean_features)
    feat_logits = torch.einsum('ik,jk->ij', feat_query, feat_support) / temperature
    feat_sims_log_p_y = F.log_softmax(feat_logits, dim=1)

    diff_logits = torch.einsum('ik,jk->ij', diffs_query, diffs_support) / temperature
    diff_sims_log_p_y = F.log_softmax(diff_logits, dim=1)
    # mask = target_query.unsqueeze(1) == target_support.unsqueeze(0)
    # loss = -log_p_y[mask].sum() / self.num_support_tr / len(target_query)
    return feat_sims_log_p_y, diff_sims_log_p_y, unique_labels, uni_idx

def compute_partbypart_combined_protonet_scores(Pss, diffs, target, sup_batch_size):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """

    diffs_support = diffs[:sup_batch_size]
    diffs_query = diffs[sup_batch_size:]
    # compute similarity for each part
    parts_similarity_score = []
    for Ps in Pss:
        #Ps = Ps.view(Ps.shape[0], -1)
        # add difference between mus to feature
        feat_sims, unique_labels, _ = compute_pure_protonet_scores(
            Ps, target, sup_batch_size)
        parts_similarity_score.append(feat_sims)


    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_diffs = []
    target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        mean_diffs.append(diffs_support[mask].sum(dim=0))
    mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
    diffs_query = F.normalize(diffs_query, dim=1)
    diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)

    return parts_similarity_score, diff_sims, unique_labels, uni_idx

def compute_multiscale_partbypart_combined_protonet_scores(Pss_list, diffs, target, sup_batch_size):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """

    diffs_support = diffs[:sup_batch_size]
    diffs_query = diffs[sup_batch_size:]
    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])

    for support_Ps in Pss_list:
        for query_Ps in Pss_list:
            for i in range(na):
                feat_sims, unique_labels, _ = compute_pure_protonet_scores_from_two_resolution(
                    support_Ps[i],query_Ps[i], target, sup_batch_size)
                parts_similarity_score.append(feat_sims)

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_diffs = []
    target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        mean_diffs.append(diffs_support[mask].sum(dim=0))
    mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
    diffs_query = F.normalize(diffs_query, dim=1)
    diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)

    return parts_similarity_score, diff_sims, unique_labels, uni_idx

def compute_multiscale_partbypart_combined_protonet_scores_geo(Pss_list, diffs_list, target, sup_batch_size):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations along x and y shape:batch * combination * 2
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """


    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])

    for support_Ps in Pss_list:
        for query_Ps in Pss_list:
            for i in range(na):
                feat_sims, unique_labels, _ = compute_pure_protonet_scores_from_two_resolution(
                    support_Ps[i],query_Ps[i], target, sup_batch_size)
                parts_similarity_score.append(feat_sims)

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)

    target_support = uni_idx[:sup_batch_size]

    # save smilarity edge similarity score
    diff_similarity_score = []
    for diffs_coord in diffs_list:
        diffs_support = diffs_coord[:sup_batch_size]
        diffs_query = diffs_coord[sup_batch_size:]
        mean_diffs = []
        for i in range(len(unique_labels)):
            mask = target_support == i
            mean_diffs.append(diffs_support[mask].sum(dim=0))
        mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
        diffs_query = F.normalize(diffs_query, dim=1)
        diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
        diff_similarity_score.append(diff_sims)


    return parts_similarity_score, diff_similarity_score, unique_labels, uni_idx

def compute_multiscale_partbypart_combined_protonet_scores_geo_all(Pss_list, diffs_list, target, sup_batch_size, ifsupport=False):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations along x and y shape:batch * combination * 2
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """


    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])

    for support_Ps in Pss_list:
        for query_Ps in Pss_list:
            for i in range(na):
                feat_sims, unique_labels, _ = compute_pure_protonet_scores_from_two_resolution(
                    support_Ps[i],query_Ps[i], target, sup_batch_size, ifsupport)
                parts_similarity_score.append(feat_sims)

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)

    if ifsupport:
        target_support = uni_idx
    else:
        target_support = uni_idx[:sup_batch_size]

    # save smilarity edge similarity score
    ne = len(diffs_list[0])
    diff_similarity_score = []
    for support_diffs in diffs_list:
        for query_diffs in diffs_list:
            for k in range(ne):
                if ifsupport:
                    diffs_support = support_diffs[k]
                    diffs_query = query_diffs[k]
                else:
                    diffs_support = support_diffs[k][:sup_batch_size]
                    diffs_query = query_diffs[k][sup_batch_size:]
                mean_diffs = []
                for i in range(len(unique_labels)):
                    mask = target_support == i
                    mean_diffs.append(diffs_support[mask].sum(dim=0))
                mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
                diffs_query = F.normalize(diffs_query, dim=1)
                diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
                diff_similarity_score.append(diff_sims)


    return parts_similarity_score, diff_similarity_score, unique_labels, uni_idx


def compute_multiscale_partbypart_combined_protonet_scores_geo_all_maximum_each_part(Pss_list, diffs_list, target, sup_batch_size, ifsupport=False):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations along x and y shape:batch * combination * 2
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """


    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])

    # for support_Ps in Pss_list:
    #     for query_Ps in Pss_list:
    #         for i in range(na):
    #             feat_sims, unique_labels, _ = compute_pure_protonet_scores_from_two_resolution(
    #                 support_Ps[i],query_Ps[i], target, sup_batch_size, ifsupport)
    #             parts_similarity_score.append(feat_sims)
    for i in range(na):
        for support_Ps in Pss_list:
            for query_Ps in Pss_list:
                feat_sims, unique_labels, _ = compute_pure_protonet_scores_from_two_resolution(
                    support_Ps[i], query_Ps[i], target, sup_batch_size, ifsupport)
                parts_similarity_score.append(feat_sims)


    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)

    # if ifsupport:
    #     target_support = uni_idx
    # else:
    #     target_support = uni_idx[:sup_batch_size]

    # save smilarity edge similarity score
    # ne = len(diffs_list[0])
    # diff_similarity_score = []
    # for support_diffs in diffs_list:
    #     for query_diffs in diffs_list:
    #         for k in range(ne):
    #             if ifsupport:
    #                 diffs_support = support_diffs[k]
    #                 diffs_query = query_diffs[k]
    #             else:
    #                 diffs_support = support_diffs[k][:sup_batch_size]
    #                 diffs_query = query_diffs[k][sup_batch_size:]
    #             mean_diffs = []
    #             for i in range(len(unique_labels)):
    #                 mask = target_support == i
    #                 mean_diffs.append(diffs_support[mask].sum(dim=0))
    #             mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
    #             diffs_query = F.normalize(diffs_query, dim=1)
    #             diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
    #             diff_similarity_score.append(diff_sims)


    return parts_similarity_score, diff_similarity_score, unique_labels, uni_idx


def compute_multiscale_partbypart_combined_protonet_scores_geo_all_1(Pss_list, diffs_list, target,inner_weight, sup_batch_size, ifsupport=False):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations along x and y shape:batch * combination * 2
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """


    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])
    k=0

    for support_Ps in Pss_list:
        for query_Ps in Pss_list:
            for i in range(na):
                feat_sims, unique_labels, _ = compute_weighted_protonet_scores_from_two_resolution(
                    support_Ps[i],query_Ps[i], target, sup_batch_size, inner_weight[:,k],ifsupport)
                parts_similarity_score.append(feat_sims)
                k=k+1

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)

    if ifsupport:
        target_support = uni_idx
    else:
        target_support = uni_idx[:sup_batch_size]

    # save smilarity edge similarity score
    ne = len(diffs_list[0])
    diff_similarity_score = []
    for support_diffs in diffs_list:
        for query_diffs in diffs_list:
            for k in range(ne):
                if ifsupport:
                    diffs_support = support_diffs[k]
                    diffs_query = query_diffs[k]
                else:
                    diffs_support = support_diffs[k][:sup_batch_size]
                    diffs_query = query_diffs[k][sup_batch_size:]
                mean_diffs = []
                for i in range(len(unique_labels)):
                    mask = target_support == i
                    mean_diffs.append(diffs_support[mask].sum(dim=0))
                mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
                diffs_query = F.normalize(diffs_query, dim=1)
                diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
                diff_similarity_score.append(diff_sims)


    return parts_similarity_score, diff_similarity_score, unique_labels, uni_idx

def compute_multiscale_partbypart_pairwise_image_smilarity_score(Pss_list, diffs_list, target, sup_batch_size, weight, weight_geo):
    """
    Compute the weighted protonet score given both features and part location differences

    Args:
        Pss_list (list of tensors Tensor): 2d list. first d is number of scale, second d is number of attention, each is a batch(supp+que)*512
        diffs (Tensor):  the pairwise differences between all locations along x and y shape:batch * combination * 2
        target (Tensor): the target labels
        weight (Tensor): batch(supp+que) * (54)   54 = scale x scale x num_attention
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """


    # compute similarity for each part
    parts_similarity_score = []
    na = len(Pss_list[0])
    idx = 0
    for support_Ps in Pss_list:
        for query_Ps in Pss_list:
            for i in range(na):
                feat_sims, unique_labels, uni_idx, class_wise_score = compute_pure_pairwise_scores_weighted_from_two_resolution(
                    support_Ps[i],query_Ps[i],target, sup_batch_size,weight[:,:,idx])
                parts_similarity_score.append(class_wise_score)
                idx = idx +1


    # # extract the prototype by mean
    # unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    #
    # if ifsupport:
    #     target_support = uni_idx
    # else:
    #     target_support = uni_idx[:sup_batch_size]
    #
    # # save smilarity edge similarity score
    ne = len(diffs_list[0])
    diff_similarity_score = []
    idx = 0
    for support_diffs in diffs_list:
        for query_diffs in diffs_list:
            for k in range(ne):
                diffs_support = support_diffs[k]
                diffs_query = query_diffs[k]
                geo_sims,_, _, geo_class_wise_score = compute_pure_pairwise_scores_weighted_from_two_resolution(
                    diffs_support, diffs_query,target, sup_batch_size, weight_geo[:, :, idx])
                # mean_diffs = []
                # for i in range(len(unique_labels)):
                #     mask = target_support == i
                #     mean_diffs.append(diffs_support[mask].sum(dim=0))
                # mean_diffs = F.normalize(torch.stack(mean_diffs), dim=1)
                # diffs_query = F.normalize(diffs_query, dim=1)
                # diff_sims = torch.einsum('ik,jk->ij', diffs_query, mean_diffs)
                diff_similarity_score.append(geo_class_wise_score)
                idx = idx+1


    return parts_similarity_score, diff_similarity_score

def compute_pure_pairwise_scores_weighted_from_two_resolution(featuress,featuresq,target, sup_batch_size,weight):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        pairwise scores
    """
    feat_support = featuress[:sup_batch_size]
    feat_query = featuresq[sup_batch_size:]

    # target_support = target[:sup_batch_size]
    feat_support = F.normalize(feat_support, dim=1)
    feat_query = F.normalize(feat_query, dim=1)
    feat_sims = torch.einsum('ik,jk->ij', feat_query, feat_support)
    feat_sims = feat_sims * weight
    # # extract the prototype by mean
    # mean_features = torch.einsum('ik,jk->ij', feat_query, feats_sims)
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    #
    target_support = uni_idx[:sup_batch_size]

    class_wise_score=[]
    for i in range(len(unique_labels)):
         mask = target_support == i
         class_wise_score.append(feat_sims[:,mask].sum(dim=1))
    class_wise_score = torch.stack(class_wise_score,dim=1)
    # mean_features = F.normalize(torch.stack(mean_features), dim=1)
    # feat_query = F.normalize(feat_query, dim=1)


    return feat_sims,unique_labels,uni_idx,class_wise_score #(query_size x class)

def compute_pure_protonet_scores(features, target, sup_batch_size):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """
    feat_support = features[:sup_batch_size]
    feat_query = features[sup_batch_size:]

    # target_support = target[:sup_batch_size]

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_features = []
    mean_diffs = []
    target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        mean_features.append(feat_support[mask].sum(dim=0))
    mean_features = F.normalize(torch.stack(mean_features), dim=1)
    feat_query = F.normalize(feat_query, dim=1)

    feat_sims = torch.einsum('ik,jk->ij', feat_query, mean_features)
    return feat_sims, unique_labels, uni_idx

def compute_pure_protonet_scores_from_two_resolution(featuress,featuresq, target, sup_batch_size, ifsupport=False):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """
    if ifsupport:
        feat_support = featuress
        feat_query = featuresq
    else:
        feat_support = featuress[:sup_batch_size]
        feat_query = featuresq[sup_batch_size:]

    # target_support = target[:sup_batch_size]

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_features = []
    if ifsupport:
        target_support = uni_idx
    else:
        target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        mean_features.append(feat_support[mask].sum(dim=0))
    mean_features = F.normalize(torch.stack(mean_features), dim=1)
    feat_query = F.normalize(feat_query, dim=1)

    feat_sims = torch.einsum('ik,jk->ij', feat_query, mean_features)
    return feat_sims, unique_labels, uni_idx

def compute_weighted_protonet_scores_from_two_resolution(featuress,featuresq, target, sup_batch_size, inner_weight, ifsupport=False):
    """
    Compute the protonet score given both features and part location differences

    Args:
        features (Tensor): the feature vectors
        diffs (Tensor):  the pairwise differences between all locations
        target (Tensor): the target labels
        sup_batch_size (int): number of support samples in each batch

    Returns:
        The protonet scores and the unique labels from target
    """
    if ifsupport:
        feat_support = featuress
        feat_query = featuresq
    else:
        feat_support = featuress[:sup_batch_size]
        feat_query = featuresq[sup_batch_size:]

    # target_support = target[:sup_batch_size]

    # extract the prototype by mean
    unique_labels, uni_idx = torch.unique(target, return_inverse=True)
    mean_features = []
    if ifsupport:
        target_support = uni_idx
    else:
        target_support = uni_idx[:sup_batch_size]
    for i in range(len(unique_labels)):
        mask = target_support == i
        #mean_features.append(feat_support[mask].sum(dim=0))
       # oo = sum(mask)
       # print(oo)
       # if oo > 1:
        mean_features.append(torch.einsum('ij,i->j',feat_support[mask],F.normalize(inner_weight[mask],dim=0)))
       # else:
       #     mean_features.append(feat_support[mask])
    mean_features = F.normalize(torch.stack(mean_features), dim=1)
    feat_query = F.normalize(feat_query, dim=1)

    feat_sims = torch.einsum('ik,jk->ij', feat_query, mean_features)
    return feat_sims, unique_labels, uni_idx

# unique_labels, uni_idx = torch.unique(target, return_inverse=True)
# target_support = uni_idx[:sup_batch_size]
# for i in range(len(unique_labels)):
#     mask = target_support == i
#     inner_weight = inner_support_weight[mask, :]
#     inner_weight = F.normalize(inner_weight, p=1, dim=0)
#     inner_support_weight[mask, :] = inner_weight