import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init
import torch.backends.cudnn as cudnn
import layers

class EmbedderMoG(nn.Module):
    def __init__(self, dim_input, num_outputs, out_type='param_cat', 
                 num_proto=5, num_ems=2, dim_feat=128, num_heads=3, tau=10.0,
                 distr_emb_args=None):
        super(EmbedderMoG, self).__init__()
        self.rff = RFFNet(dim_input, dim_feat)
        self.diem = layers.DIEM(dim_feat, H=num_heads, p=num_proto, L=num_ems, 
            tau=tau, out=out_type, distr_emb_args=distr_emb_args)
        self.outnet = OutputNet(self.diem.outdim, num_outputs, dim_input)

    def forward(self, X, cards=None):
        B, N_max, d0 = X.shape
        S = self.rff(X)
        mask = torch.ones(B, N_max).to(S)
        if cards is not None:
            for ii in range(B):
                mask[ii][cards[ii]:] = 0.0
        
        FS = self.diem(S, mask)
        return self.outnet(FS)
    
class RFFNet(nn.Module):
    def __init__(self, d0, d):
        super(RFFNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(d0, d),
        )
        
    def forward(self, X):
        return self.net(X)

class OutputNet(nn.Module):

    def __init__(self, din, K, d0):
        super(OutputNet, self).__init__()
        dout = 2*K + K*d0
        dhid = (din+dout)//2
        self.net = nn.Sequential(
            nn.Linear(din, dhid),
            nn.ReLU(),
            nn.Linear(dhid, dout),
        )
        
        self.K = K
        self.d0 = d0

    def forward(self, FS):
        
        B = FS.shape[0]
        K, d0 = self.K, self.d0
        MoGvecs = self.net(FS)
        pi = torch.softmax(MoGvecs[:,:K], -1)
        mu = MoGvecs[:,K:(K+K*d0)].reshape(B,K,d0)
        sigma = F.softplus(MoGvecs[:,(K+K*d0):]).reshape(B,K,1)
        return pi, (mu, sigma)
