import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import Sequential, Linear, ReLU, LSTM

import numpy as np

from torch_geometric.data import Data, Batch
from torch_geometric.nn import GINConv
    
class MatcherMLP(nn.Module): 
    def __init__(self, N, M, input_channels=1, mid_nc = 100, n_layers=3 ,norm='batch'):
        super(MatcherMLP,self).__init__()
        self.NxM = N*M
        self.input_nc = input_channels
        self.mid_nc = mid_nc
        self.n_layers = n_layers
        self.norm = norm
        self.build()

        
    def build(self):
        self.layers = nn.ModuleList()
        input_nc = self.input_nc * self.NxM * 2
        output_nc = self.mid_nc
        for l in range(self.n_layers+1): # +1 for Linear matching estimator.
            if l==self.n_layers-1:
                output_nc = self.NxM
                
            dense = nn.Linear(in_features=input_nc,out_features=output_nc,bias=False)
                
            processes = [dense]
            if self.norm=='batch':
                batch = nn.BatchNorm1d(output_nc)
                processes.append(batch)
            if l < self.n_layers-1:
                processes.append(nn.PReLU())
            self.layers.append(
                nn.Sequential(*processes)
            )
            input_nc = output_nc
        
    def forward(self,Ss):
        sab, sba = Ss
        batch_size = sab.shape[0]
        N = sab.shape[-2]
        M = sab.shape[-1]
        x = torch.cat((sab.view(batch_size,-1),sba.view(batch_size,-1)),dim=-1)
        for i,l in enumerate(self.layers):
            x = l(x)
        #x = self.layer1(x)
        x = x.view(batch_size,N,M)               
        return x

    
class StreamSelectiveLayer(nn.Module):
    def __init__(self,n_streams,base_layer,*args,**kwargs):
        super(StreamSelectiveLayer,self).__init__()
        self.layers = nn.ModuleList()
        self.layers += [base_layer(*args,**kwargs) for i in range(n_streams)]
        
    def forward(self,x,str_id=0):
        return self.layers[str_id](x)
        
class EncoderAttention(nn.Module):
    softmax2 = nn.Softmax(dim=2)
    softmax3 = nn.Softmax(dim=3)
    def __init__(self,input_dim,output_dim, kq_dim,use_batch_norm=True,activation=nn.PReLU(), asymmetric=False):
        super(EncoderAttention,self).__init__()
        
        self.conv_key = nn.Conv2d(input_dim, kq_dim, kernel_size=1)
        self.conv_query = nn.Conv2d(input_dim, kq_dim, kernel_size=1)
        self.conv_val = nn.Conv2d(input_dim, output_dim, kernel_size=1)
        
        self.asymmetric = asymmetric
        if use_batch_norm:            
            if asymmetric:
                self.bn = StreamSelectiveLayer(2,nn.BatchNorm2d,output_dim)
            else:
                self.bn = nn.BatchNorm2d(output_dim)
        else:
            self.bn = None
        self.act = activation
        
    def forward(self,x,dim=2):        
        # dim = 2 or 3.
        key = self.conv_key(x)
        query = self.conv_query(x)
        val = self.conv_val(x)

        if dim == 2:
            weight = self.softmax2((key.unsqueeze(2) * query.unsqueeze(3)).sum(dim=1,keepdims=True))
        elif dim==3:
            weight = self.softmax3((key.unsqueeze(3) * query.unsqueeze(4)).sum(dim=1,keepdims=True))
        z = (weight * val.unsqueeze(dim)).sum(dim)        

        if self.bn is not None:
            if self.asymmetric:
                z = self.bn(z,dim-2)
            else:
                z = self.bn(z)
                
        return self.act(z)  

    
