from collections import OrderedDict
from operator import mul

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
from pytorch_lightning.core.lightning import LightningModule
from torch import nn
from torchmetrics import Accuracy, ConfusionMatrix
from torchvision.models.vgg import vgg19
from math import comb
import itertools
# from cnn_embedding import ConvEmbedding

from util import plot_conf_mat, images_denorm, plot_attn_maps, plot_mu_prob, compute_pair_wise_diff, \
    compute_combined_protonet_scores, compute_pure_protonet_scores, compute_partbypart_combined_protonet_scores,compute_multiscale_partbypart_combined_protonet_scores, compute_pair_wise_diff_with_xycoord, compute_multiscale_partbypart_combined_protonet_scores_geo,compute_pair_wise_diff_with_xycoord_all, compute_multiscale_partbypart_combined_protonet_scores_geo_all

# An code example with 3 different scales and multiple part of parsing
NUM_CHANNEL = {
    'conv':64,
    'vgg19': 512,
    'resnet12': 640,
    'resnet18': 512,
    'resnet18_feti': 512,
    'resnet34': 512,
    'resnet50': 2048,
    'resnet101': 2048,
}


def create_backbone(backbone_model, pretrained):
    """
    Create the feature extractor from the backbone networks.
    Args:
        backbone_model (str): model of the backbone.
        pretrained (bool): whether the backbone is pretrained on ImageNet.

    Backbone_model:
        We support vgg19 and mainstream resnet variants. resnet18_feti is a special
        backbone used in mini-imagenet few-shot learning. It changes the kernel size
        of `conv1` to 5. The pretrained resnet18_feti backbone is usually pretrained
        on imagenet training set excluding classes from mini-imagenet.
        For FETI backbone details check https://github.com/peymanbateni/simple-cnaps.

    Returns:
        module of the backbone feature extractor.
    """
    supported_backbone = ['conv','vgg19', 'resnet12','resnet18', 'resnet18_feti', 'resnet34', 'resnet50', 'resnet101']
    assert backbone_model in supported_backbone, f'Backbone {backbone_model} not supported'

    if backbone_model == 'vgg19':
        feature_extractor = vgg19(pretrained=pretrained).features[:-1]
    elif backbone_model == 'conv':
        feature_extractor = ConvEmbedding()
    else:
        if backbone_model in ['resnet18', 'resnet18_feti']:
            tmp = models.resnet18(pretrained=pretrained)
        elif backbone_model == 'resnet34':
            tmp = models.resnet34(pretrained=pretrained)
        elif backbone_model == 'resnet50':
            tmp = models.resnet50(pretrained=pretrained)
        else:
            tmp = models.resnet101(pretrained=pretrained)

        # change the conv1 filter size if using feti backbone
        if backbone_model == 'resnet18_feti':
            conv1 = nn.Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2), padding=(3, 3), bias=False)
        else:
            conv1 = tmp.conv1

        feature_extractor = nn.Sequential(OrderedDict([
            ('conv1', conv1),
            ('bn1', tmp.bn1),
            ('relu', tmp.relu),
            ('maxpool', tmp.maxpool),
            ('layer1', tmp.layer1),
            ('layer2', tmp.layer2),
            ('layer3', tmp.layer3),
            ('layer4', tmp.layer4),
        ]))
    return feature_extractor


def create_2l_classifier(input_dim, hidden_dim, output_dim):
    """
    Create a two layer MLP classifier.

    Args:
        input_dim (int): input feature dimension.
        hidden_dim (int): hidden layer dimension.
        output_dim (int): output dimension

    Returns:
        module of the classifier
    """
    clf = nn.Sequential(OrderedDict([
        ('fc_1', nn.Linear(input_dim, hidden_dim)),
        ('relu_1', nn.ReLU(inplace=True)),
        ('fc_2', nn.Linear(hidden_dim, output_dim))
    ]))
    return clf

def create_weightmap(input_dim, hidden_dim, output_dim, mode='relu'):
    """
    Create a two layer MLP classifier.

    Args:
        input_dim (int): input feature dimension.
        hidden_dim (int): hidden layer dimension.
        output_dim (int): output dimension

    Returns:
        module of the classifier
    """
    if hidden_dim is None:
        # wm = nn.Sequential(OrderedDict([
        #     ('fc_1', nn.Linear(input_dim, output_dim)),
        #     ('sigmoid_1', nn.Sigmoid())
        # ]))

        wm = nn.Sequential(OrderedDict([
            ('fc_1', nn.Linear(input_dim, output_dim)),
            ('softmax_1', nn.Softmax(dim=1))
        ]))
        # wm = nn.Sequential(OrderedDict([
        #     ('fc_1', nn.Linear(input_dim, output_dim)),
        #     ('relu_1', nn.ReLU(inplace=True))
        # ]))
    else:
        wm = nn.Sequential(OrderedDict([
            ('fc_1', nn.Linear(input_dim, hidden_dim)),
            ('relu_1', nn.ReLU(inplace=True)),
            ('fc_2', nn.Linear(hidden_dim, output_dim)),
            ('relu_2',nn.ReLU(inplace=True))
        ]))
    return wm


