import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.jit as jit
from typing import List
from torch import Tensor

class scriptGRUCell(jit.ScriptModule): 
    def __init__(self, input_size, hidden_size, nonlinearity='tanh', noise=0.0, weight_init=1): 
        super(scriptGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
        self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
        self.noise = noise

        if nonlinearity == 'relu': 
            self.activ_func = torch.relu
        elif nonlinearity == 'linear': 
            self.activ_func = nn.Identity()
        elif nonlinearity == 'tanh': 
            self.activ_func = nn.Tanh()

        # Xavier for input, orthogonal for hidden (per gate)
        torch.nn.init.xavier_uniform_(self.weight_ih, gain=weight_init)
        torch.nn.init.zeros_(self.bias_ih)
        W_r, W_z, W_n = self.weight_hh.chunk(3, dim=0)
        for W in (W_r, W_z, W_n):
            with torch.no_grad():
                torch.nn.init.orthogonal_(W)
                W.mul_(weight_init)
        torch.nn.init.zeros_(self.bias_hh)

    @jit.script_method
    def forward(self, input, hx):
        igates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih)
        hgates = (torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        r_in, z_in, c_in = igates.chunk(3, 1)
        r_hid, z_hid, c_hid = hgates.chunk(3, 1)

        r_gate = torch.sigmoid(r_in+r_hid)
        z_gate = torch.sigmoid(z_in + z_hid)
        pre_c = c_in + (r_gate *c_hid)
        
        noise = torch.randn_like(hx)*self.noise
        c = self.activ_func(pre_c+noise)
        h_new = ((1-z_gate)*c) + (z_gate * hx)

        return h_new

# Low rank RNN
class LowRankRNNCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size, rank, nonlinearity='relu'):
        super(LowRankRNNCell, self).__init__()        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.rank = rank
        
        if nonlinearity == 'relu': 
            self.activ_func = torch.relu
        elif nonlinearity == 'linear': 
            self.activ_func = nn.Identity()
        elif nonlinearity == 'tanh': 
            self.activ_func = nn.Tanh()        
        
        self.input_proj = nn.Linear(input_size, hidden_size)
        self.U = nn.Parameter(torch.randn(hidden_size, rank) / (rank ** 0.5))
        self.V = nn.Parameter(torch.randn(rank, hidden_size) / (hidden_size ** 0.5))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
        
    @jit.script_method        
    def forward(self, input, hidden):
        input_part = self.input_proj(input)
        recurrent_part = torch.mm(hidden, self.U)
        recurrent_part = torch.mm(recurrent_part, self.V)
        gate_input = input_part + recurrent_part + self.bias
        new_hidden = self.activ_func(gate_input)
        return new_hidden

class scriptVanillaCell(jit.ScriptModule): 
    def __init__(self, input_size, hidden_size, nonlinearity='tanh', noise=0.0, weight_init=1): 
        super(scriptVanillaCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.Tensor(hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.Tensor(hidden_size))
        self.bias_hh = Parameter(torch.Tensor(hidden_size))
        self.noise = noise

        if nonlinearity == 'relu': 
            self.activ_func = torch.relu
        elif nonlinearity == 'linear': 
            self.activ_func = nn.Identity()
        elif nonlinearity == 'tanh': 
            self.activ_func = nn.Tanh()

        # Xavier for input, orthogonal for hidden
        torch.nn.init.xavier_uniform_(self.weight_ih, gain=weight_init)
        torch.nn.init.zeros_(self.bias_ih)
        torch.nn.init.orthogonal_(self.weight_hh)
        self.weight_hh.data *= weight_init
        torch.nn.init.zeros_(self.bias_hh)         

    @jit.script_method
    def forward(self, input, hx):
        ih = (torch.mm(input, self.weight_ih.t()) + self.bias_ih)
        hh = (torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
        noise = torch.randn_like(hx)*self.noise
        h_new = self.activ_func(ih+hh+noise)
        return h_new

class scriptRNNLayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(scriptRNNLayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, input, hx, gating):
        inputs = input.unbind(1)
        gatings = gating.unbind(1)

        outputs = jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            hx = self.cell(inputs[i], hx)*gatings[i]
            outputs.append(hx)
        return torch.stack(outputs), hx

def init_stacked_RNN(num_layers, layer, first_layer_args, other_layer_args):
    layers = [layer(*first_layer_args)] + [layer(*other_layer_args)
                                           for _ in range(num_layers - 1)]
    return nn.ModuleList(layers)

class ScriptRNN(jit.ScriptModule):
    __constants__ = ['num_layers']

    def __init__(self, input_dim, hidden_dim, num_layers, nonlinearity='relu', rnn_class = 'gru', batch_first = True):
        super(ScriptRNN, self).__init__()
        if rnn_class == 'gru': 
            rnn_cell = scriptGRUCell
            self.n_chunks = 3
        elif rnn_class == 'vanilla': 
            rnn_cell = scriptVanillaCell
            self.n_chunks = 1

        self.layers = init_stacked_RNN(num_layers, scriptRNNLayer, 
                                        [rnn_cell, input_dim, hidden_dim, nonlinearity],
                                        [rnn_cell, hidden_dim, hidden_dim, nonlinearity])        

        self.num_layers = num_layers
        self.batch_first = batch_first
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.__weights_init__()
    
    def __weights_init__(self):
        for n, p in self.named_parameters():
            if 'weight_ih' in n:
                for ih in p.chunk(self.n_chunks, 0):
                    torch.nn.init.normal_(ih, std = 1/torch.sqrt(torch.tensor(self.input_dim)))
            elif 'weight_hh' in n:
                for hh in p.chunk(self.n_chunks, 0):
                    hh.data.copy_(torch.eye(self.hidden_dim)*0.5)

    @jit.script_method
    def forward(self, input, layers_hx, gating):
        output_states = jit.annotate(List[Tensor], [])
        output = input
        for i, rnn_layer in enumerate(self.layers):
            hx = layers_hx[i]
            shape = output.shape
            output, out_state = rnn_layer(output, hx, gating)
            output = output.transpose(0,1)
            output_states += [out_state]


        return output, torch.stack(output_states)