from dataclasses import dataclass
from typing import List, Tuple
from typing_extensions import Literal
from collections import OrderedDict
import torch
import torch.nn as nn

from fs_mol.modules.graph_feature_extractor import (
    GraphFeatureExtractor,
    GraphFeatureExtractorConfig,
)
from fs_mol.data.protonet import ProtoNetBatch


FINGERPRINT_DIM = 2048
PHYS_CHEM_DESCRIPTORS_DIM = 42

@dataclass(frozen=True)
class HyProConfig:
    # Model configuration:
    graph_feature_extractor_config: GraphFeatureExtractorConfig = GraphFeatureExtractorConfig()
    used_features: Literal[
        "gnn", "ecfp", "pc-descs", "gnn+ecfp", "ecfp+fc", "pc-descs+fc", "gnn+ecfp+pc-descs+fc"
    ] = "gnn+ecfp+fc"
    distance_metric: Literal["mahalanobis", "euclidean","cosine"] ="mahalanobis" #"euclidean"#
    ood2:float = 0.1
    ood3:float = 0.1
    hyper_layer_num:int=2
    hyper_dropout:float=0.3
    sample_start:int=2
    sample_end:int=7
    sample_div:int=128

def cosine_distance(query_embeddings, support_set_embeddings, k0,l2Norm=False,scaling='1/sqrt(N)',stabilizer=torch.tensor(1e-8).float()):
    # L2-Norm
    if l2Norm:
        query_embeddings_div = torch.unsqueeze(query_embeddings.pow(2).sum(dim=2).sqrt(), 2)
        query_embeddings_div[query_embeddings_div == 0] = 1
        support_set_embeddings_div = torch.unsqueeze(support_set_embeddings.pow(2).sum(dim=2).sqrt(), 2)
        support_set_embeddings_div[support_set_embeddings_div == 0] = 1

        query_embeddings = query_embeddings / query_embeddings_div
        support_set_embeddings = support_set_embeddings / support_set_embeddings_div
    #(q,s)=(q,d)(d,s)
    n=int(support_set_embeddings.size(0))
    support_set_embeddings=nn.functional.normalize(support_set_embeddings, p=2.0, dim=-1)
    query_embeddings=nn.functional.normalize(query_embeddings, p=2.0, dim=-1)
    similarities = query_embeddings @ torch.transpose(support_set_embeddings, 0, 1)
    coes=torch.cat((-torch.ones_like(similarities[:,:k0])*n/(2.*torch.sqrt(k0) + stabilizer),torch.ones_like(similarities[:,k0:])*n/(2.*torch.sqrt(n-k0) + stabilizer)),-1)
    similarity_sums=torch.sum(similarities*coes,-1).unsqueeze(-1)
    logits=torch.cat((similarity_sums,1-similarity_sums),-1)
    # Masking: Remove padded support set artefacts
    '''mask = torch.zeros_like(similarities)
    for task_idx in range(support_set_embeddings.shape[0]):
        real_size = support_set_size[task_idx]
        if real_size > 0:
            mask[task_idx, :, :real_size] = torch.ones_like(mask[task_idx, :, :real_size])

    # Compute similarity values
    similarities = similarities * mask
    similarity_sums = similarities.sum(dim=2)  # For every query molecule: Sum over support set molecules

    # Scaling
    if scaling == '1/N':
        stabilizer = torch.tensor(1e-8).float()
        similarity_sums = 1/(2.*support_set_size.reshape(-1, 1) + stabilizer) * similarity_sums
    elif scaling == '1/sqrt(N)':
        stabilizer = torch.tensor(1e-8).float()
        similarity_sums = 1 / (2.*torch.sqrt(support_set_size.reshape(-1, 1).float()) + stabilizer) * similarity_sums'''

    return logits

class MLP(nn.Module):
    def __init__(self, inp_dim, hidden_dim,num_layers,batch_norm=False, dropout=0.):
        super(MLP, self).__init__()
        self.nl=num_layers
        self.batch_norm=batch_norm
        self.dropout=dropout
        num_dims_list = [hidden_dim] * (num_layers)
        self.layer_list={}
        for l in range(num_layers):
            if l==0:
                self.layer_list['fc{}'.format(l)] = nn.Linear(inp_dim, num_dims_list[l])
            else:
                self.layer_list['fc{}'.format(l)] = nn.Linear(num_dims_list[l-1], num_dims_list[l])
            if l < num_layers - 1:
                if batch_norm:
                    self.layer_list['norm{}'.format(l)] = nn.BatchNorm1d(num_features=num_dims_list[l])
                self.layer_list['relu{}'.format(l)] = nn.LeakyReLU()
                if dropout > 0:
                    self.layer_list['drop{}'.format(l)] = nn.Dropout(p=dropout)
        self.layer_list=nn.ModuleDict(self.layer_list)

    def forward(self, x):
        for l in range(self.nl):
            x0=x
            x=self.layer_list['fc{}'.format(l)](x)
            if l < self.nl - 1:
                if self.batch_norm:
                    x=self.layer_list['norm{}'.format(l)](x)
                x=self.layer_list['relu{}'.format(l)](x)
                if self.dropout > 0:
                    self.layer_list['drop{}'.format(l)](x)
            if l>0:
                x+=x0
        return x
    