def create_attn_module(num_channel, num_attn, conv_size):
    attn_module = nn.Sequential(OrderedDict([
        ('attn_conv', nn.Conv2d(num_channel, num_attn, kernel_size=(conv_size, conv_size),
                                padding=conv_size//2)),
        ('relu', nn.ReLU())
    ]))
    return attn_module


class DisLoss(nn.Module):
    """Distance loss"""

    def __init__(self, num_attn):
        super(DisLoss, self).__init__()
        self.num_attn = num_attn

    def forward(self, Ms):
        h = Ms.shape[-1]
        M_tmp = Ms.view(Ms.shape[0], Ms.shape[1], -1)
        # fin the peak pixel coordinates
        peak = M_tmp.argmax(dim=-1, keepdim=True).detach()
        tx = (peak // h).unsqueeze(2).float()
        ty = (peak % h).unsqueeze(3).float()
        # form the distance weights to peak
        x = torch.arange(end=h, device=tx.device).reshape(1, 1, h, 1)
        y = torch.arange(end=h, device=ty.device).reshape(1, 1, 1, h)
        dis_weight = (x - tx).pow(2) + (y - ty).pow(2)
        dis_loss = (Ms * dis_weight).sum(dim=(1, 2, 3)).mean(dim=0)
        return dis_loss


class DivLoss(nn.Module):
    """Divergence loss"""

    def __init__(self, num_attn):
        super(DivLoss, self).__init__()
        self.num_attn = num_attn

    def forward(self, mu_probs):
        loss = 0
        for i in range(self.num_attn):
            cur_prob = mu_probs[:, i]
            rest_mask = torch.ones(self.num_attn, dtype=torch.bool, device=mu_probs.device)
            rest_mask[i] = False
            rest_probs = mu_probs[:, rest_mask]
            rest_probs_max = rest_probs.max(dim=1)[0]
            l = torch.sqrt((cur_prob * rest_probs_max)).sum(dim=(1, 2)).mean()
            loss += l
        return loss / self.num_attn


class ProtoLoss(nn.Module):
    """Loss for proto-net FSL training"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature):
        super(ProtoLoss, self).__init__()
        self.num_support_tr = num_support_tr
        self.num_query_tr = num_query_tr
        self.classes_per_it_tr = classes_per_it_tr
        self.sup_batch_size = self.num_support_tr * self.classes_per_it_tr
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, Ps, target):
        Ps= Ps.view(Ps.shape[0], -1)
        Ps_support = Ps[:self.sup_batch_size]
        Ps_query = Ps[self.sup_batch_size:]

        unique_labels, uni_idx = torch.unique(target, return_inverse=True)
        target_support = uni_idx[:self.sup_batch_size]
        mean_vectors = []
        for i in range(len(unique_labels)):
            mask = target_support == i
            mean_vectors.append(Ps_support[mask].mean(dim=0))
        mean_vectors = torch.stack(mean_vectors)
        mean_vectors = F.normalize(mean_vectors, dim=1)
        Ps_query = F.normalize(Ps_query, dim=1)

        logits = torch.einsum('ik,jk->ij', Ps_query, mean_vectors) / self.temperature
        target_query = uni_idx[self.sup_batch_size:]
        return self.criterion(logits, target_query)


class CombinedProtoLoss(ProtoLoss):
    """Loss for proto-net combining the features and locations"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature, num_attn, alpha):
        super(CombinedProtoLoss, self).__init__(
            num_support_tr, num_query_tr, classes_per_it_tr, temperature)
        self.alpha = alpha
        self.num_attn = num_attn

    def forward(self, Ps, mus, target):
        Ps = Ps.view(Ps.shape[0], -1)
        diffs = compute_pair_wise_diff(mus, self.num_attn)
        feat_sims, diff_sims, unique_labels, uni_idx = compute_combined_protonet_scores(
            Ps, diffs, target, self.sup_batch_size)
        logit_feat = feat_sims / self.temperature
        logit_diff = diff_sims / self.temperature
        target_query = uni_idx[self.sup_batch_size:]
        loss_feat = self.criterion(logit_feat, target_query)
        loss_diff = self.criterion(logit_diff, target_query)
        return loss_feat + self.alpha * loss_diff


class MultiscaleWeightedPairwiseCombinedProtoLoss(ProtoLoss):
    """Loss for proto-net combining the features and locations"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature, num_attn, alpha):
        super(MultiscaleWeightedPairwiseCombinedProtoLoss, self).__init__(
            num_support_tr, num_query_tr, classes_per_it_tr, temperature)
        self.alpha = alpha
        self.num_attn = num_attn
        self.criterion = nn.CrossEntropyLoss(reduction='none')

    def forward(self, Pss, Pss1, Pss2, mus, target, weights):
        #Ps = Ps.view(Ps.shape[0], -1)
        diffs = compute_pair_wise_diff(mus, self.num_attn)
        Pss_list = [Pss, Pss1, Pss2]
        parts_similarity_score, diff_sims, unique_labels, uni_idx = compute_multiscale_partbypart_combined_protonet_scores(
            Pss_list, diffs, target, self.sup_batch_size)

        target_query = uni_idx[self.sup_batch_size:]
        weight_query = weights[self.sup_batch_size:]
        loss_features = 0
        na = len(parts_similarity_score)
        for i in range(na):
            logit_feat = parts_similarity_score[i] / self.temperature
            loss_feat = self.criterion(logit_feat, target_query)
            loss_feat = torch.dot(loss_feat,weight_query[:,i])/weight_query.shape[0]
            #loss_feat = torch.einsum('ij,i->ij',loss_feat,weight_query[:,i])
            loss_features = loss_features+loss_feat
        logit_diff = diff_sims / self.temperature
        loss_diff = self.criterion(logit_diff, target_query)
        loss_diff = torch.mean(loss_diff)
        return loss_features + self.alpha * loss_diff

class MultiscaleWeightedPairwiseGeometricalCombinedProtoLoss(ProtoLoss):
    """Loss for proto-net combining the features and locations"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature, num_attn, alpha):
        super(MultiscaleWeightedPairwiseGeometricalCombinedProtoLoss, self).__init__(
            num_support_tr, num_query_tr, classes_per_it_tr, temperature)
        self.alpha = alpha
        self.num_attn = num_attn
        # adding the weights after exponential
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        # adding the weights before exponential
        #self.criterion = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, Pss, Pss1, Pss2, mus, target, weights, weights_geo):
        #Ps = Ps.view(Ps.shape[0], -1)
        diffs_list = compute_pair_wise_diff_with_xycoord(mus, self.num_attn)
        Pss_list = [Pss, Pss1, Pss2]
        parts_similarity_score, diff_similarity_score, unique_labels, uni_idx = compute_multiscale_partbypart_combined_protonet_scores_geo(
            Pss_list, diffs_list, target, self.sup_batch_size)

        target_query = uni_idx[self.sup_batch_size:]
        weight_query = weights[self.sup_batch_size:]
        weight_query_geo = weights_geo[self.sup_batch_size:]

        loss_features = 0
        na = len(parts_similarity_score)
        for i in range(na):
            logit_feat = parts_similarity_score[i] / self.temperature
            # adding the weights before exponential (need to change reduction to mean in criterion)
            # logit_feat = torch.einsum('ij,i->ij', logit_feat, weight_query[:, i])
            loss_feat = self.criterion(logit_feat, target_query)
            # adding the weights after exponential (need to change reduction to none in criterion)
            loss_feat = torch.dot(loss_feat,weight_query[:,i])/weight_query.shape[0]
            loss_features = loss_features+loss_feat
        # number of edge of the geometry
        ne = len(diff_similarity_score)
        loss_diffs = 0
        for i in range(ne):
            logit_diff = diff_similarity_score[i] / self.temperature
            #logit_diff = torch.einsum('ij,i->ij', logit_diff, weight_query_geo[:, i])
            loss_diff = self.criterion(logit_diff, target_query)
            loss_diff = torch.dot(loss_diff, weight_query_geo[:, i]) / weight_query_geo.shape[0]
            loss_diffs = loss_diffs+loss_diff
        return loss_features + self.alpha * loss_diff



class MultiscaleWeightedPairwiseGeometricalCombinedProtoLossAll(ProtoLoss):
    """Loss for proto-net combining the features and locations"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature, num_attn, alpha):
        super(MultiscaleWeightedPairwiseGeometricalCombinedProtoLossAll, self).__init__(
            num_support_tr, num_query_tr, classes_per_it_tr, temperature)
        self.alpha = alpha
        self.num_attn = num_attn
        # adding the weights after exponential
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        # adding the weights before exponential
        #self.criterion = nn.CrossEntropyLoss(reduction='mean')

    def forward(self, Pss, Pss1, Pss2, mus, target, weights, weights_geo):
        #Ps = Ps.view(Ps.shape[0], -1)
        diffs_list = compute_pair_wise_diff_with_xycoord_all(mus, self.num_attn)
        Pss_list = [Pss, Pss1, Pss2]
        parts_similarity_score, diff_similarity_score, unique_labels, uni_idx = compute_multiscale_partbypart_combined_protonet_scores_geo_all(
            Pss_list, diffs_list, target, self.sup_batch_size)

        target_query = uni_idx[self.sup_batch_size:]
        weight_query = weights ##[self.sup_batch_size:]
        weight_query_geo = weights_geo ##[self.sup_batch_size:]

        loss_features = 0
        na = len(parts_similarity_score)
        for i in range(na):
            logit_feat = parts_similarity_score[i] / self.temperature
            # adding the weights before exponential (need to change reduction to mean in criterion)
            # logit_feat = torch.einsum('ij,i->ij', logit_feat, weight_query[:, i])
            loss_feat = self.criterion(logit_feat, target_query)
            # adding the weights after exponential (need to change reduction to none in criterion)
            loss_feat = torch.dot(loss_feat,weight_query[:,i])/weight_query.shape[0]
            loss_features = loss_features+loss_feat
        # number of edge of the geometry
        ne = len(diff_similarity_score)
        loss_diffs = 0
        for i in range(ne):
            logit_diff = diff_similarity_score[i] / self.temperature
            #logit_diff = torch.einsum('ij,i->ij', logit_diff, weight_query_geo[:, i])
            loss_diff = self.criterion(logit_diff, target_query)
            loss_diff = torch.dot(loss_diff, weight_query_geo[:, i]) / weight_query_geo.shape[0]
            loss_diffs = loss_diffs+loss_diff
        return loss_features + self.alpha * loss_diff





class ProtoPlusLoss(ProtoLoss):
    """Loss for proto-net++ FSL training"""

    def __init__(self, num_support_tr, num_query_tr, classes_per_it_tr, temperature):
        super(ProtoPlusLoss, self).__init__(num_support_tr, num_query_tr,
                                            classes_per_it_tr, temperature)

    def forward(self, Ps, target):
        Ps = Ps.view(Ps.shape[0], -1)
        Ps = F.normalize(Ps, p=2, dim=1)
        Ps_support = Ps[:self.sup_batch_size]
        Ps_query = Ps[self.sup_batch_size:]

        unique_labels, uni_idx = torch.unique(target, return_inverse=True)
        target_support = uni_idx[:self.sup_batch_size]
        target_query = uni_idx[self.sup_batch_size:]

        logits = torch.einsum('ik,jk->ij', Ps_query, Ps_support) / self.temperature
        log_p_y = F.log_softmax(logits, dim=1)
        mask = target_query.unsqueeze(1) == target_support.unsqueeze(0)

        loss = -log_p_y[mask].sum() / self.num_support_tr / len(target_query)
        return loss


class FeatLoss(nn.Module):
    """Feature module loss"""

    def __init__(self, num_attn, loss_weights):
        super(FeatLoss, self).__init__()
        self.num_attn = num_attn
        assert sum(loss_weights) > 0
        self.loss_weights = [w / sum(loss_weights) for w in loss_weights]
        self.div_loss_func = DivLoss(num_attn)
        self.cls_loss_func = nn.CrossEntropyLoss(reduction='mean')
        self.losses = []

    def forward(self, mu_probs, pred_logits, labels, weight=None):
        loss_div = self.div_loss_func(mu_probs)
        loss_cls = self.cls_loss_func(pred_logits, labels)
        loss_list = list(map(mul, self.loss_weights, [loss_div, loss_cls]))
        self.losses = loss_list
        return sum(loss_list)


class FSLFeatLoss(FeatLoss):
    """FSL for feature module"""

    def __init__(self, num_attn, loss_weights, num_support_tr, num_query_tr,
                 classes_per_it_tr, temp_proto, alpha):
        super(FSLFeatLoss, self).__init__(num_attn, loss_weights)
        self.num_support_tr = num_support_tr
        self.num_qeury_tr = num_query_tr
        self.classes_per_it_tr = classes_per_it_tr
        self.temp_proto = temp_proto
        self.alpha = alpha
        # self.proto_loss_func = CombinedProtoLoss(num_support_tr, num_query_tr, classes_per_it_tr,
        #                                          temp_proto, num_attn, alpha)
        #self.proto_loss_func = WeightedPairwiseCombinedProtoLoss(num_support_tr, num_query_tr, classes_per_it_tr,
        #                                         temp_proto, num_attn, alpha)
        # self.proto_loss_func = MultiscaleWeightedPairwiseCombinedProtoLoss(num_support_tr, num_query_tr, classes_per_it_tr,
        #                                                          temp_proto, num_attn, alpha)
        self.proto_loss_func = MultiscaleWeightedPairwiseGeometricalCombinedProtoLossAll(num_support_tr, num_query_tr,
                                                                           classes_per_it_tr,
                                                                           temp_proto, num_attn, alpha)

    def forward(self, mus, mu_probs, labels, Pss, Pss1, Pss2, weight, weight_geo):
       # loss_div = self.div_loss_func(mu_probs)
        loss_div = (self.div_loss_func(mu_probs[0]) + self.div_loss_func(mu_probs[1]) + self.div_loss_func(mu_probs[2]) )/3
        loss_proto = self.proto_loss_func(Pss, Pss1, Pss2, mus, labels,weight, weight_geo)
        loss_list = list(map(mul, self.loss_weights, [loss_div, loss_proto]))
        self.losses = loss_list
        return sum(loss_list)


class FeatModel(LightningModule):
    """
    The lightning module for the main model
    """

    def __init__(self, hparams, *args, **kwargs):
        super(FeatModel, self).__init__()
        self.hparams.update(vars(hparams))
        self.save_hyperparameters(hparams)
        nc = NUM_CHANNEL[self.hparams.backbone_model]
        self.nc = nc
        self.feat_extractor = create_backbone(self.hparams.backbone_model, self.hparams.pretrain_backbone)
        # matching mu_conv_size to att_szie
        self.feat_module = create_attn_module(nc, self.hparams.num_attn, self.hparams.attn_conv_size)
        self.feat_module1 = create_attn_module(nc, self.hparams.num_attn, self.hparams.attn_conv_size1)
        self.feat_module2 = create_attn_module(nc, self.hparams.num_attn, self.hparams.attn_conv_size2)

        self.attn_conv = nn.ModuleList([nn.Conv2d(1, nc * self.hparams.num_theta,
                                                  kernel_size=self.hparams.attn_conv_size,
                                                  padding=self.hparams.attn_conv_size // 2) for _ in range(self.hparams.num_attn)])
        self.attn_conv1 = nn.ModuleList([nn.Conv2d(1, nc * self.hparams.num_theta,
                                                  kernel_size=self.hparams.attn_conv_size1,
                                                  padding=self.hparams.attn_conv_size1 // 2) for _ in
                                        range(self.hparams.num_attn)])
        self.attn_conv2 = nn.ModuleList([nn.Conv2d(1, nc * self.hparams.num_theta,
                                                  kernel_size=self.hparams.attn_conv_size2,
                                                  padding=self.hparams.attn_conv_size2 // 2) for _ in
                                        range(self.hparams.num_attn)])
        self.weight_map = create_weightmap(self.hparams.num_attn * 3 * 2, None, self.hparams.num_attn*9)
        self.geo_weight_map = create_weightmap(self.hparams.num_attn * 3 *2, None, comb(self.hparams.num_attn,2)*3*3)
        self.clf = create_2l_classifier(nc * self.hparams.num_attn, 1024, self.hparams.num_class)
        self.acc_metric = Accuracy()
        self.cm_metric = ConfusionMatrix(self.hparams.num_class)
        self.criterion = FeatLoss(self.hparams.num_attn, self.hparams.loss_weights)
        self.mrg = self.hparams.shrinkage_mrg

    def feat_forward(self, x):
        feat_maps = self.feat_extractor(x)
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]

        # predict mu's probability
        temp = self.feat_module(feat_maps)
        temp1 = temp.view(temp.shape[0], temp.shape[1], -1)
        temp1 = self.hparams.cut_th * temp1.max(dim=-1, keepdim=True)[0]
        temp1 = temp1.unsqueeze(3)
        temp = torch.where(temp > temp1, temp, torch.zeros_like(temp))
        mu_prob = torch.clip(temp / self.hparams.mu_softmax_temp, max=20)
        mu_prob = torch.exp(mu_prob)
        mu_prob = mu_prob / mu_prob.sum(dim=(2, 3), keepdim=True)
        mu_entropy = (-mu_prob * torch.log(mu_prob)).sum(axis=(-1, -2))
       # print(mu_entropy.shape)
       # mu_weight = self.weight_map(mu_entropy)
       # print(mu_weight)

        # compute mu's mean
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        mu_x = (mu_prob * maps_x).sum(dim=(-1, -2))  # nb x na
        mu_y = (mu_prob * maps_y).sum(dim=(-1, -2))  # nb x na

        # form the Gaussian distribution of mu
        mus = torch.cat([mu_x, mu_y], dim=1)
        mu_x = mu_x.view(nb, na, 1, 1)
        mu_y = mu_y.view(nb, na, 1, 1)
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        sigma = self.hparams.sigma
        Ms = torch.exp(-.5 * (((maps_x - mu_x) / sigma).pow(2) + ((maps_y - mu_y) / sigma).pow(2))) / sigma ** 2

        # apply convolution and attention pooling
        feats = []
        for i in range(na):
            attn_map = torch.sigmoid(self.attn_conv[i](Ms[:, i].unsqueeze(1)))
            attn_map = attn_map.view(nb, self.nc, self.hparams.num_theta, w, h)
            attn_norm = torch.norm(attn_map, p='fro', dim=(3, 4), keepdim=True)
            attn_map = attn_map / attn_norm
            feat = (feat_maps.unsqueeze(dim=2) * attn_map).sum(dim=(-1, -2))  # [nb x nc x num_theta]
            feat = feat.view(nb, -1)
            feats.append(feat)
        Ps = torch.stack(feats, dim=2)

        mrg, _ = Ps.max(dim=1, keepdim=True)
        mrg = mrg * self.mrg
        Ps = torch.sign(Ps) * torch.clip(Ps.abs() - mrg, min=0)
        return mus, mu_prob, Ms, Ps

    def forward(self, x):
        nb = x.shape[0]  # batch size
        mus, mu_prob, Ms, Ps = self.feat_forward(x)
        total_feats = Ps.view(nb, -1)  # fuse the features
        cls_total = self.clf(total_feats)
        return mus, mu_prob, Ms, cls_total

    def training_step(self, batch, batch_idx):
        data, target = batch
        mus, mu_prob, Ms, logits = self(data)
        loss = self.criterion(mu_prob, logits, target)
        self.log('train/div_loss', self.criterion.losses[0], prog_bar=True)
        self.log('train/cls_loss', self.criterion.losses[1], prog_bar=True)
        self.log('train/total_loss', loss)
        if (batch_idx + 1) % 20 == 0:
            self.vis_attn_maps(data, Ms, mu_prob, 2)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        preds = self(data)[-1].argmax(dim=1)
        acc = self.acc_metric(preds, target)
        self.cm_metric(preds, target)
        self.log('val_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds

    def validation_epoch_end(self, outputs):
        conf_mat = self.cm_metric.compute()
        tb = self.logger.experiment
        conf_mat_fig = plot_conf_mat(conf_mat.cpu().detach().numpy(),
                                     np.arange(self.hparams.num_class),
                                     fig_size=(17, 15), label_font_size=5, show_color_bar=True)
        tb.add_figure('Conf Mat/val', conf_mat_fig, self.global_step)
        self.cm_metric.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr,
                                     weight_decay=self.hparams.weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=self.hparams.scheduler_step_size,
                                                    gamma=self.hparams.scheduler_gamma)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

    @staticmethod
    def add_model_arch_args(parent_parser):
        parser = parent_parser.add_argument_group('model_arch')
        parser.add_argument('--backbone_model', type=str, default='resnet34')
        parser.add_argument('--pretrain_backbone', type=bool, default=False)
        parser.add_argument('--num_class', type=int, default=200)
        parser.add_argument('--input_size', type=int, default=224)
        parser.add_argument('--num_attn', type=int, default=3)
        parser.add_argument('--mode', type=str, default='a')
        parser.add_argument('--attn_conv_size', type=int, default=1)
        parser.add_argument('--attn_conv_size1', type=int, default=3)
        parser.add_argument('--attn_conv_size2', type=int, default=5)
        parser.add_argument('--mu_conv_size', type=int, default=5)
        parser.add_argument('--mu_softmax_temp', type=float, default=0.05)
        parser.add_argument('--cut_th', type=float, default=0.9)
        parser.add_argument('--shrinkage_mrg', type=float, default=0.05)
        parser.add_argument('--sigma', type=float, default=0.5)
        parser.add_argument('--num_theta', type=int, default=1)
        return parent_parser

    @staticmethod
    def add_model_loss_args(parent_parser):
        parser = parent_parser.add_argument_group('model_loss')
        parser.add_argument('--loss_weights', nargs=2, type=float, default=[1., 2.])
        return parent_parser

    @staticmethod
    def add_model_train_args(parent_parser):
        parser = parent_parser.add_argument_group('model_train')
        parser.add_argument('--lr', type=float, default=1e-4)
        parser.add_argument('--weight_decay', type=float, default=1e-5)
        parser.add_argument('--scheduler_step_size', type=int, default=8)
        parser.add_argument('--scheduler_gamma', type=float, default=0.5)
        return parent_parser

    def vis_attn_maps(self, data, Ms, mu_prob, n):
        data_vis = data.detach().cpu().numpy()
        maps_vis = Ms.detach().cpu().numpy()
        mu_prob_vis = mu_prob.detach().cpu().numpy()
        if n < data.shape[0]:
            data_vis = data_vis[:n]
            maps_vis = maps_vis[:n]
            mu_prob_vis = mu_prob_vis[:n]
        data_vis = images_denorm(data_vis)
        # transpose to numpy image format
        data_vis = data_vis.transpose((0, 2, 3, 1))
        attn_fig = plot_attn_maps(data_vis, maps_vis)
        mu_prob_fig = plot_mu_prob(data_vis, mu_prob_vis)
        tb = self.logger.experiment
        tb.add_figure('Train/Attn_maps', attn_fig, self.global_step)
        tb.add_figure('Mu_prob', mu_prob_fig, self.global_step)


class FSLFeatModel(FeatModel):

    def __init__(self, hparams, *args, **kwargs):
        super(FSLFeatModel, self).__init__(hparams)
        self.hparams.update(vars(hparams))
        self.save_hyperparameters(hparams)
        self.num_support_val = self.hparams.num_support_val
        self.num_query_val = self.hparams.num_query_val
        self.sup_batch_size = self.num_support_val * self.hparams.classes_per_it_val
        self.sup_batch_size_train = self.hparams.num_support_tr * self.hparams.classes_per_it_tr
        self.mode = hparams.mode
        if self.hparams.train_mode == 'episode':
            self.criterion = FSLFeatLoss(self.hparams.num_attn,
                                         self.hparams.loss_weights,
                                         self.hparams.num_support_tr,
                                         self.hparams.num_query_tr,
                                         self.hparams.classes_per_it_tr,
                                         self.hparams.temp_proto,
                                         self.hparams.alpha)

    def update_evaluate_param(self, num_support_val, classes_per_it_val):
        self.num_support_val = num_support_val
        self.sup_batch_size = self.num_support_val * classes_per_it_val

    def set_alpha(self, alpha):
        self.alpha = alpha
    def feat_forward(self, x, sup_batch_size):
        feat_maps = self.feat_extractor(x)
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]
        mu_prob, mus, Ms = self.compute_attention(feat_maps, self.feat_module, padding_size=0,mode=self.mode)
        mu_prob1, mus1, Ms1 = self.compute_attention(feat_maps, self.feat_module1,padding_size=1,mode=self.mode)
        mu_prob2, mus2, Ms2 = self.compute_attention(feat_maps, self.feat_module2,padding_size=2,mode=self.mode)

        mu_entropy = (-mu_prob * torch.log(mu_prob)).sum(axis=(-1, -2))
        mu_entropy1 = (-mu_prob1 * torch.log(mu_prob1)).sum(axis=(-1, -2))
        mu_entropy2 = (-mu_prob2 * torch.log(mu_prob2)).sum(axis=(-1, -2))
        mu_entropy = torch.cat([mu_entropy,mu_entropy1,mu_entropy2],dim=1)

        mu_entropy_support = mu_entropy[:sup_batch_size]
        mu_entropy_query = mu_entropy[sup_batch_size:]
        n_support = mu_entropy_support.shape[0]
        n_query = mu_entropy_query.shape[0]
        ind_list = []
        for i in itertools.product(range(n_query),range(n_support)):
            ind_list.append(list(i))
        ind = np.array(ind_list)
        mu_entropy_q = mu_entropy_query[ind[:,0]]
        mu_entropy_s = mu_entropy_support[ind[:,1]]
        mu_entropy = torch.cat([mu_entropy_q, mu_entropy_s],dim=1)


        mu_weight = self.weight_map(mu_entropy)
        geo_weight = self.geo_weight_map(mu_entropy)
        mu_weight = mu_weight.view(-1,sup_batch_size,mu_weight.shape[1])
        geo_weight = geo_weight.view(-1, sup_batch_size, geo_weight.shape[1])

        mu_weight = torch.sum(mu_weight, dim=1)
        mu_weight = F.normalize(mu_weight,p=2,dim=1)
        geo_weight = torch.sum(geo_weight, dim=1)
        geo_weight = F.normalize(geo_weight,p=2,dim=1)

        # mu_weight = self.weight_map(mu_entropy)
        # geo_weight = self.geo_weight_map(mu_entropy)

        Ms_list = [Ms, Ms1, Ms2]
        mu_prob_list = [mu_prob,mu_prob1,mu_prob2]
        mus_list = [mus, mus1, mus2]
        #print(mu_weight)
        # conv_weight = feat_module.attn_conv.weight.detach()  #d
        # norm_thetha = torch.linalg.norm(conv_weight, dim=(2, 3), keepdim=True)
        # feat_module.attn_conv.weight.data = feat_module.attn_conv.weight.data / (norm_thetha+1e-6)

        # case f
        if self.mode=='f1' or self.mode=='f2':
            for i in range(na):
                # a = self.feat_module.attn_conv.weight.data.clone()[i,:,:,:]
                # a = a.unsqueeze(0)
                # a = a.view(self.attn_conv[i].weight.shape)
                self.attn_conv[i].weight.data = self.feat_module.attn_conv.weight.data.clone()[i,:,:,:].unsqueeze(0).view(self.attn_conv[i].weight.shape)
                self.attn_conv1[i].weight.data = self.feat_module1.attn_conv.weight.data.clone()[i, :, :, :].unsqueeze(0).view(self.attn_conv1[i].weight.shape)
                self.attn_conv2[i].weight.data = self.feat_module2.attn_conv.weight.data.clone()[i, :, :, :].unsqueeze(0).view(self.attn_conv2[i].weight.shape)

        # apply convolution and attention pooling
        Pss = self.decompute_attention(self.attn_conv,na,nb,w,h,feat_maps,Ms)
        Pss1 = self.decompute_attention(self.attn_conv1, na, nb, w, h, feat_maps, Ms1)
        Pss2 = self.decompute_attention(self.attn_conv2, na, nb, w, h, feat_maps, Ms2)


        return mus_list, mu_prob_list, Ms_list, Pss, Pss1, Pss2, mu_weight, geo_weight


    # def feat_forward_test(self, x, sup_batch_size):
    #     feat_maps = self.feat_extractor(x)
    #     nb = feat_maps.shape[0]
    #     na = self.hparams.num_attn
    #     w, h = feat_maps.shape[2:4]
    #     mu_prob, mus, Ms, raw_mu,mu_round = self.compute_attention_test(feat_maps, self.feat_module)
    #     mu_prob1, mus1, Ms1, raw_mu1,mu_round1 = self.compute_attention_test(feat_maps, self.feat_module1)
    #     mu_prob2, mus2, Ms2, raw_mu2, mu_round2 = self.compute_attention_test(feat_maps, self.feat_module2)
    #
    #     Mss = []
    #     Mss1 = []
    #     Mss2 = []
    #     for i in range(na):
    #         batch_num = torch.arange(feat_maps.shape[0]).long()
    #
    #         x, y = mu_round[:,0,i].long(), mu_round[:,1,i].long()
    #         feat_maps_fetch = feat_maps[batch_num, :, x, y]
    #
    #         #feat_maps_fetch = feat_maps_fetch.unsqueeze(2)
    #         #feat_maps_fetch = feat_maps_fetch.unsqueeze(3)
    #         Mss.append(feat_maps_fetch)
    #
    #         x, y = mu_round1[:, 0, i].long(), mu_round1[:, 1, i].long()
    #         feat_maps_fetch = feat_maps[batch_num, :, x, y]
    #         Mss1.append(feat_maps_fetch)
    #
    #         x, y = mu_round2[:, 0, i].long(), mu_round2[:, 1, i].long()
    #         feat_maps_fetch = feat_maps[batch_num, :, x, y]
    #         Mss2.append(feat_maps_fetch)
    #
    #     mu_entropy = (-mu_prob * torch.log(mu_prob)).sum(axis=(-1, -2))
    #     mu_entropy1 = (-mu_prob1 * torch.log(mu_prob1)).sum(axis=(-1, -2))
    #     mu_entropy2 = (-mu_prob2 * torch.log(mu_prob2)).sum(axis=(-1, -2))
    #     mu_entropy = torch.cat([mu_entropy,mu_entropy1,mu_entropy2],dim=1)
    #
    #     mu_entropy_support = mu_entropy[:sup_batch_size]
    #     mu_entropy_query = mu_entropy[sup_batch_size:]
    #     n_support = mu_entropy_support.shape[0]
    #     n_query = mu_entropy_query.shape[0]
    #     ind_list = []
    #     for i in itertools.product(range(n_query),range(n_support)):
    #         ind_list.append(list(i))
    #     ind = np.array(ind_list)
    #     mu_entropy_q = mu_entropy_query[ind[:,0]]
    #     mu_entropy_s = mu_entropy_support[ind[:,1]]
    #     mu_entropy = torch.cat([mu_entropy_q, mu_entropy_s],dim=1)
    #
    #
    #     mu_weight = self.weight_map(mu_entropy)
    #     geo_weight = self.geo_weight_map(mu_entropy)
    #     mu_weight = mu_weight.view(-1,sup_batch_size,mu_weight.shape[1])
    #     geo_weight = geo_weight.view(-1, sup_batch_size, geo_weight.shape[1])
    #
    #     mu_weight = torch.sum(mu_weight, dim=1)
    #     mu_weight = F.normalize(mu_weight,p=2,dim=1)
    #     geo_weight = torch.sum(geo_weight, dim=1)
    #     geo_weight = F.normalize(geo_weight,p=2,dim=1)
    #
    #     # mu_weight = self.weight_map(mu_entropy)
    #     # geo_weight = self.geo_weight_map(mu_entropy)
    #
    #     Ms_list = [Ms, Ms1, Ms2]
    #     mu_prob_list = [mu_prob,mu_prob1,mu_prob2]
    #     mus_list = [mus, mus1, mus2]
    #     #print(mu_weight)
    #
    #     # D = self.attn_conv[i]
    #     # apply convolution and attention pooling
    #     Pss, feat_res = self.decompute_attention_test(self.attn_conv,na,nb,w,h,feat_maps,Ms,0)
    #     Pss1, feat_res1 = self.decompute_attention_test(self.attn_conv1, na, nb, w, h, feat_maps, Ms1,1)
    #     Pss2, feat_res2 = self.decompute_attention_test(self.attn_conv2, na, nb, w, h, feat_maps, Ms2,2)
    #
    #
    #     for i in range(na):
    #         #print(Mss[i].shape)
    #         temp0 = torch.square(Mss[i] - feat_res[i])
    #         temp = torch.mean(temp0,dim=0,keepdim=True)
    #        # print(temp0/(temp+0.001))
    #         idx = temp0/(temp+0.001)<0.1
    #         inner_weight = torch.ones_like(Pss[i])
    #         inner_weight[idx]=0.9
    #         Pss[i] = Pss[i] * inner_weight
    #
    #         #temp0 = torch.square(Mss1[i] - Pss1[i])
    #         temp0 = torch.square(Mss1[i] - feat_res1[i])
    #         temp = torch.mean(temp0, dim=0, keepdim=True)
    #        # print(temp0 / (temp + 0.001))
    #         idx = temp0 / (temp + 0.001) < 0.05
    #         inner_weight = torch.ones_like(Pss1[i])
    #         inner_weight[idx] = 0.9
    #         Pss1[i] = Pss1[i] * inner_weight
    #
    #         #temp0 = torch.square(Mss2[i] - Pss2[i])
    #         temp0 = torch.square(Mss2[i] - feat_res2[i])
    #         temp = torch.mean(temp0, dim=0, keepdim=True)
    #        # print(temp0 / (temp + 0.001))
    #         idx = temp0 / (temp + 0.001) < 0.05
    #         inner_weight = torch.ones_like(Pss2[i])
    #         inner_weight[idx] = 0.9
    #         Pss2[i] = Pss2[i] * inner_weight
    #
    #
    #
    #     return mus_list, mu_prob_list, Ms_list, Pss, Pss1, Pss2, mu_weight, geo_weight

    def decompute_attention(self,attn_conv,na,nb,w,h,feat_maps,Ms):
        feats = []
        for i in range(na):
            attn_map = torch.sigmoid(attn_conv[i](Ms[:, i].unsqueeze(1)))
            attn_map = attn_map.view(nb, self.nc, self.hparams.num_theta, w, h)
            attn_norm = torch.norm(attn_map, p='fro', dim=(3, 4), keepdim=True)
            attn_map = attn_map / attn_norm
            feat = (feat_maps.unsqueeze(dim=2) * attn_map).sum(dim=(-1, -2)) # [nb x nc x num_theta]
            # part_feat = feat_maps.unsqueeze(dim=2) * attn_map
            feat = feat.view(nb, -1)
            feats.append(feat)
        Ps = torch.stack(feats, dim=2)
        mrg, _ = Ps.max(dim=1, keepdim=True)
        mrg = mrg * self.mrg

        Ps = torch.sign(Ps) * torch.clip(Ps.abs() - mrg, min=0)
        #print(Ps.shape)
        Pss = []
        for i in range(na):
            Pss.append(Ps[:, :, i])
        return  Pss

    # def decompute_attention_test(self,attn_conv,na,nb,w,h,feat_maps,Ms,number):
    #     feats = []
    #     feats_res = []
    #     for i in range(na):
    #         #print(Ms[:, i].shape)
    #         attn_map = torch.sigmoid(attn_conv[i](Ms[:, i].unsqueeze(1)))
    #         attn_map = attn_map.view(nb, self.nc, self.hparams.num_theta, w, h)
    #         attn_norm = torch.norm(attn_map, p='fro', dim=(3, 4), keepdim=True)
    #         attn_map = attn_map / attn_norm
    #         # print(Ms[:, i].unsqueeze(1).shape)
    #         #
    #         feat = (feat_maps.unsqueeze(dim=2) * attn_map).sum(dim=(-1, -2)) # [nb x nc x num_theta]
    #         # print(attn_conv[i].weight.shape)
    #         # print(attn_conv[i].bias.shape)
    #         # print(feat.shape)
    #         theta = attn_conv[i].weight.squeeze(dim=1)
    #        # part_feat = feat_maps.unsqueeze(dim=2) * attn_map
    #        # print(part_feat.shape)
    #         #part_feat = attn_conv[i](part_feat)
    #         # print(part_feat.shape)
    #         feat = feat.view(nb, -1)
    #         theta_b = attn_conv[i].bias
    #         feats.append(feat)
    #         feat_re = torch.einsum('ij,jab->ijab',feat,theta)
    #         #print(feat_re.shape)
    #         theta_b = torch.reshape(theta_b,(1,theta_b.shape[0],1,1))
    #         # print(theta_b.shape)
    #         feat_re = feat_re + theta_b
    #         feats_res.append(feat_re[:,:,number,number])
    #         # print(feat_re.shape)
    #         # print(feat_re.shape)
    #     Ps = torch.stack(feats, dim=2)
    #     mrg, _ = Ps.max(dim=1, keepdim=True)
    #     mrg = mrg * self.mrg
    #
    #     Ps = torch.sign(Ps) * torch.clip(Ps.abs() - mrg, min=0)
    #     #print(Ps.shape)
    #     Pss = []
    #     for i in range(na):
    #         Pss.append(Ps[:, :, i])
    #     return  Pss, feats_res


    def compute_attention(self, feat_maps, feat_module, padding_size=0,mode='a'):
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]
        # #1/torch.square(norm_thetha)
        # if mode == 'a':
        #     temp =  feat_module(feat_maps)
        # elif mode == 'azero':
        #     feat_module.attn_conv.weight.data = torch.clamp(feat_module.attn_conv.weight.data, min=0)
        #     feat_maps = torch.clamp(feat_maps, min=0)
        #     temp = feat_module(feat_maps)
        # elif mode == 'b':
        #     temp = feat_module(feat_maps)
        #     temp = torch.square(temp)
        # elif mode == 'c':
        #     temp = feat_module(feat_maps)
        #     const_conv = torch.ones_like(feat_module.attn_conv.weight.data)
        #     temp_c = F.conv2d(feat_maps, const_conv, padding=padding_size)
        #     temp = torch.square(temp)  - torch.square(temp_c)
        if mode == 'a' or mode == 'f1':
            conv_weight = feat_module.attn_conv.weight.detach()  # d
            norm_thetha = torch.linalg.norm(conv_weight, dim=(2, 3), keepdim=True)
            feat_module.attn_conv.weight.data = feat_module.attn_conv.weight.data / (norm_thetha + 1e-6)
            temp = feat_module(feat_maps)
            temp = torch.square(temp)
        elif mode == 'b' or mode == 'f2':
            conv_weight = feat_module.attn_conv.weight.detach()  # d
            norm_thetha = torch.linalg.norm(conv_weight, dim=(2, 3), keepdim=True)
            feat_module.attn_conv.weight.data = feat_module.attn_conv.weight.data / (norm_thetha + 1e-6)
            temp = feat_module(feat_maps)
            const_conv = torch.ones_like(feat_module.attn_conv.weight.data)
            temp_c = F.conv2d(feat_maps, const_conv, padding=padding_size)
            temp = torch.square(temp)  - torch.square(temp_c)
        # elif mode == 'f1':
        # elif mode == 'f2':
        else:
            print('no selected mode!')
            print(mode)

        temp1 = temp.view(temp.shape[0], temp.shape[1], -1)
        temp1 = self.hparams.cut_th * temp1.max(dim=-1, keepdim=True)[0]
        temp1 = temp1.unsqueeze(3)
        temp = torch.where(temp > temp1, temp, torch.zeros_like(temp))

        mu_prob = torch.clip(temp / self.hparams.mu_softmax_temp, max=20)
        mu_prob = torch.exp(mu_prob)
        mu_prob = mu_prob / mu_prob.sum(dim=(2, 3), keepdim=True)


        # compute mu's mean
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        mu_x = (mu_prob * maps_x).sum(dim=(-1, -2))  # nb x na
        mu_y = (mu_prob * maps_y).sum(dim=(-1, -2))  # nb x na-=
        # form the Gaussian distribution of mu
        mus = torch.cat([mu_x, mu_y], dim=1)
        mu_x = mu_x.view(nb, na, 1, 1)
        mu_y = mu_y.view(nb, na, 1, 1)
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        sigma = self.hparams.sigma
        Ms = torch.exp(-.5 * (((maps_x - mu_x) / sigma).pow(2) + ((maps_y - mu_y) / sigma).pow(2))) / sigma ** 2

        return mu_prob, mus, Ms

    def compute_attention_test(self, feat_maps, feat_module):
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]

        # predict mu's probability
        temp =  feat_module(feat_maps)
        temp1 = temp.view(temp.shape[0], temp.shape[1], -1)
        temp1 = self.hparams.cut_th * temp1.max(dim=-1, keepdim=True)[0]
        temp1 = temp1.unsqueeze(3)
        temp = torch.where(temp > temp1, temp, torch.zeros_like(temp))
        mu_prob = torch.clip(temp / self.hparams.mu_softmax_temp, max=20)
        raw_mu = mu_prob
        mu_prob = torch.exp(mu_prob)
        mu_prob = mu_prob / mu_prob.sum(dim=(2, 3), keepdim=True)


        # compute mu's mean
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        mu_x = (mu_prob * maps_x).sum(dim=(-1, -2))  # nb x na
        mu_y = (mu_prob * maps_y).sum(dim=(-1, -2))  # nb x na-=
        # form the Gaussian distribution of mu
        mu_pairwise = torch.stack([mu_x, mu_y],dim=1)
        mu_round = torch.round(mu_pairwise)
       # print(mu_round)
        mus = torch.cat([mu_x, mu_y], dim=1)
        mu_x = mu_x.view(nb, na, 1, 1)
        mu_y = mu_y.view(nb, na, 1, 1)
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        sigma = self.hparams.sigma
        Ms = torch.exp(-.5 * (((maps_x - mu_x) / sigma).pow(2) + ((maps_y - mu_y) / sigma).pow(2))) / sigma ** 2

        return mu_prob, mus, Ms, raw_mu, mu_round




    def pred(self, data, target):
        mus_list, mu_prob_list, Ms_list, Pss, Pss1, Pss2, weight, weight_geo = self.feat_forward(data,self.sup_batch_size)

        alpha = self.hparams.alpha
        # add difference between mus to feature
        #diffs = compute_pair_wise_diff(mus, na)
        na = self.hparams.num_attn
        diffs_list = compute_pair_wise_diff_with_xycoord_all(mus_list, na)

        Pss_list = [Pss,Pss1,Pss2]
        feat_sims, diff_sims, unique_labels, _ = compute_multiscale_partbypart_combined_protonet_scores_geo_all(
            Pss_list, diffs_list, target, self.sup_batch_size)
        weight_query = weight#[self.sup_batch_size:]
        weight_geo_query = weight_geo#[self.sup_batch_size:]
        feat_sims = torch.stack(feat_sims, dim=2)
        feat_sims_w = torch.einsum('ijk,ik->ij',feat_sims,weight_query)

        diff_sims = torch.stack(diff_sims, dim=2)
        diff_sims_w = torch.einsum('ijk,ik->ij', diff_sims, weight_geo_query)

        final_sims = feat_sims_w + alpha * diff_sims_w


        pred_idx = final_sims.argmax(dim=1)
        pred_labels = unique_labels[pred_idx]

        return pred_labels

    # def pred_test(self, data, target):
    #     mus_list, mu_prob_list, Ms_list, Pss, Pss1, Pss2, weight, weight_geo = self.feat_forward_test(data,self.sup_batch_size)
    #
    #     alpha = self.hparams.alpha
    #     # add difference between mus to feature
    #     #diffs = compute_pair_wise_diff(mus, na)
    #     na = self.hparams.num_attn
    #     diffs_list = compute_pair_wise_diff_with_xycoord_all(mus_list, na)
    #
    #     Pss_list = [Pss,Pss1,Pss2]
    #     feat_sims, diff_sims, unique_labels, _ = compute_multiscale_partbypart_combined_protonet_scores_geo_all(
    #         Pss_list, diffs_list, target, self.sup_batch_size)
    #     weight_query = weight#[self.sup_batch_size:]
    #     weight_geo_query = weight_geo#[self.sup_batch_size:]
    #     feat_sims = torch.stack(feat_sims, dim=2)
    #     feat_sims_w = torch.einsum('ijk,ik->ij',feat_sims,weight_query)
    #
    #     diff_sims = torch.stack(diff_sims, dim=2)
    #     diff_sims_w = torch.einsum('ijk,ik->ij', diff_sims, weight_geo_query)
    #
    #     final_sims = feat_sims_w + alpha * diff_sims_w
    #
    #
    #     pred_idx = final_sims.argmax(dim=1)
    #     pred_labels = unique_labels[pred_idx]
    #
    #     return pred_labels

    def training_step(self, batch, batch_idx):
        if self.hparams.train_mode == 'batch':
            return super(FSLFeatModel, self).training_step(batch, batch_idx)

        data, target = batch
        #mus, mu_prob, Ms, Ps, Ps_weighted, Pss, weight = self.feat_forward(data)
        mus_list, mu_prob_list, Ms, Pss, Pss1, Pss2, weight, weight_geo = self.feat_forward(data,self.sup_batch_size_train)
        loss = self.criterion(mus_list, mu_prob_list, target, Pss,Pss1,Pss2,weight,weight_geo)
        self.log('train/div_loss', self.criterion.losses[0], prog_bar=True)
        self.log('train/proto_loss', self.criterion.losses[1], prog_bar=True)
        self.log('train/total_loss', loss)
        if (batch_idx + 1) % 50 == 0:
            self.vis_attn_maps(data, Ms[0], mu_prob_list[0], 2)
            self.vis_attn_maps(data, Ms[1], mu_prob_list[1], 2)
            self.vis_attn_maps(data, Ms[2], mu_prob_list[2], 2)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        preds = self.pred(data, target)
        query_target = target[self.sup_batch_size:]
        acc = self.acc_metric(preds, query_target)
        self.log('val_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds

    def validation_epoch_end(self, outputs):
        pass

    def test_step(self, batch, batch_idx, ifvisulalize=True):
        data, target = batch
        preds = self.pred(data, target)
        #preds = self.pred_test(data, target)
        query_target = target[self.sup_batch_size:]
        ## showing test statistics
        # ifincorrect = (preds!=query_target)
        acc = self.acc_metric(preds, query_target)
        self.log('test_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds

    def load_all(self, baseline_ckpt, feti_pretrained=False):
        """
        Load the pretrained backbone weight

        Args:
            baseline_ckpt: path to the baseline/backbone model checkpoint.
            feti_pretrained: if True, the `baseline_ckpt` is the weight FETI pretrained,
                             load directly.
        """
        if feti_pretrained:
            self.feat_extractor.load_state_dict(torch.load(baseline_ckpt))
        else:
            ckpt_dict = torch.load(baseline_ckpt)

            # self.a1.data = ckpt_dict['state_dict']['a1']
            # self.b1.data = ckpt_dict['state_dict']['b1']
            # self.c1.data= ckpt_dict['state_dict']['c1']
            # self.a2.data = ckpt_dict['state_dict']['a2']
            # self.b2.data = ckpt_dict['state_dict']['b2']
            # self.c2.data = ckpt_dict['state_dict']['c2']
            # self.a3.data = ckpt_dict['state_dict']['a3']
            # self.b3.data = ckpt_dict['state_dict']['b3']
            # self.c3.data = ckpt_dict['state_dict']['c3']

            feat_extractor_list = []
            feat_weight_map_list = []
            geo_weight_map_list = []
            # inner_weight_list = []
            # inner_weight_list1 = []
            # inner_weight_list2 = []
            feat_module_list = []
            feat_module1_list = []
            feat_module2_list = []
            att_conv_list = []
            att_conv_list1 = []
            att_conv_list2 = []


            #self.weight_map
            for key, value in ckpt_dict['state_dict'].items():
                if key.startswith('feat_extractor'):
                    feat_extractor_list.append((key.lstrip('feat_extractor').lstrip('.'), value))
                if key.startswith('weight_map') and not key.startswith('geo_weight_map'):
                    feat_weight_map_list.append((key.lstrip('weight_map').lstrip('.'), value))
                # if key.startswith('inner_weight_map') and not key.startswith('inner_weight_map1') and not key.startswith('inner_weight_map2'):
                #     inner_weight_list.append((key.lstrip('inner_weight_map').lstrip('.'), value))
                # if key.startswith('inner_weight_map1') :
                #     inner_weight_list1.append((key.lstrip('inner_weight_map1').lstrip('.'), value))
                # if key.startswith('inner_weight_map2'):
                #     inner_weight_list2.append((key.lstrip('inner_weight_map2').lstrip('.'), value))
                if key.startswith('feat_module') and not key.startswith('feat_module1') and not key.startswith('feat_module2'):
                    feat_module_list.append((key.lstrip('feat_module').lstrip('.'), value))
                if key.startswith('feat_module1'):
                    feat_module1_list.append((key.lstrip('feat_module1').lstrip('.'), value))
                if key.startswith('feat_module2'):
                    feat_module2_list.append((key.lstrip('feat_module2').lstrip('.'), value))
                if key.startswith('geo_weight_map'):
                    geo_weight_map_list.append((key.lstrip('geo_weight_map').lstrip('.'), value))
                if key.startswith('attn_conv') and not key.startswith('attn_conv1') and not key.startswith('attn_conv2'):
                    att_conv_list.append((key.lstrip('attn_conv').lstrip('.'), value))
                if key.startswith('attn_conv1'):
                    att_conv_list1.append((key.lstrip('attn_conv1').lstrip('.'), value))
                if key.startswith('attn_conv2'):
                    att_conv_list2.append((key.lstrip('attn_conv2').lstrip('.'), value))
            feat_extractor_dict = OrderedDict(feat_extractor_list)
            self.feat_extractor.load_state_dict(feat_extractor_dict)
            self.weight_map.load_state_dict(OrderedDict(feat_weight_map_list))

            # self.inner_weight_map.load_state_dict(OrderedDict(inner_weight_list))
            # self.inner_weight_map1.load_state_dict(OrderedDict(inner_weight_list1))
            # self.inner_weight_map2.load_state_dict(OrderedDict(inner_weight_list2))
            self.feat_module.load_state_dict(OrderedDict(feat_module_list))
            self.feat_module1.load_state_dict(OrderedDict(feat_module1_list))
            self.feat_module2.load_state_dict(OrderedDict(feat_module2_list))
            self.geo_weight_map.load_state_dict(OrderedDict(geo_weight_map_list))
            self.attn_conv.load_state_dict(OrderedDict(att_conv_list))
            self.attn_conv1.load_state_dict(OrderedDict(att_conv_list1))
            self.attn_conv2.load_state_dict(OrderedDict(att_conv_list2))

    def load_backbone(self, baseline_ckpt, feti_pretrained=False):
        """
        Load the pretrained backbone weight

        Args:
            baseline_ckpt: path to the baseline/backbone model checkpoint.
            feti_pretrained: if True, the `baseline_ckpt` is the weight FETI pretrained,
                             load directly.
        """
        if feti_pretrained:
            self.feat_extractor.load_state_dict(torch.load(baseline_ckpt))
        else:
            ckpt_dict = torch.load(baseline_ckpt)
            feat_extractor_list = []
            for key, value in ckpt_dict['state_dict'].items():
                if key.startswith('feat_extractor'):
                    feat_extractor_list.append((key.lstrip('feat_extractor').lstrip('.'), value))
            feat_extractor_dict = OrderedDict(feat_extractor_list)
            self.feat_extractor.load_state_dict(feat_extractor_dict)

    @staticmethod
    def add_model_loss_args(parent_parser):
        parser = parent_parser.add_argument_group('model_loss')
        parser.add_argument('--loss_weights', nargs=2, type=float, default=[1., 1.])
        parser.add_argument('--temp_proto', type=float, default=0.01)
        parser.add_argument('--alpha', type=float, default=0.01)
        return parent_parser


class BaselineModel(FeatModel):
    """
    The baseline model (standard CNN) module
    """
    def __init__(self, hparams, *args, **kwargs):
        super(FeatModel, self).__init__()
        self.hparams.update(vars(hparams))
        self.save_hyperparameters(hparams)
        self.feat_extractor = create_backbone(self.hparams.backbone_model, self.hparams.pretrain_backbone)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.clf = nn.Linear(NUM_CHANNEL[self.hparams.backbone_model], self.hparams.num_class)
        self.acc_metric = Accuracy()
        self.cm_metric = ConfusionMatrix(self.hparams.num_class)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        feat = torch.flatten(self.avg_pool(self.feat_extractor(x)), 1)
        logits = self.clf(feat)
        return logits

    def training_step(self, batch, batch_idx):
        data, target = batch
        logits = self(data)
        loss = self.criterion(logits, target)
        self.log('train/ce_loss', loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        preds = self(data).argmax(dim=1)
        ifcorrect = (preds == preds)
        acc = self.acc_metric(preds, target)
        self.cm_metric(preds, target)
        self.log('val_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds

    @staticmethod
    def add_model_arch_args(parent_parser):
        parser = parent_parser.add_argument_group('model_arch')
        parser.add_argument('--backbone_model', type=str, default='resnet34')
        parser.add_argument('--pretrain_backbone', type=bool, default=False)
        parser.add_argument('--num_class', type=int, default=200)
        parser.add_argument('--input_size', type=int, default=224)
        return parent_parser

    def update_evaluate_param(self, num_support_val, classes_per_it_val):
        self.num_support_val = num_support_val
        self.sup_batch_size = self.num_support_val * classes_per_it_val

    def fsl_pred(self, data, target):
        Ps = torch.flatten(self.avg_pool(self.feat_extractor(data)), 1)
        Ps = Ps.view(data.shape[0], -1)

        Ps_support = Ps[:self.sup_batch_size]
        Ps_query = Ps[self.sup_batch_size:]
        target_support = target[:self.sup_batch_size]

        # mean max sim
        unique_labels, uni_idx = torch.unique(target_support, return_inverse=True)
        mean_vectors = []
        for i in range(len(unique_labels)):
            mask = uni_idx == i
            mean_vectors.append(Ps_support[mask].sum(dim=0))
        mean_vectors = torch.stack(mean_vectors)
        mean_vectors = F.normalize(mean_vectors, dim=1)
        Ps_query = F.normalize(Ps_query, dim=1)
        sim_matrix = torch.einsum('ik,jk->ij', Ps_query, mean_vectors)
        pred_idx = sim_matrix.argmax(dim=1)
        pred_labels = unique_labels[pred_idx]

        return pred_labels

    def load_backbone(self, baseline_ckpt, feti_pretrained=False):
        """
        Load the pretrained backbone weight

        Args:
            baseline_ckpt: path to the baseline/backbone model checkpoint.
            feti_pretrained: if True, the `baseline_ckpt` is the weight FETI pretrained,
                             load directly.
        """
        if feti_pretrained:
            self.feat_extractor.load_state_dict(torch.load(baseline_ckpt))
        else:
            ckpt_dict = torch.load(baseline_ckpt)
            feat_extractor_list = []
            for key, value in ckpt_dict['state_dict'].items():
                if key.startswith('feat_extractor'):
                    feat_extractor_list.append((key.lstrip('feat_extractor').lstrip('.'), value))
            feat_extractor_dict = OrderedDict(feat_extractor_list)
            self.feat_extractor.load_state_dict(feat_extractor_dict)

    def test_step(self, batch, batch_idx):
        data, target = batch
        preds = self.fsl_pred(data, target)
        query_target = target[self.sup_batch_size:]
        acc = self.acc_metric(preds, query_target)
        self.log('test_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds


class FSLFeatModelvis(FSLFeatModel):

    def __init__(self, hparams,*args, **kwargs):
        super(FSLFeatModelvis, self).__init__(hparams)
       # self.hparams.update(vars(hparams))
       # self.save_hyperparameters(hparams)
        self.num_support_val = self.hparams.num_support_val
        self.num_query_val = self.hparams.num_query_val
        self.sup_batch_size = self.num_support_val * self.hparams.classes_per_it_val
        if self.hparams.train_mode == 'episode':
            self.criterion = FSLFeatLoss(self.hparams.num_attn,
                                         self.hparams.loss_weights,
                                         self.hparams.num_support_tr,
                                         self.hparams.num_query_tr,
                                         self.hparams.classes_per_it_tr,
                                         self.hparams.temp_proto,
                                         self.hparams.alpha)

    def update_evaluate_param(self, num_support_val, classes_per_it_val):
        self.num_support_val = num_support_val
        self.sup_batch_size = self.num_support_val * classes_per_it_val

    def set_alpha(self, alpha):
        self.alpha = alpha

    def pred(self, data, target):
       # mus, mu_prob, Ms, Ps = self.feat_forward(data)
       #pred_test
        mus, mu_prob, Ms, Ps, feat_maps = self.feat_forward(data)
        Ps = Ps.view(data.shape[0], -1)

        # print(mus.shape)
        # print(mu_prob.shape)
        # print(Ms.shape)
        # print(Ps.shape)
        # print(feat_maps.shape)

        alpha = self.hparams.alpha
        # add difference between mus to feature
        na = self.hparams.num_attn
        diffs = compute_pair_wise_diff(mus, na)

        feat_sims, diff_sims, unique_labels, _ = compute_combined_protonet_scores(
            Ps, diffs, target, self.sup_batch_size)
        print(feat_sims)
        print(diff_sims)
        final_sims = feat_sims + alpha * diff_sims
        print(final_sims)
        pred_idx = final_sims.argmax(dim=1)

        print(pred_idx)
        pred_labels = unique_labels[pred_idx]
        print(unique_labels)
        print(pred_labels)
        return pred_labels, final_sims

    def parts_similarity(self, data, target, ifprint = True):
       # mus, mu_prob, Ms, Ps = self.feat_forward(data)

        Pss, entropy = self.parts_feature(data)
        parts_similarity_score = []
        for Ps in Pss:
            Ps = Ps.view(data.shape[0], -1)
            alpha = self.hparams.alpha
            # add difference between mus to feature
            na = self.hparams.num_attn

            feat_sims, unique_labels, _ = compute_pure_protonet_scores(
                Ps, target, self.sup_batch_size)
            parts_similarity_score.append(feat_sims)
        if ifprint:
            print(len(parts_similarity_score))


        return parts_similarity_score, entropy[self.sup_batch_size:]

    def training_step(self, batch, batch_idx):
        if self.hparams.train_mode == 'batch':
            return super(FSLFeatModel, self).training_step(batch, batch_idx)

        data, target = batch
        mus, mu_prob, Ms, Ps = self.feat_forward(data)
        loss = self.criterion(mus, mu_prob, Ps, target)
        self.log('train/div_loss', self.criterion.losses[0], prog_bar=True)
        self.log('train/proto_loss', self.criterion.losses[1], prog_bar=True)
        self.log('train/total_loss', loss)
        if (batch_idx + 1) % 50 == 0:
            self.vis_attn_maps(data, Ms, mu_prob, 2)
        return loss

    def validation_step(self, batch, batch_idx):
        data, target = batch
        preds = self.pred(data, target)
        query_target = target[self.sup_batch_size:]
        acc = self.acc_metric(preds, query_target)
        self.log('val_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)
        return preds

    def feat_forward(self, x):
        feat_maps = self.feat_extractor(x)
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]

        # predict mu's probability
        temp = self.feat_module(feat_maps)
        temp1 = temp.view(temp.shape[0], temp.shape[1], -1)
        temp1 = self.hparams.cut_th * temp1.max(dim=-1, keepdim=True)[0]
        temp1 = temp1.unsqueeze(3)
        temp = torch.where(temp > temp1, temp, torch.zeros_like(temp))
        mu_prob = torch.clip(temp / self.hparams.mu_softmax_temp, max=20)
        mu_prob = torch.exp(mu_prob)
        mu_prob = mu_prob / mu_prob.sum(dim=(2, 3), keepdim=True)


        # compute mu's mean
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        mu_x = (mu_prob * maps_x).sum(dim=(-1, -2))  # nb x na
        mu_y = (mu_prob * maps_y).sum(dim=(-1, -2))  # nb x na

        # form the Gaussian distribution of mu
        mus = torch.cat([mu_x, mu_y], dim=1)
        mu_x = mu_x.view(nb, na, 1, 1)
        mu_y = mu_y.view(nb, na, 1, 1)
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        sigma = self.hparams.sigma
        Ms = torch.exp(-.5 * (((maps_x - mu_x) / sigma).pow(2) + ((maps_y - mu_y) / sigma).pow(2))) / sigma ** 2

        # apply convolution and attention pooling
        feats = []
        for i in range(na):
            attn_map = torch.sigmoid(self.attn_conv[i](Ms[:, i].unsqueeze(1)))
            attn_map = attn_map.view(nb, self.nc, self.hparams.num_theta, w, h)
            attn_norm = torch.norm(attn_map, p='fro', dim=(3, 4), keepdim=True)
            attn_map = attn_map / attn_norm
            feat = (feat_maps.unsqueeze(dim=2) * attn_map).sum(dim=(-1, -2))  # [nb x nc x num_theta]
            feat = feat.view(nb, -1)
            feats.append(feat)

        Ps = torch.stack(feats, dim=2)

        mrg, _ = Ps.max(dim=1, keepdim=True)
        mrg = mrg * self.mrg
        Ps = torch.sign(Ps) * torch.clip(Ps.abs() - mrg, min=0)
        return mus, mu_prob, Ms, Ps, feat_maps

    def parts_feature(self, x):
        feat_maps = self.feat_extractor(x)
        nb = feat_maps.shape[0]
        na = self.hparams.num_attn
        w, h = feat_maps.shape[2:4]

        # predict mu's probability
        temp = self.feat_module(feat_maps)
        temp1 = temp.view(temp.shape[0], temp.shape[1], -1)
        temp1 = self.hparams.cut_th * temp1.max(dim=-1, keepdim=True)[0]
        temp1 = temp1.unsqueeze(3)
        temp = torch.where(temp > temp1, temp, torch.zeros_like(temp))
        mu_prob = torch.clip(temp / self.hparams.mu_softmax_temp, max=20)
        mu_prob = torch.exp(mu_prob)
        mu_prob = mu_prob / mu_prob.sum(dim=(2, 3), keepdim=True)

        mu_entropy = (-mu_prob * torch.log(mu_prob)).sum(axis=(-1,-2))
        #print(mu_entropy.shape)

        # compute mu's mean
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        mu_x = (mu_prob * maps_x).sum(dim=(-1, -2))  # nb x na
        mu_y = (mu_prob * maps_y).sum(dim=(-1, -2))  # nb x na

        # form the Gaussian distribution of mu
        mus = torch.cat([mu_x, mu_y], dim=1)
        mu_x = mu_x.view(nb, na, 1, 1)
        mu_y = mu_y.view(nb, na, 1, 1)
        x = torch.arange(w, device=feat_maps.device) + .5
        y = torch.arange(h, device=feat_maps.device) + .5
        maps_x, maps_y = torch.meshgrid(x, y)
        maps_x = maps_x.repeat(nb, na, 1, 1)
        maps_y = maps_y.repeat(nb, na, 1, 1)
        sigma = self.hparams.sigma
        Ms = torch.exp(-.5 * (((maps_x - mu_x) / sigma).pow(2) + ((maps_y - mu_y) / sigma).pow(2))) / sigma ** 2

        # apply convolution and attention pooling
        feats = []
        for i in range(na):
            attn_map = torch.sigmoid(self.attn_conv[i](Ms[:, i].unsqueeze(1)))
            attn_map = attn_map.view(nb, self.nc, self.hparams.num_theta, w, h)
            attn_norm = torch.norm(attn_map, p='fro', dim=(3, 4), keepdim=True)
            attn_map = attn_map / attn_norm
            feat = (feat_maps.unsqueeze(dim=2) * attn_map).sum(dim=(-1, -2))  # [nb x nc x num_theta]
            feat = feat.view(nb, -1)
            feats.append(feat)
        Pss = feats
        #print(len(Pss))
        #print(Pss[0].shape)
        #Ps = torch.stack(feats, dim=2)

        return Pss, mu_entropy

    def validation_epoch_end(self, outputs):
        pass

    def test_step(self, batch, batch_idx):
        data, target, index = batch
        # 4 part similarity for each part
        parts_sim_score, entropy = self.parts_similarity(data, target)
        # print(parts_sim_score[0])
        # print(parts_sim_score[1])
        # print(parts_sim_score[2])
        # print(parts_sim_score[3])
        print('-----------part entropy --------------')
        #print(entropy)

        preds, sim_scores = self.pred(data, target)
        query_target = target[self.sup_batch_size:]
        support_target = target[0:self.sup_batch_size]
        query_index = index[self.sup_batch_size:]
        support_index = index[0:self.sup_batch_size]
        print('query index',query_index+2)
        print('support index',support_index+2)
        print('support labels',support_target)
        ifcorrect = (preds!=query_target)
        print('----------now, printing part to part similarities')
        print(ifcorrect)
        if sum(ifcorrect)>0:
            print('---------wrong-----------')
            incorrect_labels = query_target[ifcorrect].tolist()
            incorrect_query_index = (query_index[ifcorrect]+2).tolist()
            print('incorrect query index',incorrect_query_index)

            incorrect_preds = preds[ifcorrect].tolist()
            print('incorrect label',incorrect_preds)
            most_mistaken_labels = max(set(incorrect_labels), key=incorrect_labels.count)
            mistake_index = [i for i, e in enumerate(incorrect_labels) if e == most_mistaken_labels]
            most_mistaken_preds = [e for i, e in enumerate(incorrect_preds) if i in mistake_index]
            most_mistaken_preds = max(set(most_mistaken_preds), key=most_mistaken_preds.count)
            print(most_mistaken_labels,'usually wrongly interpreated as',most_mistaken_preds)
            most_mistaken_pred = incorrect_preds[incorrect_labels==most_mistaken_labels]
        acc = self.acc_metric(preds, query_target)
        self.log('test_acc', acc, prog_bar=True)
        self.log('hp_metric', acc)

        return preds

    def load_backbone(self, baseline_ckpt, feti_pretrained=False):
        """
        Load the pretrained backbone weight

        Args:
            baseline_ckpt: path to the baseline/backbone model checkpoint.
            feti_pretrained: if True, the `baseline_ckpt` is the weight FETI pretrained,
                             load directly.
        """
        if feti_pretrained:
            self.feat_extractor.load_state_dict(torch.load(baseline_ckpt))
        else:
            ckpt_dict = torch.load(baseline_ckpt)
            feat_extractor_list = []
            for key, value in ckpt_dict['state_dict'].items():
                if key.startswith('feat_extractor'):
                    feat_extractor_list.append((key.lstrip('feat_extractor').lstrip('.'), value))
            feat_extractor_dict = OrderedDict(feat_extractor_list)
            self.feat_extractor.load_state_dict(feat_extractor_dict)

    @staticmethod
    def add_model_loss_args(parent_parser):
        parser = parent_parser.add_argument_group('model_loss')
        parser.add_argument('--loss_weights', nargs=2, type=float, default=[1., 1.])
        parser.add_argument('--temp_proto', type=float, default=0.01)
        parser.add_argument('--alpha', type=float, default=0.01)
        return parent_parser