class EncoderAttention2stream(EncoderAttention):
    def __init__(self,*args,**kwargs):
        super(EncoderAttention2stream,self).__init__(*args,**kwargs)        
        
    def forward(self,x,batch_size:int):        
        # dim = 2 or 3.
        key = self.conv_key(x)
        query = self.conv_query(x)
        val = self.conv_val(x)
        

        weight1 = self.softmax2((key[:batch_size].unsqueeze(2) * query[:batch_size].unsqueeze(3)).sum(dim=1,keepdims=True))
        weight2 = self.softmax3((key[batch_size:].unsqueeze(3) * query[batch_size:].unsqueeze(4)).sum(dim=1,keepdims=True))
        z1 = (weight1 * val[:batch_size].unsqueeze(2)).sum(dim=2)
        z2 = (weight1 * val[batch_size:].unsqueeze(3)).sum(dim=3)

        if self.bn is not None:
            if self.asymmetric:
                z = torch.cat([self.bn(z1,0),self.bn(z2,1)],dim=0)
            else:
                z = torch.cat([z1,z2])
                z = self.bn(z)                
        return self.act(z)  
        
        
@torch.jit.script
def max_pool_concat(x,z,dim:int):    
    z_max = z.max(dim,keepdim=True)[0]
    z = z_max.expand(z.shape)
    return torch.cat([x,z],dim=1)

class EncoderMaxPool(nn.Module):
    def __init__(self,input_dim,output_dim,output_dim_max=None,use_batch_norm=True,activation=nn.PReLU(), asymmetric=False):
        super(EncoderMaxPool,self).__init__()
        if output_dim_max is None:
            output_dim_max = output_dim * 3

        self.conv_max = nn.Conv2d(input_dim, output_dim_max, kernel_size=1)
        self.conv = nn.Conv2d(input_dim+output_dim_max, output_dim, kernel_size=1,bias=False)
        
        self.asymmetric = asymmetric        
        if use_batch_norm:
            if asymmetric:
                self.bn = StreamSelectiveLayer(2,nn.BatchNorm2d,output_dim)
            else:
                self.bn = nn.BatchNorm2d(output_dim)
        else:
            self.bn = None
        self.act = activation

    def forward(self,x,dim=2): 
        # dim = 2 or 3.
        z = self.conv_max(x)
        z = max_pool_concat(x,z,dim)
        z = self.conv(z)
        if self.bn is not None:
            if self.asymmetric:
                z = self.bn(z,dim-2)
            else:
                z = self.bn(z)
        return self.act(z)   

class EncoderMaxPool2stream(EncoderMaxPool):
    def __init__(self,*args,**kwargs):
        super(EncoderMaxPool2stream,self).__init__(*args,**kwargs)

    def forward(self,x,batch_size:int): 
        # dim = 2 or 3.
        z = self.conv_max(x)
        z = torch.cat([max_pool_concat(x[:batch_size],z[:batch_size],2),
                       max_pool_concat(x[batch_size:],z[batch_size:],3)], dim=0)
        z = self.conv(z)
        if self.bn is not None:
            if self.asymmetric:
                z = torch.cat([self.bn(z[:batch_size],0),
                               self.bn(z[batch_size:],1)],dim=0)                
            else:
                z = self.bn(z)
        return self.act(z)    
    
    
@torch.jit.script
def cross_concat(Z,batch_size:int):    
    Za = Z[:batch_size]
    Zb = Z[batch_size:]
    return torch.cat([torch.cat([Za,Zb],dim=1), torch.cat([Zb,Za],dim=1)],dim=0)


class FeatureWeavingLayer(torch.nn.Module):
    def __init__(self, input_dim, output_dim, inner_dim,
                 encoder=EncoderMaxPool2stream, use_batch_norm=True,
                 activation=nn.PReLU(), asymmetric=False):
        
        super(FeatureWeavingLayer,self).__init__()
        self.E = encoder(input_dim,output_dim,inner_dim,use_batch_norm,activation,asymmetric)
                    
    def forward(self,Z):
        batch_size = Z.shape[0]//2
        Z = cross_concat(Z,batch_size)
        return self.E(Z,batch_size)
    
    
