
import torch
import torch.nn as nn

from .attention import MultiHeadAttention
from .BlockLSTM import BlockLSTM
from .BlockGRU import BlockGRU
from .sparse_grad_attn import blocked_grad
'''
Core blocks module.  Takes:
    input: (ts, mb, h)
    hx: (ts, mb, h)
    cx: (ts, mb, h)

    output:
    output, hx, cx

'''
from torch.distributions.categorical import Categorical

class BlocksCore(nn.Module):


    def __init__(self, nhid, num_blocks_in, num_blocks_out, topkval, step_att, do_gru, num_modules_read_input=2,
                 att_out=320, inp_key_size=64, key_size=32):
        super(BlocksCore, self).__init__()
        self.nhid = nhid
        self.num_blocks_in = num_blocks_in
        self.num_blocks_out = num_blocks_out
        self.block_size_in = nhid // num_blocks_in
        self.block_size_out = nhid // num_blocks_out
        self.topkval = topkval
        self.step_att = step_att
        self.do_gru = do_gru
        self.num_modules_read_input = num_modules_read_input
        self.key_size = key_size
        self.inp_key_size = inp_key_size

        print('bs in', self.block_size_in)
        print('bs out', self.block_size_out)
        print('key size', self.key_size)
        print('inp key size', self.inp_key_size)

        self.mha = MultiHeadAttention(n_head=4, d_model_read=self.block_size_out, d_model_write=self.block_size_out, d_model_out=self.block_size_out, d_k=key_size, d_v=key_size, num_blocks_read=self.num_blocks_out, num_blocks_write=self.num_blocks_out, topk=self.num_blocks_out, grad_sparse=False)

        self.att_out = att_out
        self.inp_att = MultiHeadAttention(n_head=1, d_model_read=self.block_size_out, d_model_write=self.block_size_in, d_model_out=self.att_out, d_k=inp_key_size, d_v=self.att_out, num_blocks_read=num_blocks_out, num_blocks_write=num_modules_read_input,residual=False, topk=self.num_blocks_in+1, grad_sparse=False, skip_write=True)
        print('att out', self.att_out)

        if do_gru:
            self.block_lstm = BlockGRU(self.att_out*self.num_blocks_out, self.nhid, k=self.num_blocks_out)
        else:
            self.block_lstm = BlockLSTM(self.att_out*self.num_blocks_out, self.nhid, k=self.num_blocks_out)

    def blockify_params(self):
        self.block_lstm.blockify_params()

    def forward(self, inp, hx, cx, step,do_print=False):

        hxl = []
        cxl = []

        inp_use = inp #layer_input[idx_step]
        #use attention here.
        inp_use = inp_use.reshape((inp_use.shape[0], self.num_blocks_in, self.block_size_in))
        inp_use = inp_use.repeat(1,self.num_modules_read_input-1,1)
        inp_use = torch.cat([torch.zeros_like(inp_use[:,0:1,:]), inp_use], dim=1)

        #print('inp use shape pre-att', inp_use.shape)
        #print('hx shape', hx.shape)

        inp_use, iatt, _ = self.inp_att(hx.reshape((hx.shape[0], self.num_blocks_out, self.block_size_out)), inp_use, inp_use)
        inp_use = inp_use.reshape((inp_use.shape[0], self.att_out*self.num_blocks_out))

        #null_score = iatt.mean((0,1))[1]
        #topk_mat = torch.topk(iatt[:,:,0], dim=1, k=self.topkval)[0][:,-1] #64 x 1
        #topk_mat = topk_mat.reshape((inp_use.shape[0],1)).repeat(1,self.num_blocks_out) #64 x num_blocks
        #mask = torch.gt(iatt[:,:,0], topk_mat - 0.01).float()

        #rejected_index = Categorical(probs = 1/(iatt[:, :, 0]+1e-8)).sample_n(self.num_blocks_out-self.topkval).t()
        #rejection_mask = torch.ones_like(iatt[:, :, 0]).scatter(dim=1, index=rejected_index, value=0)

        new_mask = torch.ones_like(iatt[:,:,0])
        bottomk_indices = torch.topk(iatt[:,:,0], dim=1,
                                sorted=True, largest=True,
                                k = self.num_blocks_out - self.topkval)[1]

        for i in range(bottomk_indices.shape[0]):
            new_mask[i, bottomk_indices[i]] = 0.0

        mask = new_mask
        block_mask = mask.reshape((inp_use.shape[0], self.num_blocks_out,1))

        if do_print:
            #print('step', idx_step, 'out of', inp.shape[0])
            #print('att at 0', iatt[0])
            print('mask at', step, mask[0])


        mask = mask.reshape((inp_use.shape[0],self.num_blocks_out,1)).repeat((1,1,self.block_size_out)).reshape((inp_use.shape[0], self.num_blocks_out*self.block_size_out))

        mask = mask.detach()

        if self.do_gru:
            hx_new = self.block_lstm(inp_use, hx)
            cx_new = hx_new
        else:
            hx_new, cx_new = self.block_lstm(inp_use, hx, cx)

        hx_old = hx*1.0
        cx_old = cx*1.0

        if self.step_att:
            hx_new = hx_new.reshape((hx_new.shape[0], self.num_blocks_out, self.block_size_out))
            #bg = blocked_grad()
            hx_new_grad_mask = blocked_grad.apply(hx_new,
                                                  mask.reshape(
                                                      (mask.shape[0],
                                                       self.num_blocks_out,
                                                       self.block_size_out)))
            #hx_new_grad_mask = bg(hx_new, mask.reshape((mask.shape[0], self.num_blocks_out, self.block_size_out)))
            #hx_new_grad_mask = hx_new * mask.reshape((mask.shape[0], self.num_blocks_out, self.block_size_out))
            hx_new_att,attn_out,extra_loss_att = self.mha(hx_new_grad_mask,hx_new_grad_mask,hx_new_grad_mask)
            hx_new = hx_new + hx_new_att
            hx_new = hx_new.reshape((hx_new.shape[0], self.nhid))
            extra_loss = extra_loss_att

        hx = (mask)*hx_new + (1-mask)*hx_old
        cx = (mask)*cx_new + (1-mask)*cx_old

        return hx, cx, mask, block_mask

    def step_attention(self, hx_new, cx_new, mask):
        hx_new = hx_new.reshape((hx_new.shape[0], self.num_blocks_out, self.block_size_out))
        # bg = blocked_grad()
        hx_new_grad_mask = blocked_grad.apply(hx_new,
                                              mask.reshape((mask.shape[0],
                                                            self.num_blocks_out,
                                                            self.block_size_out)))
        hx_new_att,attn_out,extra_loss_att = self.mha(hx_new_grad_mask,hx_new_grad_mask,hx_new_grad_mask)
        hx_new = hx_new + hx_new_att
        hx_new = hx_new.reshape((hx_new.shape[0], self.nhid))
        extra_loss = extra_loss_att
        return hx_new, cx_new, extra_loss

if __name__ == "__main__":
    bc = BlocksCore(512, 1, 4, 4)

    inp = torch.randn(10, 512)
    hx = torch.randn(10,512)
    cx = torch.randn(10,512)

    hx, cx = bc(inp, hx, cx)

    print('hx cx shape', hx.shape, cx.shape)





