from typing import Callable
import numpy as np
import torch
import torch.nn as nn
from spikingjelly.activation_based import layer, functional, surrogate
from spikingjelly.activation_based.neuron import BaseNode


def gen_ortho_matrix(dim, rng=None):
    '''Generate random orthogonal matrix
    '''
    H = np.eye(dim)
    for n in range(1, dim):
        if rng is None:
            x = np.random.normal(size=(dim - n + 1,))
        else:
            x = rng.normal(size=(dim - n + 1,))
        D = np.sign(x[0])
        x[0] += D * np.sqrt((x * x).sum())
        Hx = -D * (np.eye(dim - n + 1) - 2. * np.outer(x, x) / (x * x).sum())
        mat = np.eye(dim)
        mat[n - 1:, n - 1:] = Hx
        H = np.dot(H, mat)
    return H


class CustomNeuron(BaseNode):
    def __init__(self,
                 n_rnn=256,
                 tau_minitializer='constant',
                 low_m=0, high_m=4,
                 v_threshold: float = 1.,
                 v_reset: float = 0.,
                 surrogate_function: Callable = surrogate.Sigmoid(),
                 detach_reset: bool = False,
                 step_mode='s',
                 backend='torch',
                 store_v_seq: bool = False
                 ):

        super().__init__(v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)

        self.tau_m = nn.Parameter(torch.Tensor(n_rnn))
        if tau_minitializer == 'uniform':
            nn.init.uniform_(self.tau_m, low_m, high_m)
        elif tau_minitializer == 'constant':
            nn.init.constant_(self.tau_m, low_m)

        self.v_threshold = v_threshold

    @property
    def supported_backends(self):
        if self.step_mode == 's':
            return ('torch',)
        elif self.step_mode == 'm':
            return ('torch', 'cupy')
            return ('torch', 'cupy')
        else:
            raise ValueError(self.step_mode)

    def extra_repr(self):
        with torch.no_grad():
            tau = 1. / self.w.sigmoid()
        return super().extra_repr() + f', tau={tau}'

    def neuronal_charge(self, x: torch.Tensor):
        alpha = self.tau_m.sigmoid()
        self.v = self.v * alpha + (1 - alpha) * x

    def neuronal_fire(self):
        return self.surrogate_function(self.v - self.v_threshold)

    def single_step_forward(self, x: torch.Tensor):
        self.v_float_to_tensor(x)
        self.neuronal_charge(x)
        spike = self.neuronal_fire()
        self.neuronal_reset(spike)
        return spike


