import torch
import torch.nn as nn

class RNN(nn.Module):
    def __init__(
        self,
        task,
        n_in=2,
        n_rec=512,
        n_out=512,
        n_init=512,
        sigma_in=0,
        sigma_rec=0,
        sigma_out=0,
        dt=0.2,
        tau=1,
        feedback_freq=0,
        bias=False,
        activation_fn="relu",
        device="cuda",
    ):
        super(RNN, self).__init__()

        self.task = task
        self.n_in = n_in
        self.n_rec = n_rec
        self.n_out = n_out
        self.n_init = n_init
        self.sigma_in = sigma_in
        self.sigma_rec = sigma_rec
        self.sigma_out = sigma_out
        self.dt = dt
        self.tau = tau
        self.feedback_freq = feedback_freq
        self.bias = bias
        self.activation_fn = activation_fn
        self.device = device

        self.encoder = nn.Linear(self.n_init, self.n_rec, bias=self.bias)
        self.w_in = nn.Linear(self.n_in, self.n_rec, bias=self.bias)
        self.w_rec = nn.Linear(self.n_rec, self.n_rec, bias=self.bias)
        self.w_out = nn.Linear(self.n_rec, self.n_out, bias=self.bias)

        if self.activation_fn == "relu":
            self.activation = torch.relu
        elif self.activation_fn == "tanh":
            self.activation = torch.tanh
        else:
            self.activation = nn.Identity()

        self.to(self.device)

    def forward(self, x, init_state=None, feedback=None):
        batch_size, timesteps = x.shape[0], x.shape[1]

        if init_state is not None:
            self.h = self.encoder(init_state.reshape(batch_size, self.n_init))
        else:
            self.h = torch.zeros(batch_size, self.n_rec, device=self.device)

        self.u_1t = torch.zeros(batch_size, timesteps, self.n_rec, device=self.device)
        self.h_1t = torch.zeros(batch_size, timesteps, self.n_rec, device=self.device)
        self.y_1t = torch.zeros(batch_size, timesteps, self.n_out, device=self.device)

        for t in range(timesteps):
            if feedback is not None and t % self.feedback_freq == 0 and t != 0:
                self.h = self.encoder(feedback[:, t - 1, :].reshape(batch_size, self.n_init))

            x_t = x[:, t, :].reshape(batch_size, self.n_in)

            noise_in = torch.rand_like(x_t, device=self.device)
            self.u = self.w_rec(self.h) + self.w_in(x_t + self.sigma_in * noise_in)

            noise_rec = torch.rand_like(self.h , device=self.device)
            self.h = (1 - self.dt / self.tau) * self.h + (self.dt / self.tau) * self.activation(self.u) + self.sigma_rec * noise_rec

            noise_out = torch.randn(batch_size, self.n_out, device=self.device)
            self.y = self.w_out(self.h) + self.sigma_out * noise_out

            self.u_1t[:, t, :] = self.u
            self.h_1t[:, t, :] = self.h
            self.y_1t[:, t, :] = self.y

        return self.u_1t, self.h_1t, self.y_1t