class MatcherWeaveNet(torch.nn.Module):
    def __init__(self, L, D=64, inner_conv_out_channels=64, use_resnet=True, asymmetric=False):
        super(MatcherWeaveNet,self).__init__()
        self.use_resnet=use_resnet
        self.asymmetric=asymmetric
        self.L = L
        assert(D%2==0)
        self.D = D
        
        self.inner_conv_out_channels = inner_conv_out_channels
        
        self.build()
        
    def build(self):
        self.encoders, dim = self.build_encoders()
            
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(dim, 1, kernel_size=1,bias=False),
            nn.BatchNorm2d(1)   
        )  
        
    def build_encoders(self):
        encoders = torch.nn.ModuleList()
        
        input_dim = 2
        if self.asymmetric:
            input_dim = 4
            
        for i in range(self.L):
            output_dim = self.D
            encoders.append(
                FeatureWeavingLayer(input_dim,output_dim//2,
                                    self.inner_conv_out_channels,
                                    EncoderMaxPool2stream,
                                    asymmetric=self.asymmetric)                
            )
            input_dim = output_dim
        return encoders, input_dim
         
    def forward(self,Ss,depth=-1):
        Za, Zb = Ss
        batch_size,N,M = Za.shape 
        Za = Za.view(batch_size,1,N,M)
        Zb = Zb.permute(0,2,1).view(batch_size,1,N,M)
        if self.asymmetric:
            condition = Za.new_zeros(batch_size,1,N,M) # 0
            Za = torch.cat([Za,condition], dim=1)
            condition += 1 # 1
            Zb = torch.cat([Zb,condition], dim=1)
        Z = torch.cat([Za,Zb],dim=0)
        Z_keep = None
        
        
        if depth==-1:
            depth = len(self.encoders)
         
        for i, FWLayer in enumerate(self.encoders):
            Z = FWLayer(Z)
            if self.use_resnet and i%2==0:
                # use residual network
                if Z_keep is not None:
                    Z = Z + Z_keep
                Z_keep = Z
            
            if depth < i:
                # depth parameter is used to fexibly decide the depth of the network.
                # This is intended for boosting the training speed with the incremental curriculum learning.
                break
        Z = self.conv1x1(cross_concat(Z,batch_size))
        
        m = (Z[:batch_size]+Z[batch_size:])/2
        return m.view(batch_size,N,M)
    
class MatcherWeaveNet_A(MatcherWeaveNet):
    def __init__(self, L, D=64, key_query_channels=32, use_resnet=True, asymmetric=False):
        super(MatcherWeaveNet_A,self).__init__(L,D, key_query_channels, use_resnet,asymmetric)
        
    def build_encoders(self):
        encoders = torch.nn.ModuleList()
        
        input_dim = 2
        if self.asymmetric:
            input_dim = 4
            
        for i in range(self.L):
            output_dim = self.D
            encoders.append(
                # self.inner_conv_out_channels is "key_query_channels"
                FeatureWeavingLayer(input_dim,output_dim//2,self.inner_conv_out_channels,EncoderAttention2stream) 
            )         
            input_dim = output_dim
        return encoders, input_dim        

class MatcherWeaveNetDual(MatcherWeaveNet):
    def __init__(self, L, D=64, inner_conv_out_channels=64, use_resnet=True):
        # This model innately asymmetric and side-identifier set by asymmetric option is not required -> set asymmetric=False
        super(MatcherWeaveNetDual,self).__init__(L,D, inner_conv_out_channels, use_resnet,asymmetric=False)
        
        
    def build(self):
        self.encoders_A, dim = self.build_encoders()
        self.encoders_B, _ = self.build_encoders()
        assert(len(self.encoders_A)==len(self.encoders_B))
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(dim, 1, kernel_size=1,bias=False),
            nn.BatchNorm2d(1)   
        )
        
    def forward(self,Ss,depth=-1):
        Za, Zb = Ss
        batch_size,N,M = Za.shape 
        Za = Za.view(batch_size,1,N,M)
        Zb = Zb.permute(0,2,1).view(batch_size,1,N,M)
        Z = torch.cat([Za,Zb],dim=0)
        Z_keep = None
        
        
        if depth==-1:
            depth = len(self.encoders_A)
        
        for i, (FWLayerA,FWLayerB) in enumerate(zip(self.encoders_A,self.encoders_B)):
            Z = torch.cat([FWLayerA(Z[:batch_size]), FWLayerB(Z[batch_size:])],dim=0)
            if self.use_resnet and i%2==0:
                # use residual network
                if Z_keep is not None:
                    Z = Z + Z_keep
                Z_keep = Z
            
            if depth < i:
                # depth parameter is used to fexibly decide the depth of the network.
                # This is intended for boosting the training speed with the incremental curriculum learning.
                break
        Z = self.conv1x1(cross_concat(Z,batch_size))
        
        m = (Z[:batch_size]+Z[batch_size:])/2
        return m.view(batch_size,N,M)

