import numpy as np
import torch
import torch.nn as nn


def gen_ortho_matrix(dim, rng=None):
    '''Generate random orthogonal matrix
    Taken from scipy.stats.ortho_group
    Copied here from compatibilty with older versions of scipy
    '''
    H = np.eye(dim)
    for n in range(1, dim):
        if rng is None:
            x = np.random.normal(size=(dim - n + 1,))
        else:
            x = rng.normal(size=(dim - n + 1,))
        # random sign, 50/50, but chosen carefully to avoid roundoff error
        D = np.sign(x[0])
        x[0] += D * np.sqrt((x * x).sum())
        # Householder transformation
        Hx = -D * (np.eye(dim - n + 1)
                   - 2. * np.outer(x, x) / (x * x).sum())
        mat = np.eye(dim)
        mat[n - 1:, n - 1:] = Hx
        H = np.dot(H, mat)
    return H


class customRNN(nn.Module):
    def __init__(self, config, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.seed = config['seed']
        self.rng = config['rng']

        n_input = config['num_input']
        n_rnn = config['num_rnn']

        self.activation = torch.nn.functional.softplus
        self.bias_start = 0.0
        self.w_in_start = 1.0
        self.w_rec_start = 1.0
        self.alpha = 0.01
        self.alpha_noise = 0.5
        self.sigma = 0.05

        w_in2rnn0 = np.float32(self.rng.randn(n_input, n_rnn) / np.sqrt(n_input) * self.w_in_start)
        w_rec0 = np.float32(self.w_rec_start * gen_ortho_matrix(n_rnn, rng=self.rng))

        self.in2rnn = nn.Parameter(torch.from_numpy(w_in2rnn0), requires_grad=True)
        self.kernel = nn.Parameter(torch.from_numpy(w_rec0), requires_grad=True)
        self.bias = nn.Parameter(torch.from_numpy(np.ones(n_rnn) * self.bias_start), requires_grad=True)
        self.initS = nn.Parameter(torch.zeros(config['batch_size'], n_rnn))

        self.w_rnn_out = nn.Parameter(torch.zeros(config['num_rnn'], config['num_rnn_out']), requires_grad=True)
        self.b_rnn_out = nn.Parameter(torch.zeros(config['num_rnn_out']), requires_grad=True)

        # Helper tensors
        self.initNoise = torch.zeros(config['batch_size'], n_rnn).to(self.device)
        self.zeroStims = torch.zeros(config['batch_size'], n_input).to(self.device)
        self.onlyFixStim = torch.concat([torch.zeros(config['batch_size'], np.prod(config['image_shape'])),
                                         np.float32(config['fixationInput']) * torch.ones(config['batch_size'],
                                                                                          n_input - np.prod(
                                                                                              config['image_shape']))],
                                        dim=1).to(self.device)
        self.config = config

    def call(self, stims, states0, OUnoise):
        states = states0

        # Single-timestep update of state
        incs = torch.matmul(stims, self.in2rnn)
        new_noise = (1.0 - self.alpha_noise) * OUnoise + torch.normal(size=states.shape, mean=0, std=np.sqrt(
            2.0 * self.alpha_noise) * self.sigma).to(self.device)
        act = self.activation(incs + torch.matmul(states, self.kernel) + self.bias + new_noise)
        new_states = (1.0 - self.alpha) * states + self.alpha * act

        ret_states = new_states.float()
        ret_noise = new_noise.float()
        return ret_noise, ret_states, act

    def shared_raster_out(self, config, shared_act):
        output_rnn_ta = []
        states = self.initS
        noise = self.initNoise
        for time in range(config['tdim']):
            rec_act = shared_act[time, :]
            noise = (1.0 - self.alpha_noise) * noise + torch.normal(size=states.shape, mean=0, std=np.sqrt(
                2.0 * self.alpha_noise) * self.sigma).to(self.device)
            states = (1.0 - self.alpha) * states + self.alpha * rec_act
            output_rnn_ta.append(states)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))
        y_hat = torch.matmul(final_rnn_outputs.float(), self.w_rnn_out) + self.b_rnn_out
        return y_hat

    def forward(self, config, imageStims):
        # Reinitialize state
        states = self.initS.to(self.device)
        noise = self.initNoise.to(self.device)

        in_rnn_ta = []
        output_rnn_ta = []
        rec_rnn_ta = []

        for time in range(config['tdim']):
            # Set time-specific inputs
            if time >= config['stimPeriod'][0] and time < config['stimPeriod'][1]:
                stims = imageStims
            elif time < config['fixationPeriod'][1]:
                stims = self.onlyFixStim
            else:
                stims = self.zeroStims

            new_noise, new_states, rec_act = self.call(stims, states, noise)

            in_rnn_ta.append(stims)
            output_rnn_ta.append(new_states)
            rec_rnn_ta.append(rec_act)

            states = new_states
            noise = new_noise

        final_rnn_inputs = torch.stack(in_rnn_ta)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))

        rec_rnn_outputs = torch.stack(rec_rnn_ta)
        rec_rnn_outputs = torch.reshape(rec_rnn_outputs, (-1, config['num_rnn']))

        y_hat = torch.matmul(final_rnn_outputs, self.w_rnn_out) + self.b_rnn_out
        return (y_hat, final_rnn_outputs, final_rnn_inputs, rec_rnn_outputs)
