import torch.nn as nn
import torch


class RNNCell(nn.RNNCell):
    """
    math: h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})
    """
    def __init__(self, input_size, hidden_size, init_std=1e-0, bias=True, nonlinearity="tanh", device=None, dtype=None):
        super().__init__(input_size, hidden_size, bias, nonlinearity, device, dtype)
        self.noise_std = nn.Parameter(torch.full((hidden_size,), init_std, device=device))
        self.input_buf, self.hx_buf, self.epsilon_buf = [], [], []

    def forward(self, input, hx=None, add_noise=False):
        """
        input: (bs, input_size)
        logit_output: (bs, hidden_size)
        """
        is_batched = input.dim() == 2
        if not is_batched:
            input = input.unsqueeze(0)
        if hx is None:
            hx = torch.zeros(input.shape[0], self.hidden_size, dtype=input.dtype, device=input.device)
        else:
            hx = hx.unsqueeze(0) if not is_batched else hx

        output = torch.matmul(input,self.weight_ih.T) + torch.matmul(hx,self.weight_hh.T) + self.bias_ih + self.bias_hh

        if add_noise:
            bs, _ = output.shape
            epsilon = torch.zeros_like(output, device=self.noise_std.device)
            epsilon[:bs//2] += torch.randn((bs//2, self.hidden_size), device=self.noise_std.device)
            epsilon[bs//2:] -= epsilon[:bs//2]
            output += epsilon * self.noise_std
            self.store_buf(input, hx, epsilon)

        if self.nonlinearity == "tanh":
            output = torch.tanh(output)
        elif self.nonlinearity == "relu":
            output = torch.relu(output)

        if not is_batched:
            output = output.squeeze(0)
        return output

    def backward(self, loss):
        """
        loss: (bs, )
        """
        batch_size = self.input_buf[0].shape[0]
        loss = loss.unsqueeze(0).unsqueeze(-1)
        noise_std = torch.unsqueeze(self.noise_std, -1)

        i, h, e = torch.stack(self.input_buf), torch.stack(self.hx_buf), torch.stack(self.epsilon_buf)
        self.weight_hh.grad = torch.einsum('bni,bnj->ji', h * loss, e) / (noise_std * batch_size)
        self.weight_ih.grad = torch.einsum('bni,bnj->ji', i * loss, e) / (noise_std * batch_size)
        self.bias_hh.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std * batch_size)
        self.bias_ih.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std * batch_size)
        self.noise_std.grad = torch.einsum('bni,bnj->j', loss, e ** 2 - 1) / (self.noise_std * batch_size)

        self.clear_buf()

    def clear_buf(self):
        self.input_buf, self.hx_buf, self.epsilon_buf = [], [], []

    def store_buf(self, input, hx, epsilon):
        self.input_buf.append(input)
        self.hx_buf.append(hx)
        self.epsilon_buf.append(epsilon)


