
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.init as init
import numpy as np
import math

device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Attend2Pack(nn.Module):
    def __init__(self,originaldim,embedingdim,num_head,FFNdim,AFFNnum,inchannel,outchannel,C,outdim,L,W):
        super().__init__()
        self.AFFNnum=AFFNnum
        self.AFFNlist=nn.ModuleList()
        self.embeding=nn.Linear(originaldim,embedingdim)
        for i in range(AFFNnum):
            self.AFFNlist.append(AttentionFFN(embedingdim,num_head,FFNdim).to(device))
        self.FrontierEmbedding=FrontierEmbedding(inchannel,outchannel,embedingdim,L,W).to(device)
        self.DecodingSequence=DecodingSequence(embedingdim,num_head,C).to(device)
        self.DecodingPosition=DecodingPosition(embedingdim,outdim).to(device)
        
    def clacb(self,input,scale):
        x=self.embeding(input)
        for i in range(self.AFFNnum):
            x=self.AFFNlist[i](x,scale)
        return x
        
class MHA(nn.Module):
    def __init__(self,embedingdim,num_head):
        super().__init__()
        self.embedingdim=embedingdim
        self.num_head=num_head
        self.edn=embedingdim//num_head
        self.Q=nn.Linear(embedingdim,embedingdim,bias=False)
        self.K=nn.Linear(embedingdim,embedingdim,bias=False)
        self.V=nn.Linear(embedingdim,embedingdim,bias=False)
        self.softmax=nn.Softmax(dim=-1)
        self.O=nn.Linear(embedingdim,embedingdim,bias=False)
    
    def forward(self,input,scale):
        batch,n,dim=input.shape
        nh=self.num_head
        edn=self.edn
        q=self.Q(input).reshape(batch,n,nh,edn).transpose(1,2)
        k=self.K(input).reshape(batch,n,nh,edn).transpose(1,2)
        v=self.V(input).reshape(batch,n,nh,edn).transpose(1,2)
        dist=torch.matmul(q,k.transpose(2,3))*scale
        dist=self.softmax(dist)
        att=torch.matmul(dist,v)
        att=att.transpose(1,2).reshape(batch,n,dim)
        att=self.O(att)
        return att
                  
class AttentionFFN(nn.Module):
    def __init__(self,embedingdim,num_head,lineardim):
        super().__init__()
        self.mha=MHA(embedingdim,num_head)
        self.norm=nn.LayerNorm(embedingdim)
        self.linear1=nn.Linear(embedingdim,lineardim)
        self.linear2=nn.Linear(lineardim,embedingdim)
        
    def forward(self,input,scale):
        x=self.norm(input)
        x=self.mha(x,scale)+x
        x=self.norm(x)
        x2=torch.relu(self.linear1(x))
        x=torch.relu(self.linear2(x2))+x
        return x
    
class FrontierEmbedding(nn.Module):
    def __init__(self,inchannel,outchannel,embedingdim,L,W):
        super().__init__()
        self.conv1=nn.Conv2d(inchannel,outchannel,kernel_size=3, stride=2, padding=1)
        self.conv2=nn.Conv2d(outchannel,outchannel,kernel_size=3, stride=2, padding=1)
        self.conv3=nn.Conv2d(outchannel,outchannel,kernel_size=3, stride=2, padding=1)
        if(L%8==0):
            L=L//8
        else:
            L=L//8+1
        if(W%8==0):
            W=W//8
        else:
            W=W//8+1
        self.norm1=nn.LayerNorm([outchannel,L,W])
        self.linear1=nn.Linear(outchannel*L*W,embedingdim)
        self.LinearFrontier=nn.Linear(embedingdim,embedingdim,bias=False)

    def forward(self,input):
        x=self.conv1(input)
        x=self.conv2(x)
        x=self.conv3(x)
        x=self.norm1(x)
        x=torch.relu(x)
        x=torch.flatten(x,1)
        x=self.linear1(x)
        x=self.LinearFrontier(x)
        return x
    
