from typing import Callable
import numpy as np
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, surrogate
from spikingjelly.activation_based.neuron import BaseNode


class CustomNeuron(BaseNode):
    def __init__(self,
                 n_rnn=256,
                 tau_minitializer='constant',
                 low_m=0, high_m=4,
                 v_threshold: float = 1.,
                 v_reset: float = 0.,
                 surrogate_function: Callable = surrogate.Sigmoid(),
                 detach_reset: bool = False,
                 step_mode='s',
                 backend='torch',
                 store_v_seq: bool = False
                 ):

        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)

        self.tau_m = nn.Parameter(torch.Tensor(n_rnn))
        if tau_minitializer == 'uniform':
            nn.init.uniform_(self.tau_m, low_m, high_m)
        elif tau_minitializer == 'constant':
            nn.init.constant_(self.tau_m, low_m)

        self.v_threshold = nn.Parameter(torch.ones(n_rnn) * v_threshold)

    @property
    def supported_backends(self):
        if self.step_mode == 's':
            return ('torch',)
        elif self.step_mode == 'm':
            return ('torch', 'cupy')
        else:
            raise ValueError(self.step_mode)

    def extra_repr(self):
        with torch.no_grad():
            tau = 1. / self.w.sigmoid()
        return super().extra_repr() + f', tau={tau}'

    def neuronal_charge(self, x: torch.Tensor):
        alpha = self.tau_m.sigmoid()
        self.v = self.v * alpha + (1 - alpha) * x

    def neuronal_fire(self):
        return self.surrogate_function(self.v - self.v_threshold)

    def single_step_forward(self, x: torch.Tensor):
        self.v_float_to_tensor(x)
        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike


def gen_ortho_matrix(dim, rng=None):
    '''Generate random orthogonal matrix
    '''
    H = np.eye(dim)
    for n in range(1, dim):
        if rng is None:
            x = np.random.normal(size=(dim - n + 1,))
        else:
            x = rng.normal(size=(dim - n + 1,))
        D = np.sign(x[0])
        x[0] += D * np.sqrt((x * x).sum())
        Hx = -D * (np.eye(dim - n + 1) - 2. * np.outer(x, x) / (x * x).sum())
        mat = np.eye(dim)
        mat[n - 1:, n - 1:] = Hx
        H = np.dot(H, mat)
    return H


class customSNN(nn.Module):
    def __init__(self, config, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.seed = config['seed']
        self.rng = config['rng']
        n_input = config['num_input']
        n_rnn = config['num_rnn']

        self.alpha = 0.01
        self.alpha_noise = 0.5
        self.sigma = 0.05
        self.bias_start = 0.0
        self.w_in_start = 1.0
        self.w_rec_start = 1.0

        self.in2rnn = layer.Linear(n_input, n_rnn, bias=False)
        w_in2rnn0 = np.float32(self.rng.randn(n_input, n_rnn) / np.sqrt(n_input) * self.w_in_start)
        self.in2rnn.weight.data.copy_(torch.from_numpy(w_in2rnn0.transpose()))

        self.rec = layer.Linear(n_rnn, n_rnn, bias=True)
        w_rec0 = np.float32(self.w_rec_start * gen_ortho_matrix(n_rnn, rng=self.rng))
        self.rec.weight.data.copy_(torch.from_numpy(w_rec0.transpose()))
        b_rec0 = np.float32(np.ones(n_rnn) * self.bias_start)
        self.rec.bias.data.copy_(torch.from_numpy(b_rec0))

        self.in2rnn = self.in2rnn.to(self.device)
        self.rec = self.rec.to(self.device)

        self.activation = CustomNeuron(n_rnn=n_rnn,
                                       tau_minitializer=config['tau_minitializer'],
                                       low_m=config['low_m'], high_m=config['high_m'],
                                       v_threshold=config['neuron_thr'],
                                       v_reset=0.,
                                       surrogate_function=surrogate.Sigmoid(),
                                       detach_reset=True)

        w_rnn2out = torch.zeros(config['num_rnn'], config['num_rnn_out'])
        b_rnn2out = torch.zeros(config['num_rnn_out'])
        self.out = layer.Linear(config['num_rnn'], config['num_rnn_out'], bias=True)
        self.out.weight.data.copy_(torch.transpose(w_rnn2out, dim0=0, dim1=1))
        self.out.bias.data.copy_(b_rnn2out)
        self.out = self.out.to(device)

        # Initialize states from 0~config['neuron_thr']
        self.initS = torch.zeros(config['batch_size'], n_rnn)
        self.initNoise = torch.zeros(config['batch_size'], n_rnn)

        self.zeroStims = torch.zeros(config['batch_size'], n_input).to(device)
        self.onlyFixStim = torch.concat([
            torch.zeros(config['batch_size'], np.prod(config['image_shape']), device=device),
            np.float32(config['fixationInput']) * torch.ones(config['batch_size'],
                                                             n_input - np.prod(config['image_shape']), device=device)],
            dim=1)
        self.config = config

        functional.set_step_mode(self, step_mode='s')
        functional.set_backend(self, backend='torch')

    def call(self, stims, states0, OUnoise):
        states = states0

        # Single-timestep update of state
        incs = self.in2rnn(stims)
        new_noise = (1.0 - self.alpha_noise) * OUnoise + torch.normal(size=states.shape, mean=0, std=np.sqrt(
            2.0 * self.alpha_noise) * self.sigma).to(self.device)
        rec_spiking = self.activation(incs + self.rec(states) + new_noise)
        new_states = (1.0 - self.alpha) * states + \
                     self.alpha * rec_spiking

        ret_states = new_states.float()
        ret_noise = new_noise.float()
        return ret_noise, ret_states, rec_spiking

    def forward(self, config, imageStims):
        # Reinitialize state
        states = self.initS.to(self.device)
        noise = self.initNoise.to(self.device)

        in_rnn_ta = []
        output_rnn_ta = []
        spiking_rnn_ta = []

        for time in range(config['tdim']):
            # Set time-specific inputs
            if time >= config['stimPeriod'][0] and time < config['stimPeriod'][1]:
                stims = imageStims
            elif time < config['fixationPeriod'][1]:
                stims = self.onlyFixStim
            else:
                stims = self.zeroStims

            new_noise, new_states, rec_spiking = self.call(stims, states, noise)

            functional.reset_net(self)

            in_rnn_ta.append(stims)
            output_rnn_ta.append(new_states)
            spiking_rnn_ta.append(rec_spiking)

            states = new_states
            noise = new_noise

        final_rnn_inputs = torch.stack(in_rnn_ta)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))

        spiking_rnn_outputs = torch.stack(spiking_rnn_ta)
        spiking_rnn_outputs = torch.reshape(spiking_rnn_outputs, (-1, config['num_rnn']))

        y_hat = self.out(final_rnn_outputs)
        return (y_hat, final_rnn_outputs, final_rnn_inputs, spiking_rnn_outputs)

    def shared_raster_out(self, config, shared_raster):
        output_rnn_ta = []
        states = self.initS
        noise = self.initNoise
        for time in range(config['tdim']):
            rec_spiking = shared_raster[time, :]
            noise = (1.0 - self.alpha_noise) * noise + torch.normal(size=states.shape, mean=0, std=np.sqrt(
                2.0 * self.alpha_noise) * self.sigma).to(self.device)
            states = (1.0 - self.alpha) * states + self.alpha * rec_spiking
            output_rnn_ta.append(states)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))
        y_hat = self.out(final_rnn_outputs)
        return y_hat
