import torch
from architectures.activations import (
    MexicanHat, MexicanHatStandard, HardSigmoid, StraightThroughEstimator, HardSoftmax
)

class RNNCell(torch.nn.Module):
    def __init__(self, n_inputs, n_hidden, nonlinearity, dropouts, input_bias, hidden_bias):
        super(RNNCell, self).__init__()

        if nonlinearity == 'sigmoid':
            activation_fn = torch.sigmoid
        elif nonlinearity == 'mexicanhat':
            activation_fn = MexicanHat()
        elif nonlinearity == 'mexicanhatstd':
            activation_fn = MexicanHatStandard()
        elif nonlinearity == 'relu':
            activation_fn = torch.relu
        elif nonlinearity == 'hard_sigmoid':
            activation_fn = HardSigmoid()
        elif nonlinearity == 'step':
            activation_fn = StraightThroughEstimator()
        elif nonlinearity == 'softmax':
            activation_fn = torch.nn.Softmax(dim=1)
        elif nonlinearity == 'hard_softmax':
            activation_fn = HardSoftmax(n_hidden)
        else:
            print("[!!!] WARNING: activation function not recognized, using identity")
            activation_fn = torch.nn.Identity()

        self.in2hidden = torch.nn.Linear(n_inputs, n_hidden, bias=input_bias)
        self.hidden2hidden = torch.nn.Linear(n_hidden, n_hidden, bias=hidden_bias)

        self.add_in2h_do = dropouts[0] > 0
        self.in2h_do = torch.nn.Dropout(dropouts[0])
        self.add_h2h_do = dropouts[1] > 0
        self.h2h_do = torch.nn.Dropout(dropouts[1])

        self.activation_fn = activation_fn

    def forward(self, x, hidden):
        igates = self.in2hidden(x)
        if self.add_in2h_do:
            igates = self.in2h_do(igates)
        hgates = self.hidden2hidden(hidden)
        if self.add_h2h_do:
            hgates = self.h2h_do(hgates)
        return self.activation_fn(igates + hgates)


class RNNModule(torch.nn.Module):
    def __init__(
        self, device, n_inputs, n_hidden, nonlinearity, dropouts,
        input_bias=True, hidden_bias=True
    ):
        super(RNNModule, self).__init__()

        self.rnn_cell = RNNCell(
            n_inputs, n_hidden, nonlinearity, dropouts, input_bias, hidden_bias
        )
        self.n_hidden = n_hidden

        self.device = device

    def forward(self, x, hidden=None):
        # x: [BATCH SIZE, TIME, N_FEATURES]
        # hidden: [BATCH SIZE, N_HIDDEN]
        
        output = torch.zeros(x.shape[0], x.shape[1], self.n_hidden).to(self.device)

        if hidden is None:
            h_out = torch.zeros(x.shape[0], self.n_hidden) # initialize hidden state
            h_out = h_out.to(self.device)
        else:
            h_out = hidden

        window_size = x.shape[1]

        # loop over time
        for t in range(window_size):
            x_t = x[:,t,...]
            h_out = self.rnn_cell(x_t, h_out)
            output[:,t,...] = h_out

        # return all outputs, and the last hidden state
        return output, h_out
    