class ADC(nn.Module):
    def __init__(self, inp_dim,hidden_dim, n_layer=4,batch_norm=False, dropout=0.0,gpu=0):
        super(ADC, self).__init__()
        
        
        self.nl=n_layer
        self.dropout=dropout
        self.batch_norm=batch_norm
        self.gpu_id=gpu

        num_dims_list = [hidden_dim] * (n_layer-1)+[inp_dim]
        self.layer_list_w = OrderedDict()
        for l in range(self.nl):
            if l==0:
                self.layer_list_w['fc{}'.format(l)] = nn.Linear(inp_dim, num_dims_list[l])
            else:
                self.layer_list_w['fc{}'.format(l)] = nn.Linear(num_dims_list[l-1], num_dims_list[l])
            if l < self.nl - 1:
                if batch_norm:
                    self.layer_list_w['norm{}'.format(l)] = nn.BatchNorm1d(num_features=num_dims_list[l])
                self.layer_list_w['relu{}'.format(l)] = nn.LeakyReLU()
                if dropout > 0:
                    self.layer_list_w['drop{}'.format(l)] = nn.Dropout(p=dropout)
        self.layer_list_w=nn.ModuleDict(self.layer_list_w)

        num_dims_list = [hidden_dim] * (self.nl-1)+[1]
        self.layer_list_b = OrderedDict()
        for l in range(self.nl):
            if l==0:
                self.layer_list_b['fc{}'.format(l)] = nn.Linear(inp_dim, num_dims_list[l])
            else:
                self.layer_list_b['fc{}'.format(l)] = nn.Linear(num_dims_list[l-1], num_dims_list[l])
            if l < self.nl - 1:
                if batch_norm:
                    self.layer_list_b['norm{}'.format(l)] = nn.BatchNorm1d(num_features=num_dims_list[l])
                self.layer_list_b['relu{}'.format(l)] = nn.LeakyReLU()
                if dropout > 0:
                    self.layer_list_b['drop{}'.format(l)] = nn.Dropout(p=dropout)
        self.layer_list_b=nn.ModuleDict(self.layer_list_b)


    def forward(self,s_emb,k0,q_emb):
        x_in=torch.stack((s_emb[:k0,:].mean(0),s_emb[k0:,:].mean(0)),0)
        w=x_in
        for l in range(self.nl):
            w0=w
            w=self.layer_list_w['fc{}'.format(l)](w)
            if l < self.nl - 1:
                if self.batch_norm:
                    w=self.layer_list_w['norm{}'.format(l)](w)
                w=self.layer_list_w['relu{}'.format(l)](w)
                if self.dropout > 0:
                    self.layer_list_w['drop{}'.format(l)](w)
            if l>0 and l<self.nl-1:
                w+=w0#2,128
        w=w.transpose(-1,-2)
        b=x_in
        for l in range(self.nl):
            b0=b
            b=self.layer_list_b['fc{}'.format(l)](b)
            if l < self.nl - 1:
                if self.batch_norm:
                    b=self.layer_list_b['norm{}'.format(l)](b)
                b=self.layer_list_b['relu{}'.format(l)](b)
                if self.dropout > 0:
                    self.layer_list_b['drop{}'.format(l)](b)
            if l>0 and l<self.nl-1:
                b+=b0#2,1
        b=b.transpose(-1,-2)
        #print(all_emb.size(),w.size(),b.size())
        x=torch.mm(q_emb,w)+b
        
        return x
    
class HyperSetEncoder2(nn.Module):
    def __init__(self, inp_dim,hidden_dim,out_dim, n_layer1=2, n_layer2=2,batch_norm=False, dropout=0,gpu=0):
        super(HyperSetEncoder2, self).__init__()
        
        self.n_layer1=n_layer1
        self.n_layer2=n_layer2
        self.nl=n_layer1+n_layer2
        self.dropout=dropout
        self.batch_norm=batch_norm
        self.gpu_id=gpu

        num_dims_listo = [hidden_dim] * (n_layer1+n_layer2-1)+[out_dim]
        num_dims_listi=num_dims_listo[:]
        num_dims_listi[n_layer1-1]=hidden_dim*2
        self.layer_list = OrderedDict()
        for l in range(n_layer1+n_layer2):
            if l==0:
                self.layer_list['fc{}'.format(l)] = nn.Linear(inp_dim, num_dims_listo[l])
            else:
                self.layer_list['fc{}'.format(l)] = nn.Linear(num_dims_listi[l-1], num_dims_listo[l])
            if l < n_layer1+n_layer2 - 1:
                if batch_norm:
                    self.layer_list['norm{}'.format(l)] = nn.BatchNorm1d(num_features=num_dims_listo[l])
                self.layer_list['relu{}'.format(l)] = nn.LeakyReLU()
                if dropout > 0:
                    self.layer_list['drop{}'.format(l)] = nn.Dropout(p=dropout)
        self.layer_list=nn.ModuleDict(self.layer_list)


    def cat_label(self,all_emb,k0,ood=0):
        if ood in [3]:
            n,_,d=all_emb.size()
            d="cuda:{}".format(self.gpu_id)#"cpu"#"cuda:{}".format(self.gpu_id)
            l0=torch.zeros((k0,k0-1,1),device=d)
            l1=torch.ones((k0,k0-1,1),device=d)
            ln=torch.cat((l1,l0),-1)
            del l0,l1

            l0=torch.zeros((k0,n-k0,1),device=d)
            l1=torch.ones((k0,n-k0,1),device=d)
            lp=torch.cat((l0,l1),-1)
            del l0,l1
            
            lq0=torch.cat((ln,lp),1)

            l0=torch.zeros((n-k0,k0,1),device=d)
            l1=torch.ones((n-k0,k0,1),device=d)
            ln=torch.cat((l1,l0),-1)
            del l0,l1

            l0=torch.zeros((n-k0,n-k0-1,1),device=d)
            l1=torch.ones((n-k0,n-k0-1,1),device=d)
            lp=torch.cat((l0,l1),-1)
            del l0,l1
            
            lq1=torch.cat((ln,lp),1)

            ls=torch.cat((lq0,lq1),0)
            ae=torch.cat((all_emb,ls),-1)
            return ae

        n,d=all_emb.size()
        k=int((n-1)/2)
        d="cuda:{}".format(self.gpu_id)#"cpu"#"cuda:{}".format(self.gpu_id)
        l0=torch.zeros((k0,1),device=d)
        l1=torch.ones((k0,1),device=d)
        
        ln=torch.cat((l1,l0),-1)
        del l0,l1
        l0=torch.zeros((n-k0,1),device=d)
        l1=torch.ones((n-k0,1),device=d)
        lp=torch.cat((l0,l1),-1)
        del l0,l1
        ls=torch.cat((ln,lp),0)
        ae=torch.cat((all_emb[:,:],ls),-1)
        return ae
    def forward(self,all_emb,k0):
        x=all_emb
        for l in range(self.nl):
            x0=x
            x=self.layer_list['fc{}'.format(l)](x)
            if l < self.nl - 1:
                if self.batch_norm:
                    x=self.layer_list['norm{}'.format(l)](x)
                x=self.layer_list['relu{}'.format(l)](x)
                if self.dropout > 0:
                    self.layer_list['drop{}'.format(l)](x)
            if l>0 and l<self.nl-1 and l!=self.n_layer1:
                x+=x0
            if l==self.n_layer1-1:
                p0=torch.mean(x[0:k0],0)
                p1=torch.mean(x[k0:],0)
                x=torch.cat((p0,p1),-1)
                #x=torch.mean(x,0)
        return x

