import torch.nn as nn
import torch
import torch.nn.functional as F


class AutoGenNet(nn.Module):
    def __init__(self, network_size=(1, 100, 1), nonlinearity='relu',
                 bias_learning=True, bias_init='uniform', gain_init=(1, 1, 1)):
        """
        Vanilla RNN without input to generate a 1D output.
        :param network_size: (tuple) Size of the network
        :param nonlinearity: (string) Choice of nonlinearity
        :param bias_learning: (bool) Whether to learn only with biases.
        """
        super().__init__()
        self.network_size = network_size
        self.recurrent_layer = nn.RNN(input_size=network_size[0], hidden_size=network_size[1],
                                      nonlinearity=nonlinearity)
        self.readout_layer = nn.Linear(network_size[1], network_size[2], bias=False)

        self._initialize_params(bias_init, gain_init)
        self._freeze_params(bias_learning)

    def _initialize_params(self, bias_init, gain_init):
        # -- Weights --
        # RNN
        nn.init.normal_(self.recurrent_layer.weight_hh_l0,
                        std=gain_init[0] / self.recurrent_layer.weight_hh_l0.size(1) ** 0.5)  # recurrent weight matrix
        nn.init.zeros_(self.recurrent_layer.weight_ih_l0)
        # Readout layer
        nn.init.normal_(self.readout_layer.weight, std=gain_init[1] / self.readout_layer.weight.size(1))

        # -- Biases --
        if self.recurrent_layer.mode == 'RNN_RELU':
            nn.init.zeros_(self.recurrent_layer.bias_hh_l0)   # bias from "hidden to hidden" (a mathematical nonsense)
            if bias_init == "uniform":
                nn.init.uniform_(self.recurrent_layer.bias_ih_l0, a=0., b=gain_init[2])  # bias from input layer to hidden
            elif bias_init == "zero":
                nn.init.zeros_(self.recurrent_layer.bias_ih_l0)
            else:
                raise ValueError(f"Unknown bias initialization '{bias_init}'. Select from 'uniform' or 'zero'.")
        elif self.recurrent_layer.mode == 'RNN_TANH':
            nn.init.zeros_(self.recurrent_layer.bias_hh_l0)
            nn.init.zeros_(self.recurrent_layer.bias_ih_l0)

    def _freeze_params(self, bias_learning):
        self.recurrent_layer.bias_hh_l0.requires_grad = False
        self.recurrent_layer.weight_ih_l0.requires_grad = False

        if bias_learning:
            self.recurrent_layer.weight_hh_l0.requires_grad = False
            self.readout_layer.weight.requires_grad = False

    def forward(self, x, h0):
        h, _ = self.recurrent_layer(x, h0)
        x = self.readout_layer(h)
        return x, h


# TODO: Implement a RNN dynamics closer to neuro theory.
class NoisyDecayRNN(nn.RNN):
    r"""
    Noisy RNN with decay.

    :math:`v_t = v_{t-1} + \alpha (-v_{t-1} + Wf(v_{t-1}) + U x_t + b_h + b_i) + \sigma \sqrt{\alpha} \xi_t`

    where :math:`\xi_t` = standard Gaussian variate
    """
    def __init__(self, *args, **kwargs):
        self.alpha = kwargs.pop('alpha', 1.)
        self.sigma = kwargs.pop('sigma', 0.)
        super().__init__(*args, **kwargs)
        assert self.num_layers == 1, "No support for multiple layers"
        assert not self.bidirectional, "No support for bidirectionality for NoisyRNN"
        assert self.dropout == 0, "No support for dropout"

    def forward(self, x, v_prev=None):
        assert (x.dim() in (2, 3)), f"RNN: Expected input to be 2-D or 3-D but received {x.dim()}-D tensor"
        assert self.mode == 'RNN_TANH' or self.mode == 'RNN_RELU'

        if x.dim() == 3:
            if v_prev is None:
                v_prev = torch.zeros(1, self.hidden_size, dtype=x.dtype, device=x.device)
            output = torch.empty((x.shape[0], x.shape[1], self.hidden_size))

            if self.mode == 'RNN_TANH':
                v = v_prev
                for t in range(x.size(1)):
                    xi = torch.randn((1, self.hidden_size))
                    v = v + self.alpha * (-v + F.tanh(v) @ self.weight_hh_l0.T +
                                              x[:, t, :] @ self.weight_ih_l0.T +
                                              self.bias_ih_l0 + self.bias_hh_l0) + self.sigma*self.alpha**0.5*xi
                    output[:, t, :] = F.tanh(v)
            elif self.mode == 'RNN_RELU':
                v = v_prev
                for t in range(x.size(1)):
                    xi = torch.randn((x.shape[0], self.hidden_size))
                    v = v + self.alpha * (-v + F.relu(v) @ self.weight_hh_l0.T +
                                          x[:, t, :] @ self.weight_ih_l0.T +
                                          self.bias_ih_l0 + self.bias_hh_l0) + self.sigma * self.alpha ** 0.5 * xi
                    output[:, t, :] = F.relu(v)
            return output, output[:, -1, :] if self.batch_first else output[-1, :, :]

        elif x.dim() == 2:
            # x has shape (N, `self.input_size)
            if v_prev is None:
                v_prev = torch.zeros(x.shape[0], self.hidden_size, dtype=x.dtype, device=x.device)
            output = torch.empty((x.shape[0], self.hidden_size))
            if self.mode == 'RNN_TANH':
                v = v_prev
                xi = torch.randn((x.shape[0], self.hidden_size))
                v = v + self.alpha * (-v + F.tanh(v) @ self.weight_hh_l0.T +
                                          x @ self.weight_ih_l0.T +
                                          self.bias_ih_l0 + self.bias_hh_l0) + self.sigma*self.alpha**0.5*xi
                output = v  # not that v is return when input is 2D
            elif self.mode == 'RNN_RELU':
                v = v_prev
                xi = torch.randn((1, self.hidden_size))
                v = v + self.alpha * (-v + F.relu(v) @ self.weight_hh_l0.T +
                                      x @ self.weight_ih_l0.T +
                                      self.bias_ih_l0 + self.bias_hh_l0) + self.sigma * self.alpha ** 0.5 * xi
                output = v
            return output, None