# Single Stream WeaveNet (=DBM+set encoder)
class MatcherSSWN(torch.nn.Module):
    def __init__(self, L, D=64, inner_conv_out_channels=64, use_resnet=True):
        super(MatcherSSWN,self).__init__()
        self.use_resnet=use_resnet
        self.L = L
        self.D = D
        
        self.inner_conv_out_channels = inner_conv_out_channels
        
        self.build()
        
    def build(self):
        self.encoders, dim = self.build_encoders()
            
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(dim, 1, kernel_size=1,bias=False),
            nn.BatchNorm2d(1)   
        )  
        
    def build_encoders(self):
        encoders = torch.nn.ModuleList()
        
        input_dim = 2
            
        for i in range(self.L):
            output_dim = self.D
            encoders.append(
                EncoderMaxPool(input_dim,output_dim,self.inner_conv_out_channels)                
            )                                    
            input_dim = output_dim
        return encoders, input_dim
         
    def forward(self,Ss,depth=-1):
        sab, sba = Ss
        batch_size,N,M = sab.shape 
                
        
        m = torch.cat([sab.view(batch_size,1,N,M),sba.permute(0,2,1).view(batch_size,1,N,M)], dim=1)
        
        m_keep = None
        
        
        if depth==-1:
            depth = len(self.encoders)
                
        for i, E in enumerate(self.encoders):
            dim_aggregate = 2+(i%2)            
            m = E(m,dim=dim_aggregate)

            if self.use_resnet and i%2==0:
                # use residual network
                if m_keep is not None:
                    m = m + m_keep
                m_keep = m                            
            
            if depth < i:
                # depth parameter is used to fexibly decide the depth of the network.
                # This is intended for boosting the training speed with the incremental curriculum learning.
                break
            
        m = self.conv1x1(m)
        return m.view(batch_size,N,M)

# Attention Based DBM    
class MatcherDBM_A(MatcherSSWN):
    def __init__(self, L, D=64, key_query_channels=64, use_resnet=True):
        super(MatcherDBM_A,self).__init__(L,D, inner_conv_out_channels=key_query_channels, use_resnet=use_resnet)
        
    def build_encoders(self):
        encoders = torch.nn.ModuleList()
        
        input_dim = 2
            
        for i in range(self.L):
            output_dim = self.D
            encoders.append(
                # self.inner_conv_out_channels is "key_query_channels"
                EncoderAttention(input_dim,output_dim,self.inner_conv_out_channels)                
            )         
            input_dim = output_dim
        return encoders, input_dim        
    
# Deep Bipartite Match
class MatcherDBM_P(torch.nn.Module):
    def __init__(self, L, D=64, use_resnet=True):
        super(MatcherDBM_P,self).__init__()
        self.use_resnet=use_resnet
        self.L = L
        self.D = D        
        self.build()
        
        
    def build(self):
        self.encoders = torch.nn.ModuleList()
        
        input_dim = 2
                    
        for i in range(self.L):
            output_dim = self.D
            if i == self.L-1:
                output_dim = 1
            self.encoders.append(
                nn.Sequential(
                    nn.Conv2d(input_dim*2, output_dim, kernel_size=1,bias=False),
                    nn.BatchNorm2d(output_dim),
                    nn.PReLU()
                )
            )
            input_dim = output_dim
        return
         
    def forward(self,Ss,depth=-1):
        sab, sba = Ss
        batch_size,N,M = sab.shape 
        
        m = torch.cat([sab.view(batch_size,1,N,M),sba.permute(0,2,1).view(batch_size,1,N,M)], dim=1)
        
        assert(m.shape[1]==2)
        m_keep = None
        
        
        if depth==-1:
            depth = len(self.encoders)
                
        for i, E in enumerate(self.encoders):
            dim = 2+i%2 # alternatively do the row/column-wise message passing
            m_max = m.max(dim,keepdim=True)[0].repeat_interleave(N,dim)
            m = E(torch.cat([m,m_max],dim=1))

            if self.use_resnet and i%2==0:
                # use residual network
                if m_keep is not None:
                    m = m + m_keep
                m_keep = m                            
            
            if depth < i:
                # depth parameter is used to fexibly decide the depth of the network.
                # This is intended for boosting the training speed with the incremental curriculum learning.
                break
            
        return m.view(batch_size,N,M)
    
    