class GRUCell(nn.GRUCell):
    """
    math:
        \begin{array}{ll}
        r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
        z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
        n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\
        h' = (1 - z) * n + z * h
        \end{array}
    """
    def __init__(self, input_size, hidden_size, init_std=1e-0, bias=True, device=None, dtype=None):
        super().__init__(input_size, hidden_size, bias, device, dtype)
        self.noise_std = nn.Parameter(torch.full((4*hidden_size,), init_std, device=device))
        # [:3*hidden_size] control eps added to wx, the rest for wh[-hidden_size:]
        self.input_buf, self.hx_buf, self.epsilon_buf, self.epsilon_buf_hn = [], [], [], []

    def forward(self, input, hx=None, add_noise=False):
        is_batched = input.dim() == 2
        if not is_batched:
            input = input.unsqueeze(0)
        if hx is None:
            hx = torch.zeros(input.shape[0], self.hidden_size, dtype=input.dtype, device=input.device)
        else:
            hx = hx.unsqueeze(0) if not is_batched else hx

        wx = torch.matmul(input, self.weight_ih.T) + self.bias_ih
        wh = torch.matmul(hx, self.weight_hh.T) + self.bias_hh

        if add_noise:
            bs, _ = wx.shape
            epsilon = torch.zeros_like(wx, device=self.noise_std.device)
            epsilon[:bs//2] += torch.randn((bs//2, 3*self.hidden_size), device=self.noise_std.device)
            epsilon[bs//2:] -= epsilon[:bs//2]

            epsilon_hn = torch.zeros((bs,self.hidden_size), device=self.noise_std.device)
            epsilon_hn[:bs//2] += torch.randn((bs//2, self.hidden_size), device=self.noise_std.device)
            epsilon_hn[bs//2:] -= epsilon_hn[:bs//2]

            wx += epsilon * self.noise_std[:3*self.hidden_size]
            wh[:,-self.hidden_size:] += epsilon_hn * self.noise_std[-self.hidden_size:]

            self.store_buf(input, hx, epsilon, epsilon_hn)

        r, z = torch.sigmoid(wx[:, :2 * self.hidden_size] + wh[:, :2 * self.hidden_size]).chunk(2, 1)
        n = torch.tanh(wx[:, -self.hidden_size:] + r * (wh[:, -self.hidden_size:]))
        hy = (1 - z) * n + z * hx

        if not is_batched:
            hy = hy.squeeze(0)
        return hy

    def backward(self, loss):
        """
        loss:   (bs,)
        """
        batch_size = self.input_buf[0].shape[0]
        noise_std = torch.cat((self.noise_std[:2 * self.hidden_size], self.noise_std[-self.hidden_size:]))
        loss = loss.unsqueeze(0).unsqueeze(-1)

        i, h, e, eh = torch.stack(self.input_buf), torch.stack(self.hx_buf), torch.stack(self.epsilon_buf), torch.stack(self.epsilon_buf_hn)
        self.weight_ih.grad = torch.einsum('bni,bnj->ji', i * loss, e) / (self.noise_std.unsqueeze(-1)[:3*self.hidden_size] * batch_size)
        self.bias_ih.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std[:3*self.hidden_size] * batch_size)
        self.noise_std.grad = torch.einsum('bni,bnj->j', loss, torch.cat((e, eh), dim=2) ** 2 - 1) / (self.noise_std * batch_size)
        e[:,:, -self.hidden_size:] = eh
        self.weight_hh.grad = torch.einsum('bni,bnj->ji', h * loss, e) / (noise_std.unsqueeze(-1) * batch_size)
        self.bias_hh.grad = torch.einsum('bni,bnj->j', loss, e) / (noise_std * batch_size)

        self.clear_buf()

    def clear_buf(self):
        self.input_buf, self.hx_buf, self.epsilon_buf, self.epsilon_buf_hn = [], [], [], []

    def store_buf(self, input, hx, epsilon, epsilon_hn):
        self.input_buf.append(input)
        self.hx_buf.append(hx)
        self.epsilon_buf.append(epsilon)
        self.epsilon_buf_hn.append(epsilon_hn)


class LSTMCell(nn.LSTMCell):
    """
    math:
        \begin{array}{ll}
        i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
        f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\
        g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\
        o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\
        c' = f * c + i * g \\
        h' = o * \tanh(c') \\
        \end{array}
    """
    def __init__(self, input_size, hidden_size, init_std=1e-0, bias=True, device=None, dtype=None):
        super().__init__(input_size, hidden_size, bias, device, dtype)
        self.noise_std = nn.Parameter(torch.full((4*hidden_size,), init_std, device=device))
        self.input_buf, self.hx_buf, self.epsilon_buf, self.epsilon_buf_c = [], [], [], []

    def forward(self, input, hx=None, add_noise=False):
        """
        hx: (h,c)
        """
        is_batched = input.dim() == 2
        if not is_batched:
            input = input.unsqueeze(0)
        if hx is None:
            zeros = torch.zeros(input.shape[0], self.hidden_size, dtype=input.dtype, device=input.device)
            hx = (zeros, zeros)
        else:
            hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx

        wxh = torch.matmul(input, self.weight_ih.T) + self.bias_ih + torch.matmul(hx[0], self.weight_hh.T) + self.bias_hh
        c = hx[1]
        if add_noise:
            bs, _ = wxh.shape
            epsilon = torch.zeros_like(wxh, device=self.noise_std.device)
            epsilon[:bs//2] += torch.randn((bs//2, 4*self.hidden_size), device=self.noise_std.device)
            epsilon[bs//2:] -= epsilon[:bs//2]
            wxh += epsilon * self.noise_std
            self.store_buf(input, hx[0], epsilon)
        i, f, g, o = wxh.chunk(4,1)
        c_ = torch.sigmoid(f) * c + torch.sigmoid(i) * torch.tanh(g)
        h_ = torch.sigmoid(o) * torch.tanh(c_)
        output = (h_, c_)

        if not is_batched:
            output = (output[0].squeeze(0), output[1].squeeze(0))
        return output

    def backward(self, loss):
        """
        loss:   (bs,)
        """
        batch_size = self.input_buf[0].shape[0]
        loss = loss.unsqueeze(0).unsqueeze(-1)

        i, h, e = torch.stack(self.input_buf), torch.stack(self.hx_buf), torch.stack(self.epsilon_buf)
        self.weight_ih.grad = torch.einsum('bni,bnj->ji', i * loss, e) / (self.noise_std.unsqueeze(-1) * batch_size)
        self.bias_ih.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std * batch_size)
        self.weight_hh.grad = torch.einsum('bni,bnj->ji', h * loss, e) / (self.noise_std.unsqueeze(-1) * batch_size)
        self.bias_hh.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std * batch_size)
        self.noise_std.grad = torch.einsum('bni,bnj->j', loss, e ** 2 - 1) / (self.noise_std * batch_size)

        self.clear_buf()

    def clear_buf(self):
        self.input_buf, self.hx_buf, self.epsilon_buf = [], [], []

    def store_buf(self, input, hx, epsilon):
        self.input_buf.append(input)
        self.hx_buf.append(hx)
        self.epsilon_buf.append(epsilon)


class Linear(nn.Linear):
    def __init__(self, in_features, out_features, init_std=1e-0, bias=True, device=None, dtype=None):
        """
        weight: (out_features, in_features)
        bias: (out_features,)
        input_buf: (bs, in_features)
        epsilon_buf: (bs, out_features)
        noise_std: (out_features,)
        """
        super().__init__(in_features, out_features, bias, device, dtype)
        self.noise_std = nn.Parameter(torch.full((out_features,), init_std, device=device))
        self.input_buf, self.epsilon_buf = [], []

    def forward(self, input, add_noise=False):
        """
        input: (bs, in_features)
        logit_output: (bs, out_features)
        """
        logit_output = super().forward(input)
        if add_noise:
            bs, out_features = logit_output.shape
            epsilon = torch.zeros_like(logit_output, device=self.noise_std.device)
            epsilon[:bs//2] += torch.randn((bs//2, out_features), device=self.noise_std.device)
            epsilon[bs//2:] -= epsilon[:bs//2]
            noise = epsilon * self.noise_std
            self.store_buf(input, epsilon)
            return logit_output + noise
        else:
            return logit_output

    def backward(self, loss):
        """
        loss: (bs,)
        """
        batch_size = self.input_buf[0].shape[0]
        loss = loss.unsqueeze(0).unsqueeze(-1)
        noise_std = torch.unsqueeze(self.noise_std,-1)

        i, e = torch.stack(self.input_buf), torch.stack(self.epsilon_buf)
        self.weight.grad = torch.einsum('bni,bnj->ji', i * loss, e) / (noise_std * batch_size)
        self.bias.grad = torch.einsum('bni,bnj->j', loss, e) / (self.noise_std * batch_size)
        self.noise_std.grad = torch.einsum('bni,bnj->j', loss, e ** 2 - 1) / (self.noise_std * batch_size)

        self.clear_buf()

    def clear_buf(self):
        self.input_buf, self.epsilon_buf = [], []

    def store_buf(self, input, epsilon):
        self.input_buf.append(input)
        self.epsilon_buf.append(epsilon)