class customSNN(nn.Module):
    def __init__(self, config, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.config = config

        self.seed = config['seed']
        self.rng = config['rng']

        # for call function
        self.alpha = 0.01
        self.alpha_noise = 0.5
        self.sigma = 0.05

        n_input = config['num_input']
        self.n_input = n_input
        n_rnn = config['num_rnn']
        self.n_rnn = n_rnn
        n_branch = config['num_branch']
        self.n_branch = n_branch
        n_out = config['num_rnn_out']
        self.n_out = n_out
        pad = ((n_input + n_rnn) // n_branch * n_branch + n_branch - (n_input + n_rnn)) % n_branch
        self.pad = pad

        # create model parameters
        self.bias_start = 0.0
        self.w_in_start = 1.0
        self.w_rec_start = 1.0

        self.dense_in = n_input + n_rnn + pad
        self.dense_out = n_rnn * n_branch
        self.dense = nn.Linear(self.dense_in, self.dense_out)

        w_in2rnn0 = np.float32(self.rng.randn(n_input, n_rnn) / np.sqrt(n_input) * self.w_in_start)
        w_in2rnn0 = np.tile(w_in2rnn0, (1, n_branch))
        w_in2rnn0 = torch.from_numpy(w_in2rnn0.transpose())

        w_rec0 = np.float32(self.w_rec_start * gen_ortho_matrix(n_rnn, rng=self.rng))
        w_rec0 = np.tile(w_rec0, (1, n_branch))
        w_rec0 = torch.from_numpy(w_rec0.transpose())

        w_pad = np.float32(np.zeros(shape=(pad, n_rnn)))
        w_pad = np.tile(w_pad, (1, n_branch))
        w_pad = torch.from_numpy(w_pad.transpose())

        dense_w = torch.cat([w_in2rnn0, w_rec0, w_pad], dim=1)
        self.dense.weight.data.copy_(dense_w)
        dense_b = np.float32(np.ones(n_rnn) * self.bias_start)
        dense_b = np.tile(dense_b, (n_branch))
        dense_b = torch.from_numpy(dense_b.transpose())
        self.dense.bias.data.copy_(dense_b)

        self.mask = self.create_mask()

        self.out = layer.Linear(n_rnn, n_out, bias=True)
        w_rnn2out = np.float32(np.zeros(shape=(n_rnn, n_out)))
        w_rnn2out = torch.from_numpy(w_rnn2out.transpose())
        b_rnn2out = np.float32(np.zeros(n_out))
        b_rnn2out = torch.from_numpy(b_rnn2out.transpose())
        self.out.weight.data.copy_(w_rnn2out)
        self.out.bias.data.copy_(b_rnn2out)

        # timing factor of membrane potential and dendritic branches
        self.tau_n = nn.Parameter(torch.Tensor(n_rnn, n_branch))
        if config['tau_ninitializer'] == 'uniform':
            nn.init.uniform_(self.tau_n, config['low_n'], config['high_n'])
        elif config['tau_ninitializer'] == 'constant':
            nn.init.constant_(self.tau_n, config['low_n'])

        # Initialize states
        self.initS = torch.zeros(n_rnn)
        self.initNoise = torch.zeros(n_rnn)

        # define inputs
        self.zeroStims = torch.zeros(config['batch_size'], n_input)
        self.onlyFixStim = torch.concat([
            torch.zeros(config['batch_size'], np.prod(config['image_shape'])),
            np.float32(config['fixationInput']) * torch.ones(config['batch_size'],
                                                             n_input - np.prod(config['image_shape']))], dim=1)

        self.activation = CustomNeuron(n_rnn=n_rnn,
                                       tau_minitializer=config['tau_minitializer'],
                                       low_m=config['low_m'], high_m=config['high_m'],
                                       v_threshold=config['neuron_thr'],
                                       v_reset=0.,
                                       surrogate_function=surrogate.Sigmoid(),
                                       detach_reset=True)

        functional.set_step_mode(self, step_mode='s')
        functional.set_backend(self, backend='torch')

    def create_mask(self):
        input_size = self.n_input + self.n_rnn + self.pad
        mask = torch.zeros(self.n_rnn * self.n_branch, input_size)

        for i in range(self.n_rnn):
            seq = torch.randperm(input_size)
            # 每个neuron随机选择1/n_branch的input_size
            for j in range(self.n_branch):
                mask[i * self.n_branch + j, seq[j * input_size // self.n_branch:(j + 1) * input_size // self.n_branch]] = 1
        return mask

    def apply_mask(self):
        self.dense.weight.data = self.dense.weight.data * self.mask.to(self.device)

    def forward(self, config, imageStims):
        self.apply_mask()

        padding = torch.zeros(config['batch_size'], self.pad).to(self.device)
        d_input = torch.zeros(config['batch_size'], self.n_rnn, self.n_branch).to(self.device)
        beta = torch.sigmoid(self.tau_n).to(self.device)

        states = self.initS.to(self.device)
        states = states.unsqueeze(0).expand(config['batch_size'], -1)
        noise = self.initNoise.to(self.device)
        noise = noise.unsqueeze(0).expand(config['batch_size'], -1)

        in_rnn_ta = []
        output_rnn_ta = []
        spiking_rnn_ta = []
        for time in range(config['tdim']):
            # Set time-specific inputs
            if time >= config['stimPeriod'][0] and time < config['stimPeriod'][1]:
                stims = imageStims
            elif time < config['fixationPeriod'][1]:
                stims = self.onlyFixStim.to(self.device)
            else:
                stims = self.zeroStims.to(self.device)

            inputs = torch.cat([stims.float(), states, padding], 1)
            d_input = beta * d_input + (1 - beta) * self.dense(inputs).reshape(-1, self.n_rnn, self.n_branch)
            n_input = d_input.sum(dim=2, keepdim=False)

            noise = ((1.0 - self.alpha_noise) * noise +
                     torch.normal(size=states.shape, mean=0, std=np.sqrt(2.0 * self.alpha_noise) * self.sigma).to(
                         self.device))
            rec_spiking = self.activation(n_input + noise)
            states = (1.0 - self.alpha) * states + self.alpha * rec_spiking

            in_rnn_ta.append(stims)
            output_rnn_ta.append(states)
            spiking_rnn_ta.append(rec_spiking)

            functional.reset_net(self)

        final_rnn_inputs = torch.stack(in_rnn_ta)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))

        spiking_rnn_outputs = torch.stack(spiking_rnn_ta)
        spiking_rnn_outputs = torch.reshape(spiking_rnn_outputs, (-1, config['num_rnn']))

        y_hat = self.out(final_rnn_outputs)
        return (y_hat, final_rnn_outputs, final_rnn_inputs, spiking_rnn_outputs)

    def damage(self, config, imageStims, damage_neuron_idx):
        self.apply_mask()

        padding = torch.zeros(config['batch_size'], self.pad).to(self.device)
        d_input = torch.zeros(config['batch_size'], self.n_rnn, self.n_branch).to(self.device)
        beta = torch.sigmoid(self.tau_n).to(self.device)

        states = self.initS.to(self.device)
        states = states.unsqueeze(0).expand(config['batch_size'], -1)
        noise = self.initNoise.to(self.device)
        noise = noise.unsqueeze(0).expand(config['batch_size'], -1)

        in_rnn_ta = []
        output_rnn_ta = []
        spiking_rnn_ta = []
        for time in range(config['tdim']):
            # Set time-specific inputs
            if time >= config['stimPeriod'][0] and time < config['stimPeriod'][1]:
                stims = imageStims
            elif time < config['fixationPeriod'][1]:
                stims = self.onlyFixStim.to(self.device)
            else:
                stims = self.zeroStims.to(self.device)

            inputs = torch.cat([stims.float(), states, padding], 1)
            d_input = beta * d_input + (1 - beta) * self.dense(inputs).reshape(-1, self.n_rnn, self.n_branch)
            n_input = d_input.sum(dim=2, keepdim=False)

            noise = ((1.0 - self.alpha_noise) * noise +
                     torch.normal(size=states.shape, mean=0, std=np.sqrt(2.0 * self.alpha_noise) * self.sigma).to(
                         self.device))
            rec_spiking = self.activation(n_input + noise)
            rec_spiking[:, damage_neuron_idx] = 0
            states = (1.0 - self.alpha) * states + self.alpha * rec_spiking

            in_rnn_ta.append(stims)
            output_rnn_ta.append(states)
            spiking_rnn_ta.append(rec_spiking)

            functional.reset_net(self)

        final_rnn_inputs = torch.stack(in_rnn_ta)
        final_rnn_outputs = torch.stack(output_rnn_ta)
        final_rnn_outputs = torch.reshape(final_rnn_outputs, (-1, config['num_rnn']))

        spiking_rnn_outputs = torch.stack(spiking_rnn_ta)
        spiking_rnn_outputs = torch.reshape(spiking_rnn_outputs, (-1, config['num_rnn']))

        y_hat = self.out(final_rnn_outputs)
        return (y_hat, final_rnn_outputs, final_rnn_inputs, spiking_rnn_outputs)