class MatcherGIN(torch.nn.Module):
    
    def __init__(self, N, M, device, L, D):
        super(MatcherGIN,self).__init__()

        self.device = device

        self.graphtype = 'Normal' # 'Normal', 'Line', 'Bipartite'
        if self.graphtype == 'Normal':
            assert N==M # if N!=M: Line or Bipartite
            self.n_node = 2*N
            self.n_feature = 2*N
            self.edge_index = torch.stack([torch.tensor([[i]*N for i in range(2*N)]).view(-1),torch.tensor([*range(N,2*N)]*N+[*range(N)]*N)])
        elif self.graphtype == 'Line':
            self.n_node = N*M
            self.n_feature = 2
            # calculate edge_index
            ind = (torch.tensor([*range(N*M)]).view(N,M)).tolist()
            ind2 = (torch.tensor([*range(N*M)]).view(N,M).t()).tolist()
            nbs = [ind[i] for i in range(N)] + [ind2[i] for i in range(M)] # neighbor nodes
            r1,r2=[],[]
            for nb in nbs:
                n = len(nb)
                r1 += (torch.tensor([[nb[i]]*(n-1) for i in range(n)]).view(-1)).tolist()
                r2 += (torch.tensor([nb[0:i]+nb[i+1:n] for i in range(n)]).view(-1)).tolist()
            self.edge_index=torch.tensor([r1,r2])
 

        self.edge_index = self.to_cuda(self.edge_index)
        self.num_gc_layers = L
        self.dim = D

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        for i in range(self.num_gc_layers):
            if i:
                nn1 = nn.Sequential(
                    nn.Linear(self.dim, self.dim), 
                    nn.ReLU(), 
                    nn.Linear(self.dim, self.dim)
                    )
            else:
                nn1 = nn.Sequential(
                    nn.Linear(self.n_feature, self.dim), 
                    nn.ReLU(), 
                    nn.Linear(self.dim, self.dim)
                    )

            conv = GINConv(nn1)
            bn = torch.nn.BatchNorm1d(self.dim)

            self.convs.append(conv)
            self.bns.append(bn)

        self.linear = nn.Sequential(
            nn.Linear(self.dim*self.num_gc_layers*(N+M), N*M,bias=False),
            nn.BatchNorm1d(N*M)
        )
        
        

    def to_cuda(self,x):
        if self.device is None:
            return x
        return x.to(self.device)
    
    def input2graph(self, Ss):
        sab, sba = Ss
        batch_size = sab.shape[0]
        N = sab.shape[-2]
        M = sab.shape[-1]
        lup   = np.zeros([batch_size,M,N])
        lup = torch.from_numpy(lup.astype(np.float32)).clone()
        lup = lup.view(batch_size, M, N)
        if sab.is_cuda:
            lup = lup.to(sab.device)
        sab10 = torch.cat((lup,sab),dim=2)
        sba10 = torch.cat((sba,lup),dim=2)

        if self.graphtype == 'Normal':
            datalist = [Data(x=torch.cat([sab10[b],sba10[b]]), 
                             edge_index=self.edge_index) for b in range(batch_size)]            
        elif self.graphtype == 'Line':
            datalist = [Data(x=torch.stack([sab[b].reshape(-1),(sba[b].t()).reshape(-1)]).t(), 
                             edge_index=self.edge_index) for b in range(batch_size)]
        batch = Batch.from_data_list(datalist)
        return batch

    def forward(self, Ss):
        sab, sba = Ss
        batch_size = sab.shape[0]
        N = sab.shape[-2] # n_male
        M = sab.shape[-1] # n_female
        
        batch = self.input2graph(Ss)
        x, edge_index = batch.x, batch.edge_index
        
        xs = []
        for i in range(self.num_gc_layers):
            x = F.relu(self.convs[i](x, edge_index))
            x = self.bns[i](x)
            xs.append(x)

        x = torch.cat(xs,1)  
        x = x.view(batch_size,-1)
        x = self.linear(x)
        x = x.view(batch_size,N,M)
        
        return x
