import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.distributions.one_hot_categorical import OneHotCategorical
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical


def cat_softmax(probs, mode, tau=1, hard=False, dim=-1):
    DEVICE = 'cuda' if probs.is_cuda else 'cpu'
    if mode.lower() == 'reinforce' or mode.lower() == 'sample':
        cat_distr = OneHotCategorical(probs=probs)       
        return cat_distr.sample()
    elif mode.lower()=='argmax':
        tmp = F.one_hot(probs.argmax(-1), num_classes=probs.shape[-1])
        return tmp.float()
    elif mode.lower() == 'gumbel':
        cat_distr = RelaxedOneHotCategorical(tau, probs=probs)
        y_soft = cat_distr.rsample()
    
    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(probs, device=DEVICE).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret


class MsgGenLSTM(nn.Module):
    '''
        Give the hidden generated by DataEncoderMLP, [N_B, 1, hidden], use GRU
        to generate message, [N_B, Max_len, hidden], and the mask having same shape
        The input size is voc_size, initialized as all zeros.
        For the mode, use 'gumbel' during interaction (or for task),
        use sample during distillation
        
        t = DataEncoderMLP()
        h0=t.forward(data_batch)
        msg_gen = MsgGenLSTM()
        message, logits = msg_gen.forward(h0, h0)
        
        message and logits are [NB, L, V], use message.argmax(-1) to get message
    '''
    def __init__(self, L=20, V=10, tau=1., hidden_size=128):
        super(MsgGenLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.msg_length = L
        self.input_size = V
        self.output_size = V
        self.tau = tau
        
        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)
        
    def forward(self, h_0, c_0, tau=1., mode='gumbel'):
        '''
            The size of h_0 and c_0 is [N_B, 1, hidden], we should firstly convert 
            them to [N_B, hidden]
        '''        
        DEVICE = 'cuda' if h_0.is_cuda else 'cpu'
        batch_size = h_0.size(0)
        decoder_input = torch.zeros((batch_size,self.input_size),device=DEVICE)
        decoder_hidden = h_0.squeeze(1)
        decoder_cell = c_0.squeeze(1)
        message = []
        logits = []             

        for _ in range(self.msg_length):
            decoder_hidden, decoder_cell = \
                self.lstm(decoder_input, (decoder_hidden, decoder_cell))
            
            logit = self.out(decoder_hidden)
            logits.append(logit)
            probs = F.softmax(logit, dim=1)
            
            if self.training:
                predict = cat_softmax(probs, mode=mode, tau=tau, hard=True, dim=1)
            else:
                predict = cat_softmax(probs, mode='argmax', tau=tau, hard=True, dim=1)
            #_mask = _mask * (1 - predict[:, -1])        # The last position is EOS
            
            message.append(predict)
            decoder_input = predict
        
        message = torch.stack(message).transpose(0,1)           # Shape [NB, L, V]
        logits = torch.stack(logits).transpose(0,1) 
        
        return message, logits
    

class MsgDecoderLSTM(nn.Module):
    '''
        t = DataEncoderMLP()
        h0=t.forward(data_batch)
        msg_gen = MsgGenLSTM()
        msg, logits = msg_gen.forward(h0, h0)
        msg_de = MsgDecoderLSTM()
        last_hidden, last_hidden = msg_de.forward(msg)
    '''
    def __init__(self,L=20, V=10, hidden_size=128):
        super(MsgDecoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = V

        self.lstm = nn.LSTMCell(self.input_size, self.hidden_size)
        self.init_hidden = self.init_hidden_and_cell()
        self.init_cell = self.init_hidden_and_cell()        

    def forward(self, msg):
        DEVICE = 'cuda' if msg.is_cuda else 'cpu'
        msg = msg.transpose(0,1)       # [NB, L, V] to [L, NB, V]
        max_len = msg.size(0)
        batch_size = msg.size(1)    # msg shape: [MSG_MAX_LEN, N_B, MSG_VOCSIZE+1]

        last_hidden = self.init_hidden.expand(batch_size, -1).contiguous().to(DEVICE)
        last_cell = self.init_cell.expand(batch_size, -1).contiguous().to(DEVICE)     

        for t in range(max_len):
            hidden, cell = self.lstm(msg[t], (last_hidden, last_cell))
            last_hidden = hidden
            last_cell = cell       
        
        return last_hidden
        
    def init_hidden_and_cell(self):
        return torch.zeros(1, self.hidden_size)


if __name__ == '__main__':
    V = 20
    L = 10
    NB = 3
    h0 = torch.rand((NB,128)).cuda()

    msg_gen = MsgGenLSTM().cuda()
    msg_gen.train()
    msg, logits = msg_gen.forward(h0, h0,mode='argmax')
    msg_de = MsgDecoderLSTM().cuda()
    last_hidden = msg_de.forward(msg)
    
    print(msg.argmax(-1))
    print(logits.argmax(-1))