



'''
Goal1: a GRU where the weight matrices have a block structure so that information flow is constrained

Data is assumed to come in [block1, block2, ..., block_n].

Goal2: Dynamic parameter sharing between blocks (RIMs)

'''

import torch
import torch.nn as nn
from .GroupLinearLayer import GroupLinearLayer
from .sparse_attn import Sparse_attention


'''
Given an N x N matrix, and a grouping of size, set all elements off the block diagonal to 0.0
'''
def zero_matrix_elements(matrix, k):
    assert matrix.shape[0] % k == 0
    assert matrix.shape[1] % k == 0
    g1 = matrix.shape[0] // k
    g2 = matrix.shape[1] // k
    new_mat = torch.zeros_like(matrix)
    for b in range(0,k):
        new_mat[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2] += matrix[b*g1 : (b+1)*g1, b*g2 : (b+1)*g2]

    matrix *= 0.0
    matrix += new_mat


class BlockGRU(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ninp, nhid, k):
        super(BlockGRU, self).__init__()

        assert ninp % k == 0
        assert nhid % k == 0

        self.k = k
        self.gru = nn.GRUCell(ninp, nhid)
        self.nhid = nhid
        self.ninp = ninp

    def blockify_params(self):
        pl = self.gru.parameters()

        for p in pl:
            p = p.data
            if p.shape == torch.Size([self.nhid*3]):
                pass
                '''biases, don't need to change anything here'''
            if p.shape == torch.Size([self.nhid*3, self.nhid]) or p.shape == torch.Size([self.nhid*3, self.ninp]):
                for e in range(0,4):
                    zero_matrix_elements(p[self.nhid*e : self.nhid*(e+1)], k=self.k)

    def forward(self, input, h):

        #self.blockify_params()

        hnext = self.gru(input, h)

        return hnext, None

class Identity(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input * 1.0
    def backward(ctx, grad_output):
        print(torch.sqrt(torch.sum(torch.pow(grad_output,2))))
        print('-----------')
        return grad_output * 1.0

class SharedBlockGRU(nn.Module):
   """Dynamic sharing of parameters between blocks(RIM's)"""
   def __init__(self, ninp, nhid, k, n_templates):
     super(SharedBlockGRU, self).__init__()
     assert ninp % k == 0
     assert nhid % k == 0
     self.k = k 
     self.m = nhid // self.k
     self.n_templates = n_templates
     print("input to template is ", ninp%k)
     self.templates = nn.ModuleList([nn.GRUCell(ninp//k,self.m) for _ in range(0,self.n_templates)])
     self.nhid = nhid
     self.ninp = ninp
     self.gll_write = GroupLinearLayer(self.m,16, self.n_templates)
     self.gll_read = GroupLinearLayer(self.m,16,1)
     self.sa = Sparse_attention(1)
     print("Using Gumble sparsity")
   def blockify_params(self):
     return
   def forward(self, input, h): 
     #self.blockify_params()
     bs = h.shape[0]
     h = h.reshape((h.shape[0], self.k, self.m)).reshape((h.shape[0]*self.k, self.m))
     input = input.reshape(input.shape[0], 1, input.shape[1])
     #input = input.repeat(1,self.k,1)
     #input = input.reshape(input.shape[0]*self.k, input.shape[2])
     input = input.reshape(input.shape[0]*self.k, -1) 
     #print("input shape is", input.shape)
     h_read = self.gll_read((h*1.0).reshape((h.shape[0], 1, h.shape[1])))
     hnext_stack = []
     for template in self.templates:
       hnext_l = template(input, h)
       hnext_l = hnext_l.reshape((hnext_l.shape[0], 1, hnext_l.shape[1]))
       hnext_stack.append(hnext_l)
     hnext = torch.cat(hnext_stack, 1)
     write_key = self.gll_write(hnext)
     ''' 
     sm = nn.Softmax(2)
     att = sm(torch.bmm(h_read, write_key.permute(0, 2, 1))).squeeze(1)
     att = self.sa(att).unsqueeze(1)
     '''
     att = torch.nn.functional.gumbel_softmax(torch.bmm(h_read, write_key.permute(0, 2, 1)), tau=1, hard=True)
     #att = att*0.0 + 0.25
     hnext = torch.bmm(att, hnext)
     hnext = hnext.mean(dim=1)
     hnext = hnext.reshape((bs, self.k, self.m)).reshape((bs, self.k*self.m))
     #print('shapes', hnext.shape, cnext.shape)
     return hnext, att.data.reshape(bs,self.k,self.n_templates)



if __name__ == "__main__":

    Blocks = BlockGRU(2, 6, k=2)
    opt = torch.optim.Adam(Blocks.parameters())

    pl = Blocks.gru.parameters()

    inp = torch.randn(100,2)
    h = torch.randn(100,6)

    h2 = Blocks(inp,h)

    L = h2.sum()**2

    #L.backward()
    #opt.step()
    #opt.zero_grad()


    pl = Blocks.gru.parameters()
    for p in pl:
        print(p.shape)
        #print(torch.Size([Blocks.nhid*4]))
        if p.shape == torch.Size([Blocks.nhid*3]):
            print(p.shape, 'a')
            #print(p)
            '''biases, don't need to change anything here'''
        if p.shape == torch.Size([Blocks.nhid*3, Blocks.nhid]) or p.shape == torch.Size([Blocks.nhid*3, Blocks.ninp]):
            print(p.shape, 'b')
            for e in range(0,4):
                print(p[Blocks.nhid*e : Blocks.nhid*(e+1)])
