
from .rnn_models import RNNModel #rnn_models
import torch
import torch.nn as nn


class BlocksWrapper(nn.Module):
    def __init__(self, ntokens, nhid, nout=None, dropout=0.0, num_blocks=6, update_topk=6, num_blocks_read_input=4,
                 use_decoder=True, blocks_att_out=320, inp_key_size=64, key_size=32, batch_first=False):
        super(BlocksWrapper, self).__init__()
        self.myrnn = RNNModel("GRU", ntokens, nhid, nhid, nout=nout,
                              nlayers=1, dropout=dropout, tie_weights=False,
                              use_cudnn_version=False, use_adaptive_softmax=False,
                              cutoffs=[10000], discrete_input=False, num_blocks=num_blocks,
                              topk=update_topk, do_gru=True, num_modules_read_input=num_blocks_read_input,
                              blocks_att_out=blocks_att_out, use_decoder=use_decoder, key_size=key_size,
                              inp_key_size=inp_key_size)
        #self.myrnn = nn.LSTM(ntokens, nhid)
        self.nhid = nhid
        self.batch_first = batch_first
        print('using blocks wrapper!')

    def forward(self, inp, h=None):
        if self.batch_first:
            inp = inp.permute(1, 0, 2)
        if h is None:
            h = torch.zeros(1, inp.shape[1], 2 * self.nhid, dtype=inp.dtype, device=inp.device)
        hx = h[:,:,:self.nhid].contiguous()
        cx = h[:,:,self.nhid:].contiguous()
        ob, (hx,cx) = self.myrnn(inp, (hx, cx))
        hb = torch.cat([hx,cx], dim=2)
        if self.batch_first:
            ob = ob.permute(1, 0, 2)
        return ob,hb


if __name__ == "__main__":
    nhid = 120
    ntokens = 144
    T = 128
    N = 1

    blocks = BlocksWrapper(ntokens, nhid, nout=nhid, use_decoder=False, update_topk=4)
    print("Blocks nparams: ", sum([p.numel() for p in blocks.parameters()]))
    gru = torch.nn.GRU(ntokens, nhid)
    print("Gru nparams: ", sum([p.numel() for p in gru.parameters()]))
    # exit()

    x = torch.randn(T, N, ntokens)

    h0 = torch.randn(1, N, nhid)
    h0_blocks = torch.randn(1, N, nhid*2)

    og, hg = gru(x, h0)
    print('gru of x: o,h', og.shape, hg.shape)

    ob, hb = blocks(x, h0_blocks)
    print('block res: o,h', ob.shape, hb.shape)