class HyProModel(nn.Module):
    def __init__(self, config: HyProConfig):
        super().__init__()
        self.config = config

        # Create GNN if needed:
        if self.config.used_features.startswith("gnn"):
            self.graph_feature_extractor = GraphFeatureExtractor(
                config.graph_feature_extractor_config
            )

        self.use_fc = self.config.used_features.endswith("+fc")

        # Create MLP if needed:
        if self.use_fc:
            # Determine dimension:
            fc_in_dim = 0
            if "gnn" in self.config.used_features:
                fc_in_dim += self.config.graph_feature_extractor_config.readout_config.output_dim#512
            if "ecfp" in self.config.used_features:
                fc_in_dim += FINGERPRINT_DIM#2048
            if "pc-descs" in self.config.used_features:
                fc_in_dim += PHYS_CHEM_DESCRIPTORS_DIM
            self.gen_w0=HyperSetEncoder2(inp_dim=fc_in_dim+2,hidden_dim=fc_in_dim,out_dim=fc_in_dim,n_layer1=self.config.hyper_layer_num,n_layer2=self.config.hyper_layer_num,dropout=self.config.hyper_dropout)
            self.gen_b0=HyperSetEncoder2(inp_dim=fc_in_dim+2,hidden_dim=fc_in_dim,out_dim=fc_in_dim,n_layer1=self.config.hyper_layer_num,n_layer2=self.config.hyper_layer_num,dropout=self.config.hyper_dropout)
            '''self.gen_w1=HyperSetEncoder2(inp_dim=1024+2,hidden_dim=1024,out_dim=1024)
            self.gen_b1=HyperSetEncoder2(inp_dim=1024+2,hidden_dim=1024,out_dim=1024)
            self.gen_w2=HyperSetEncoder2(inp_dim=1024+2,hidden_dim=1024,out_dim=1024)
            self.gen_b2=HyperSetEncoder2(inp_dim=1024+2,hidden_dim=1024,out_dim=1024)'''
            self.fc = nn.Sequential(nn.Linear(fc_in_dim, 1024),nn.LeakyReLU(),nn.Linear(1024, 1024))#,nn.SELU(),nn.Linear(1024, 1024))
            #self.fc = nn.Sequential(nn.Linear(fc_in_dim, 2048),nn.SELU(),nn.Linear(2048, 1024),nn.SELU(),nn.Linear(1024, 1024))#,nn.SELU(),nn.Linear(1024, 1024))
            
            #self.gen_w0=HyperSetEncoder2(inp_dim=fc_in_dim+2,hidden_dim=512,out_dim=512)
            #self.gen_b0=HyperSetEncoder2(inp_dim=fc_in_dim+2,hidden_dim=512,out_dim=512)
            #self.fc = nn.Sequential(nn.Linear(fc_in_dim+512, 512),nn.LeakyReLU(),nn.Linear(512, 512),)
        
        #self.adc=ADC(512,512)

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    def forward(self, input_batch: ProtoNetBatch,ood=0,query_labels=None,ood2=0):
        support_features: List[torch.Tensor] = []
        query_features: List[torch.Tensor] = []

        if "gnn" in self.config.used_features:
            support_features.append(self.graph_feature_extractor(input_batch.support_features))
            query_features.append(self.graph_feature_extractor(input_batch.query_features))
        if "ecfp" in self.config.used_features:
            support_features.append(input_batch.support_features.fingerprints)
            query_features.append(input_batch.query_features.fingerprints)
        if "pc-descs" in self.config.used_features:
            support_features.append(input_batch.support_features.descriptors)
            query_features.append(input_batch.query_features.descriptors)
        support_features_flat = torch.cat(support_features, dim=1)
        query_features_flat = torch.cat(query_features, dim=1)

        if ood in [1,2,3,4,5,6]:
            if ood in [1,2,3,4,5,6]:
                
                labels=torch.unique(input_batch.support_labels)
                class_features_0 = torch.index_select(support_features_flat, 0, self._extract_class_indices(input_batch.support_labels,labels[0]))
                class_features_1 = torch.index_select(support_features_flat, 0, self._extract_class_indices(input_batch.support_labels, labels[1]))
                support_features_flat0=torch.cat((class_features_0,class_features_1),0)
                label_0 = torch.index_select(input_batch.support_labels, 0, self._extract_class_indices(input_batch.support_labels,labels[0]))
                label_1 = torch.index_select(input_batch.support_labels, 0, self._extract_class_indices(input_batch.support_labels, labels[1]))
                new_labels=torch.cat((label_0,label_1),0)
                k0=int(class_features_0.size(0))
                #nf=self.gen_w0.cat_label(support_features_flat0,k0)
                #w=self.gen_w0(nf,k0).unsqueeze(0)#[1,2560]
                #b=self.gen_b0(nf,k0).unsqueeze(0)
                #support_features_flat0=support_features_flat0*(1+w)#+b
                #query_features_flat0=query_features_flat*(1+w)#+b
                if ood in [1,2,3,4,5]:
                    nf=self.gen_w0.cat_label(support_features_flat0,k0)
                    w0=self.gen_w0(nf,k0).unsqueeze(0)#[1,2560]
                    b0=self.gen_b0(nf,k0).unsqueeze(0)
                    support_features_flat0=support_features_flat0*(1+w0)+b0
                    query_features_flat0=query_features_flat*(1+w0)+b0
                support_features_flat0 = self.fc(support_features_flat0)
                query_features_flat0 = self.fc(query_features_flat0)

                '''nf=self.gen_w1.cat_label(support_features_flat0,k0)
                w1=self.gen_w1(nf,k0).unsqueeze(0)
                b1=self.gen_b1(nf,k0).unsqueeze(0)
                support_features_flat0=support_features_flat0*(1+w1)+b1
                query_features_flat0=query_features_flat0*(1+w1)+b1
                support_features_flat0 = self.fc[1](support_features_flat0)
                query_features_flat0 = self.fc[1](query_features_flat0)
                support_features_flat0 = self.fc[2](support_features_flat0)
                query_features_flat0 = self.fc[2](query_features_flat0)
                nf=self.gen_w2.cat_label(support_features_flat0,k0)
                w2=self.gen_w2(nf,k0).unsqueeze(0)
                b2=self.gen_b2(nf,k0).unsqueeze(0)
                support_features_flat0=support_features_flat0*(1+w2)+b2
                query_features_flat0=query_features_flat0*(1+w2)+b2'''
                #support_features_flat0=torch.cat((support_features_flat0,w.repeat(support_features_flat0.size(0),1)),-1)
                #query_features_flat0=torch.cat((query_features_flat,w.repeat(query_features_flat.size(0),1)),-1)
                #support_features_flat0 = self.fc(support_features_flat0)
                #query_features_flat0 = self.fc(query_features_flat0)
                if self.config.distance_metric == "mahalanobis":
                    class_means, class_precision_matrices = self.compute_class_means_and_precisions(
                        support_features_flat0, new_labels
                    )

                    # grabbing the number of classes and query examples for easier use later
                    number_of_classes = class_means.size(0)
                    number_of_targets = query_features_flat0.size(0)

                    """
                    Calculating the Mahalanobis distance between query examples and the class means
                    including the class precision estimates in the calculations, reshaping the distances
                    and multiplying by -1 to produce the sample logits
                    """
                    repeated_target = query_features_flat0.repeat(1, number_of_classes).view(-1, class_means.size(1))
                    repeated_class_means = class_means.repeat(number_of_targets, 1)
                    repeated_difference = repeated_class_means - repeated_target
                    repeated_difference = repeated_difference.view(number_of_targets, number_of_classes, repeated_difference.size(1)).permute(1, 0, 2)
                    first_half = torch.matmul(repeated_difference, class_precision_matrices)
                    logits0 = torch.mul(first_half, repeated_difference).sum(dim=2).transpose(1, 0) * -1
                elif self.config.distance_metric == "cosine":
                    n=int(support_features_flat0.size(0))
                    class_means=torch.cat((support_features_flat0[0:k].mean(0,keep_dim=True),support_features_flat0[k:].mean(0,keep_dim=True)),0)
                    class_means=nn.functional.normalize(class_means, p=2.0, dim=-1)
                    query_features_flat0=nn.functional.normalize(query_features_flat0, p=2.0, dim=-1)
                    similarities = query_features_flat0 @ torch.transpose(class_means, 0, 1)#q,2
                    logits=similarities
                else:  # euclidean
                    logits0 = self._protonets_euclidean_classifier(support_features_flat0,query_features_flat0,new_labels,)
                if ood in [5]:
                    loss0=self.compute_loss(logits0,query_labels)
                    #loss0.backward()
                    #del support_features_flat0,query_features_flat0
                #logits0 = self._siamese_distance(query_features_flat0,support_features_flat0,k0,metric='cosine',ood=ood)#cosine, euclidean,dot
            #print(input_batch.support_labels.size(),query_labels.size(),query_features_flat.size())#64,93,93

            support_features_flat=torch.cat((support_features_flat,query_features_flat),0)
            n=int(support_features_flat.size(0))
            
            new_labels=torch.cat((input_batch.support_labels,query_labels))
            labels=torch.unique(new_labels)
            class_features_0 = torch.index_select(support_features_flat, 0, self._extract_class_indices(new_labels,labels[0]))
            class_features_1 = torch.index_select(support_features_flat, 0, self._extract_class_indices(new_labels, labels[1]))
            support_features_flat=torch.cat((class_features_0,class_features_1),0)
            label_0 = torch.index_select(new_labels, 0, self._extract_class_indices(new_labels,labels[0]))
            label_1 = torch.index_select(new_labels, 0, self._extract_class_indices(new_labels, labels[1]))
            new_labels=torch.cat((label_0,label_1),0)
            k0=int(class_features_0.size(0))


        else:
            labels=torch.unique(input_batch.support_labels)
            class_features_0 = torch.index_select(support_features_flat, 0, self._extract_class_indices(input_batch.support_labels,labels[0]))
            class_features_1 = torch.index_select(support_features_flat, 0, self._extract_class_indices(input_batch.support_labels, labels[1]))
            support_features_flat=torch.cat((class_features_0,class_features_1),0)
            label_0 = torch.index_select(input_batch.support_labels, 0, self._extract_class_indices(input_batch.support_labels,labels[0]))
            label_1 = torch.index_select(input_batch.support_labels, 0, self._extract_class_indices(input_batch.support_labels, labels[1]))
            new_labels=torch.cat((label_0,label_1),0)
            k0=int(class_features_0.size(0))
        
        if self.use_fc:

            if ood in [1,2,3,4,5,6]:
                t=[]
                loss=0
                bz=0
                for szi in range(self.config.sample_start,self.config.sample_end):#[6]:#
                    sz=2**szi
                    if sz>2*min(k0,n-k0):
                        #sz=2*min(k0,n-k0)
                        break
                    
                    for i in range(int(self.config.sample_div/sz)):#/sz
                        bz+=1
                        k=int(sz/2)
                        s0=support_features_flat[torch.randperm(int(k0))]
                        s1=support_features_flat[k0+torch.randperm(int(n-k0))]
                        support_samples=torch.cat((s0[:k],s1[:k]),0)
                        query_samples=torch.cat((s0[k:],s1[k:]),0)
                        if ood in [1,2,3,4,6]:
                            #nf=self.gen_w0.cat_label(support_samples,k)
                            #wt=self.gen_w0(nf,k).unsqueeze(0)
                            #bt=self.gen_b0(nf,k).unsqueeze(0)
                            #task_embedding=torch.cat((wt,bt),-1)
                            if ood in [1,2,3,4]:
                                nf=self.gen_w0.cat_label(support_samples,k)
                                w0=self.gen_w0(nf,k).unsqueeze(0)#[1,2560]
                                b0=self.gen_b0(nf,k).unsqueeze(0)
                                support_samples=support_samples*(1+w0)+b0
                                query_samples=query_samples*(1+w0)+b0
                                task_embedding=torch.cat((w0,b0),-1)
                            elif ood in [6]:
                                task_embedding=torch.mean(support_samples,0,keepdim=True)
                            support_samples = self.fc(support_samples)
                            query_samples = self.fc(query_samples)
                        '''nf=self.gen_w1.cat_label(support_samples,k)
                        w1=self.gen_w1(nf,k).unsqueeze(0)
                        b1=self.gen_b1(nf,k).unsqueeze(0)
                        support_samples=support_samples*(1+w1)+b1
                        query_samples=query_samples*(1+w1)+b1
                        support_samples = self.fc[1](support_samples)
                        query_samples = self.fc[1](query_samples)
                        support_samples = self.fc[2](support_samples)
                        query_samples = self.fc[2](query_samples)
                        nf=self.gen_w2.cat_label(support_samples,k)
                        w2=self.gen_w2(nf,k).unsqueeze(0)
                        b2=self.gen_b2(nf,k).unsqueeze(0)
                        support_samples=support_samples*(1+w2)+b2
                        query_samples=query_samples*(1+w2)+b2
                        task_embedding=torch.cat((w0,b0,w1,b1,w2,b2),-1)'''
                        
                        if len(t)==0:
                           t=task_embedding
                        else:
                            t=torch.cat((t,task_embedding),0)
                        #support_samples=support_samples*(1+wt)#+bt
                        #query_samples=query_samples*(1+wt)#+bt
                        #support_samples=torch.cat((support_samples,wt.repeat(support_samples.size(0),1)),-1)
                        #query_samples=torch.cat((query_samples,wt.repeat(query_samples.size(0),1)),-1)
                        #support_samples = self.fc(support_samples)
                        #query_samples = self.fc(query_samples)
                        if ood in [1,2,3,4,6]:
                            s_labels=torch.cat((torch.zeros(k),torch.ones(k))).cuda()
                            if self.config.distance_metric == "mahalanobis":
                                class_means, class_precision_matrices = self.compute_class_means_and_precisions(
                                    support_samples, s_labels
                                )

                                # grabbing the number of classes and query examples for easier use later
                                number_of_classes = class_means.size(0)
                                number_of_targets = query_samples.size(0)

                                """
                                Calculating the Mahalanobis distance between query examples and the class means
                                including the class precision estimates in the calculations, reshaping the distances
                                and multiplying by -1 to produce the sample logits
                                """
                                repeated_target = query_samples.repeat(1, number_of_classes).view(-1, class_means.size(1))
                                repeated_class_means = class_means.repeat(number_of_targets, 1)
                                repeated_difference = repeated_class_means - repeated_target
                                repeated_difference = repeated_difference.view(number_of_targets, number_of_classes, repeated_difference.size(1)).permute(1, 0, 2)
                                first_half = torch.matmul(repeated_difference, class_precision_matrices)
                                logits = torch.mul(first_half, repeated_difference).sum(dim=2).transpose(1, 0) * -1
                            elif self.config.distance_metric == "cosine":
                                n=int(support_samples.size(0))
                                class_means=torch.cat((support_samples[0:k].mean(0,keep_dim=True),support_samples[k:].mean(0,keep_dim=True)),0)
                                class_means=nn.functional.normalize(class_means, p=2.0, dim=-1)
                                query_samples=nn.functional.normalize(query_samples, p=2.0, dim=-1)
                                similarities = query_samples @ torch.transpose(class_means, 0, 1)#q,2
                                logits=similarities
                            else:  # euclidean
                                logits = self._protonets_euclidean_classifier(support_samples,query_samples,s_labels,)
                            #logits = self._siamese_distance(query_samples,query_samples,k,metric='cosine',ood=ood)#cosine, euclidean,dot
                            q_labels=torch.cat((torch.zeros(k0-k),torch.ones(n-k0-k))).cuda()
                            loss+=self.compute_loss(logits,q_labels)
                            del s_labels,q_labels
                if ood in [3,4,5]:
                    nf=self.gen_w0.cat_label(support_features_flat,k0)
                    w=self.gen_w0(nf,k0).unsqueeze(0)#[1,2560]
                    b=self.gen_b0(nf,k0).unsqueeze(0)
                    task_emb0=torch.cat((w,b),-1)
                if ood in [6]:
                    task_emb0=torch.mean(support_features_flat,0,keepdim=True)
                if ood in [2,3,6]:
                    task_emb=nn.functional.normalize(task_emb0, p=2.0, dim=-1)
                    loss-=ood2*torch.sum(nn.functional.normalize(t, p=2.0, dim=-1)@torch.transpose(task_emb,0,1))
                if bz>0:
                    loss/=bz
                
                if ood in [4,5]:
                    #in_p=torch.sum(t@torch.transpose(t.mean(dim=0,keepdim=True),0,1))
                    task_emb=nn.functional.normalize(task_emb0, p=2.0, dim=-1)
                    if bz>0:
                        in_p=torch.sum(nn.functional.normalize(t, p=2.0, dim=-1)@torch.transpose(task_emb,0,1))
                    else:
                        in_p=0
                
                
                if ood in [4]:
                    return logits0,loss,task_emb0,in_p
                if ood in [5]:
                    return logits0,loss0,task_emb0,in_p
                


                return logits0,loss,task_emb0,None
            
 

            '''g=torch.normal(torch.zeros_like(w),torch.ones_like(w))
            b=torch.abs(b)
            gw=w+b*g
            support_features_flat=support_features_flat*(1+gw)
            query_features_flat=query_features_flat*(1+gw)
            #support_features_flat=torch.cat((support_features_flat,gw.repeat(support_features_flat.size(0),1)),-1)
            #query_features_flat=torch.cat((query_features_flat,gw.repeat(query_features_flat.size(0),1)),-1)

            support_features_flat = self.fc(support_features_flat)
            query_features_flat = self.fc(query_features_flat)'''
            
            #nf=self.gen_w0.cat_label(support_features_flat,k0)
            #w=self.gen_w0(nf,k0).unsqueeze(0)#[1,2560]
            #b=self.gen_b0(nf,k0).unsqueeze(0)
            #support_features_flat=support_features_flat*(1+w)#+b
            #query_features_flat=query_features_flat*(1+w)#+b
            #support_features_flat=torch.cat((support_features_flat,w.repeat(support_features_flat.size(0),1)),-1)
            #query_features_flat=torch.cat((query_features_flat,w.repeat(query_features_flat.size(0),1)),-1)
            #support_features_flat = self.fc(support_features_flat)
            #query_features_flat = self.fc(query_features_flat)
            nf=self.gen_w0.cat_label(support_features_flat,k0)
            w0=self.gen_w0(nf,k0).unsqueeze(0)#[1,2560]
            b0=self.gen_b0(nf,k0).unsqueeze(0)
            support_features_flat=support_features_flat*(1+w0)+b0
            query_features_flat=query_features_flat*(1+w0)+b0
            support_features_flat = self.fc(support_features_flat)
            query_features_flat = self.fc(query_features_flat)
            '''nf=self.gen_w1.cat_label(support_features_flat,k0)
            w1=self.gen_w1(nf,k0).unsqueeze(0)
            b1=self.gen_b1(nf,k0).unsqueeze(0)
            support_features_flat=support_features_flat*(1+w1)+b1
            query_features_flat=query_features_flat*(1+w1)+b1
            support_features_flat = self.fc[1](support_features_flat)
            query_features_flat = self.fc[1](query_features_flat)
            support_features_flat = self.fc[2](support_features_flat)
            query_features_flat = self.fc[2](query_features_flat)
            nf=self.gen_w2.cat_label(support_features_flat,k0)
            w2=self.gen_w2(nf,k0).unsqueeze(0)
            b2=self.gen_b2(nf,k0).unsqueeze(0)
            support_features_flat=support_features_flat*(1+w2)+b2
            query_features_flat=query_features_flat*(1+w2)+b2'''
            '''nf=self.gen_w0.cat_label(support_features_flat,k0)
            w=self.gen_w0(nf,k0).unsqueeze(0)
            support_features_flat=self.fc(torch.cat((support_features_flat,w.repeat(support_features_flat.size(0),1)),-1))
            query_features_flat=self.fc(torch.cat((query_features_flat,w.unsqueeze(0).repeat(query_features_flat.size(0),1)),-1))'''
            
        #logits=self.adc(support_features_flat,k0,query_features_flat)
        '''if ood in [3,4,5]:
            #query_features_flat=support_features_flat
            d=int(support_features_flat.size(1))
            query_features_flat=support_features_flat.unsqueeze(1)
            eye = torch.eye(n)
            mask=(eye==0)
            support_features_flat=support_features_flat.unsqueeze(0).repeat(n,1,1)#n,n,d
            support_features_flat=support_features_flat[mask].view(n,n-1,d)'''
            
        #logits = self._siamese_distance(query_features_flat,support_features_flat,k0,metric='cosine',ood=ood)#cosine, euclidean,dot
        
        '''class_means, class_precision_matrices = self.compute_class_means_and_precisions(support_features_flat, new_support_label)
        torch.nn.functional.normalize(class_means,-1)#2,d
        torch.nn.functional.normalize(query_features_flat,-1)#q,d
        similarities = query_features_flat @ torch.transpose(class_means, 0, 1)#q,2
        logits=similarities/0.04'''
        if self.config.distance_metric == "mahalanobis":
            class_means, class_precision_matrices = self.compute_class_means_and_precisions(
                support_features_flat, new_labels
            )

            # grabbing the number of classes and query examples for easier use later
            number_of_classes = class_means.size(0)
            number_of_targets = query_features_flat.size(0)

            """
            Calculating the Mahalanobis distance between query examples and the class means
            including the class precision estimates in the calculations, reshaping the distances
            and multiplying by -1 to produce the sample logits
            """
            repeated_target = query_features_flat.repeat(1, number_of_classes).view(
                -1, class_means.size(1)
            )
            repeated_class_means = class_means.repeat(number_of_targets, 1)
            repeated_difference = repeated_class_means - repeated_target
            repeated_difference = repeated_difference.view(
                number_of_targets, number_of_classes, repeated_difference.size(1)
            ).permute(1, 0, 2)
            first_half = torch.matmul(repeated_difference, class_precision_matrices)
            logits = torch.mul(first_half, repeated_difference).sum(dim=2).transpose(1, 0) * -1
        elif self.config.distance_metric == "cosine":
            n=int(support_samples.size(0))
            class_means=torch.cat((support_features_flat[0:k].mean(0,keep_dim=True),support_features_flat[k:].mean(0,keep_dim=True)),0)
            class_means=nn.functional.normalize(class_means, p=2.0, dim=-1)
            query_features_flat=nn.functional.normalize(query_features_flat, p=2.0, dim=-1)
            similarities = query_features_flat @ torch.transpose(class_means, 0, 1)#q,2
            logits=similarities
        else:  # euclidean
            logits = self._protonets_euclidean_classifier(
                support_features_flat,
                query_features_flat,
                new_labels,
            )
        if ood in [1,2]:
            return logits,t
        elif ood in [3]:
            return logits,new_labels
        elif ood in [4,5]:
            return logits,new_labels,w,b
        else:
            return logits

    def compute_class_means_and_precisions(
        self, features: torch.Tensor, labels: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        means = []
        precisions = []
        task_covariance_estimate = self._estimate_cov(features)
        for c in torch.unique(labels):
            # filter out feature vectors which have class c
            class_features = torch.index_select(features, 0, self._extract_class_indices(labels, c))
            # mean pooling examples to form class means
            means.append(torch.mean(class_features, dim=0, keepdim=True).squeeze())
            lambda_k_tau = class_features.size(0) / (class_features.size(0) + 1)
            lambda_k_tau = min(lambda_k_tau, 0.1)
            precisions.append(
                torch.inverse(
                    (lambda_k_tau * self._estimate_cov(class_features))
                    + ((1 - lambda_k_tau) * task_covariance_estimate)
                    + 0.1
                    * torch.eye(class_features.size(1), class_features.size(1)).to(self.device)
                )
            )

        means = torch.stack(means)
        precisions = torch.stack(precisions)

        return means, precisions

    @staticmethod
    def _estimate_cov(
        examples: torch.Tensor, rowvar: bool = False, inplace: bool = False
    ) -> torch.Tensor:
        """
        SCM: Function based on the suggested implementation of Modar Tensai
        and his answer as noted in:
        https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/5

        Estimate a covariance matrix given data.

        Covariance indicates the level to which two variables vary together.
        If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
        then the covariance matrix element `C_{ij}` is the covariance of
        `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.

        Args:
            examples: A 1-D or 2-D array containing multiple variables and observations.
                Each row of `m` represents a variable, and each column a single
                observation of all those variables.
            rowvar: If `rowvar` is True, then each row represents a
                variable, with observations in the columns. Otherwise, the
                relationship is transposed: each column represents a variable,
                while the rows contain observations.

        Returns:
            The covariance matrix of the variables.
        """
        if examples.dim() > 2:
            raise ValueError("m has more than 2 dimensions")
        if examples.dim() < 2:
            examples = examples.view(1, -1)
        if not rowvar and examples.size(0) != 1:
            examples = examples.t()
        factor = 1.0 / (examples.size(1) - 1)
        if inplace:
            examples -= torch.mean(examples, dim=1, keepdim=True)
        else:
            examples = examples - torch.mean(examples, dim=1, keepdim=True)
        examples_t = examples.t()
        return factor * examples.matmul(examples_t).squeeze()

    @staticmethod
    def _extract_class_indices(labels: torch.Tensor, which_class: torch.Tensor) -> torch.Tensor:
        class_mask = torch.eq(labels, which_class)  # binary mask of labels equal to which_class
        class_mask_indices = torch.nonzero(class_mask)  # indices of labels equal to which class
        return torch.reshape(class_mask_indices, (-1,))  # reshape to be a 1D vector

    @staticmethod
    def compute_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        #return nn.functional.binary_cross_entropy(logits[:,1], labels.float())
        return nn.functional.cross_entropy(logits, labels.long())

    def _protonets_euclidean_classifier(
        self,
        support_features: torch.Tensor,
        query_features: torch.Tensor,
        support_labels: torch.Tensor,
    ) -> torch.Tensor:
        class_prototypes = self._compute_class_prototypes(support_features, support_labels)
        logits = self._euclidean_distances(query_features, class_prototypes)
        return logits

    def _compute_class_prototypes(
        self, support_features: torch.Tensor, support_labels: torch.Tensor
    ) -> torch.Tensor:
        means = []
        for c in torch.unique(support_labels):
            # filter out feature vectors which have class c
            class_features = torch.index_select(
                support_features, 0, self._extract_class_indices(support_labels, c)
            )
            means.append(torch.mean(class_features, dim=0))
        return torch.stack(means)

    def _euclidean_distances(
        self, query_features: torch.Tensor, class_prototypes: torch.Tensor
    ) -> torch.Tensor:
        num_query_features = query_features.shape[0]
        num_prototypes = class_prototypes.shape[0]

        distances = (
            (
                query_features.unsqueeze(1).expand(num_query_features, num_prototypes, -1)
                - class_prototypes.unsqueeze(0).expand(num_query_features, num_prototypes, -1)
            )
            .pow(2)
            .sum(dim=2)
        )

        return -distances
    
    def _siamese_distance(self,query_embeddings, support_set_embeddings, k0,l2Norm=False,metric='euclidean',stabilizer=torch.tensor(1e-8).float(),ood=0):#euclidean, cosine
        # L2-Norm
        #(q,s)=(q,d)(d,s)
        n=support_set_embeddings.size(0)
        if (ood in [3,4,5]):
            if metric=='cosine':
                support_set_embeddings=nn.functional.normalize(support_set_embeddings, p=2.0, dim=-1)
                query_embeddings=nn.functional.normalize(query_embeddings, p=2.0, dim=-1)
                similarities = torch.bmm(support_set_embeddings,torch.transpose(query_embeddings,-1,-2)).squeeze(-1)
                logits0=torch.cat((torch.sum(similarities[:k0,:k0-1],-1,True)/(k0-1+stabilizer),torch.sum(similarities[:k0,k0-1:],-1,True)/(n-k0+stabilizer)),-1)/0.04
                logits1=torch.cat((torch.sum(similarities[k0:,:k0],-1,True)/(k0+stabilizer),torch.sum(similarities[k0:,k0:],-1,True)/(n-k0-1+stabilizer)),-1)/0.04
                logits=torch.cat((logits0,logits1),0)
            elif metric=='euclidean':
                similarities = (query_embeddings.unsqueeze(1).repeat(1,n,1) - support_set_embeddings.unsqueeze(0)).pow(2).sum(dim=-1)
                logits=-torch.cat((torch.sum(similarities[:,:k0],-1,True)/(k0+stabilizer),torch.sum(similarities[:,k0:],-1,True)/(n-k0+stabilizer)),-1)
            elif metric=='dot':
                similarities = query_embeddings @ torch.transpose(support_set_embeddings, 0, 1)
                coes=(1/n)*torch.cat((-torch.ones_like(similarities[:,:k0])*n/(2.*torch.sqrt(torch.Tensor([k0]).cuda())+stabilizer),torch.ones_like(similarities[:,k0:])*n/(2.*torch.sqrt(torch.Tensor([n-k0]).cuda())+stabilizer)),-1)
                similarity_sums=nn.functional.sigmoid(torch.sum(similarities*coes,-1,keepdim=True)*0.04)
                logits=torch.cat((1-similarity_sums,similarity_sums),-1)
        else:
            if metric=='cosine':
                support_set_embeddings=nn.functional.normalize(support_set_embeddings, p=2.0, dim=-1)
                query_embeddings=nn.functional.normalize(query_embeddings, p=2.0, dim=-1)
                similarities = query_embeddings @ torch.transpose(support_set_embeddings, 0, 1)
                logits=torch.cat((torch.sum(similarities[:,:k0],-1,True)/(k0+stabilizer),torch.sum(similarities[:,k0:],-1,True)/(n-k0+stabilizer)),-1)/0.04
            elif metric=='euclidean':
                similarities = (query_embeddings.unsqueeze(1).repeat(1,n,1) - support_set_embeddings.unsqueeze(0)).pow(2).sum(dim=-1)
                logits=-torch.cat((torch.sum(similarities[:,:k0],-1,True)/(k0+stabilizer),torch.sum(similarities[:,k0:],-1,True)/(n-k0+stabilizer)),-1)
            elif metric=='dot':
                similarities = query_embeddings @ torch.transpose(support_set_embeddings, 0, 1)
                coes=(1/n)*torch.cat((-torch.ones_like(similarities[:,:k0])*n/(2.*torch.sqrt(torch.Tensor([k0]).cuda())+stabilizer),torch.ones_like(similarities[:,k0:])*n/(2.*torch.sqrt(torch.Tensor([n-k0]).cuda())+stabilizer)),-1)
                similarity_sums=nn.functional.sigmoid(torch.sum(similarities*coes,-1,keepdim=True)*0.04)
                logits=torch.cat((1-similarity_sums,similarity_sums),-1)
        #coes=(1/n)*torch.cat((-torch.ones_like(similarities[:,:k0])*n/(torch.sqrt(2.*torch.Tensor([k0]).cuda())+stabilizer),torch.ones_like(similarities[:,k0:])*n/(torch.sqrt(2.*torch.Tensor([n-k0]).cuda())+stabilizer)),-1)
        #similarity_sums=nn.functional.sigmoid(torch.sum(similarities*coes,-1).unsqueeze(-1)/0.04)
        #logits=torch.cat((1-similarity_sums,similarity_sums),-1)
        
        # Masking: Remove padded support set artefacts
        '''mask = torch.zeros_like(similarities)
        for task_idx in range(support_set_embeddings.shape[0]):
            real_size = support_set_size[task_idx]
            if real_size > 0:
                mask[task_idx, :, :real_size] = torch.ones_like(mask[task_idx, :, :real_size])

        # Compute similarity values
        similarities = similarities * mask
        similarity_sums = similarities.sum(dim=2)  # For every query molecule: Sum over support set molecules

        # Scaling
        if scaling == '1/N':
            stabilizer = torch.tensor(1e-8).float()
            similarity_sums = 1/(2.*support_set_size.reshape(-1, 1) + stabilizer) * similarity_sums
        elif scaling == '1/sqrt(N)':
            stabilizer = torch.tensor(1e-8).float()
            similarity_sums = 1 / (2.*torch.sqrt(support_set_size.reshape(-1, 1).float()) + stabilizer) * similarity_sums'''
        return logits