class DecodingSequence(nn.Module):
    def __init__(self,embedingdim,num_head,C):
        super().__init__()
        self.embedingdim=embedingdim
        self.num_head=num_head
        self.C=C
        self.LinearS_leftover=nn.Linear(embedingdim,embedingdim,bias=False)
        self.Wk=nn.Linear(embedingdim,embedingdim,bias=False)
        self.Wv=nn.Linear(embedingdim,embedingdim,bias=False)
        self.Lineark=nn.Linear(embedingdim,embedingdim,bias=False)
        self.softmax=nn.Softmax(dim=1)
        self.LinearQ=nn.Linear(embedingdim,embedingdim,bias=False)

    def forward(self,inputb,inputf,maskseq,IsGreedy):
        batch_size,num_n,dim=inputb.shape
        edn=self.embedingdim//self.num_head
        totalnp=np.linspace(0,num_n-1,num_n)
        totalnp=totalnp.astype(int)
        totallist=totalnp.tolist()
        totalset=set(totallist)
        maskseqsetlist=[set(maskseq[i]) for i in range(batch_size)]
        unmasksetlist=[list(totalset.difference(maskseqsetlist[i])) for i in range(batch_size)]
        unmaskb=torch.zeros(batch_size,num_n-len(maskseq[0]),self.embedingdim)
        unmasksetlist=np.array(unmasksetlist)
        unmasksetlist=unmasksetlist.astype(int)
        unmasksetlist=unmasksetlist.transpose(1,0).tolist()
        templist=[i for i in range(batch_size)]
        unmaskb=torch.transpose(inputb[templist,unmasksetlist,:],0,1)
        unmaskb=self.LinearS_leftover(unmaskb)
        inputq=(torch.mean(unmaskb,axis=1)+inputf)/2
        inputq=torch.unsqueeze(inputq,2).reshape(batch_size,self.num_head,edn,1)
        k=self.Wk(inputb).reshape(batch_size,num_n,self.num_head,edn).transpose(1,2)
        v=self.Wv(inputb).reshape(batch_size,num_n,self.num_head,edn).transpose(1,2)
        ct=torch.matmul(k,inputq)/math.sqrt(edn)
        ct=ct.squeeze().transpose(1,2)
        maskseq=np.array(maskseq)
        maskseq=maskseq.astype(int).transpose(1,0)      
        ct[templist,maskseq,:]=-np.inf
        ct=self.softmax(ct)
        ct=ct.transpose(1,2).unsqueeze(2)
        q2=torch.matmul(ct,v)
        q2=q2.reshape(batch_size,1,dim)
        k2=self.Lineark(inputb)
        ct2=self.C*torch.tanh(torch.matmul(self.LinearQ(q2),torch.transpose(k2,1,2))/math.sqrt(self.embedingdim))
        ct2=torch.squeeze(ct2)
        ct2[templist,maskseq]=-np.inf
        policy=self.softmax(ct2)
        if(IsGreedy):
            seqidx=torch.max(policy,1)[1]
            return seqidx
        else:
            seqidx=torch.multinomial(policy,1)
            seqidx=torch.squeeze(seqidx)
            policypro=policy[templist,seqidx]
            return seqidx,policypro

class DecodingPosition(nn.Module):
    def __init__(self,embedingdim,out_dim):
        super().__init__()
        self.embedingdim=embedingdim
        self.LinearW_selected=nn.Linear(embedingdim,embedingdim,bias=False)
        self.LinearP_leftover=nn.Linear(embedingdim,embedingdim,bias=False)
        self.LinearOut=nn.Linear(embedingdim,out_dim,bias=False)
        self.softmax=nn.Softmax(dim=1)

    def forward(self,selectb,inputb,inputf,maskseq,maskposition,IsGreedy):
        batch_size=inputb.size()[0]
        num_n=inputb.size()[1]
        totalnp=np.linspace(0,num_n-1,num_n)
        totalnp=totalnp.astype(int)
        totallist=totalnp.tolist()
        totalset=set(totallist)
        maskseqsetlist=[set(maskseq[i]) for i in range(batch_size)]
        unmasksetlist=[list(totalset.difference(maskseqsetlist[i])) for i in range(batch_size)]
        unmaskb=torch.zeros(batch_size,num_n-len(maskseq[0]),self.embedingdim)
        unmasksetlist=np.array(unmasksetlist)
        unmasksetlist=unmasksetlist.astype(int)
        unmasksetlist=unmasksetlist.transpose(1,0)
        templist=[i for i in range(batch_size)]
        unmaskb=torch.transpose(inputb[templist,unmasksetlist,:],0,1)
        unmaskb=self.LinearP_leftover(unmaskb)
        inputb_select=inputb[templist,selectb,:]
        bst=self.LinearW_selected(inputb_select)
        bst=torch.squeeze(bst,1)
        if(num_n!=len(maskseq[0])):
            inputq=(bst+torch.mean(unmaskb,axis=1)+inputf)/3
        else:
            inputq=(bst+inputf)/2
        ct2=self.LinearOut(inputq)
        for i in range(batch_size):
            ct2[i,maskposition[i]]=-np.inf
        policy=self.softmax(ct2)
        if(IsGreedy):
            return torch.max(policy,1)[1]
        else:
            posidx=torch.multinomial(policy,1)
            posidx=torch.squeeze(posidx)
            policypro=policy[templist,posidx]
            return posidx,policypro
        

