"""Implements a full residual block around a black box layer.

Configurable options include:
normalization position: prenorm or postnorm
normalization type: batchnorm, layernorm etc.
subsampling/pooling
residual options: feedforward, residual, affine scalars, depth-dependent scaling, etc.
"""
import random
from functools import partial
import torch
from torch import nn
from src.models.nn import LinearActivation, Activation, DropoutNd

from src.models.nn import Normalization, StochasticDepth, DropoutNd
from src.models.sequence import SequenceModule
from src.models.sequence.modules.pool import registry as pool_registry
from src.models.nn.residual import registry as residual_registry
import src.utils as utils
import src.utils.registry as registry
from spikingjelly.clock_driven.neuron import MultiStepLIFNode, MultiStepIFNode
from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer
from torch.autograd import Function


def create_random_mask(size, sparsity=0.9):
    # Generate a mask with 90% of elements being 0
    mask = torch.zeros(size)

    # Number of elements to sample from random normal
    num_non_zero = int(size.numel() * (1 - sparsity))

    # Randomly select positions to set non-zero elements
    non_zero_indices = torch.randperm(size.numel())[:num_non_zero]

    # Set non-zero elements to random values sampled from normal distribution
    mask.view(-1)[non_zero_indices] = 1 #torch.ones(num_non_zero)

    return mask

class SequenceResidualBlock(SequenceModule):
    class Replace(Function):
        @staticmethod
        def forward(ctx, z1, z1_r):
            return z1_r

        @staticmethod
        def backward(ctx, grad):
            return (grad, grad)
    """Flexible residual block design. See model.py for meaning of options."""

    def __init__(
            self,
            d_input,
            i_layer=None, # Only needs to be passed into certain residuals like Decay
            prenorm=True,
            bidirectional=False,
            dropout=0.0,
            tie_dropout=False,
            transposed=False,
            layer=None, # Config for black box module
            residual=None, # Config for residual function
            norm=None, # Config for normalization layer
            pool=None,
            drop_path=0.,
        ):
        super().__init__()

        self.i_layer = i_layer
        self.d_input = d_input
        self.prenorm = prenorm
        self.bidirectional = bidirectional
        self.transposed = transposed
        self.activation = Activation('gelu')
        print(d_input)
        self.layer = utils.instantiate(registry.layer, layer, d_input)

        if self.bidirectional:
            self.reverse_layer = utils.instantiate(registry.layer, layer, d_input)
            self.bidirectional_linear = nn.Linear(2*self.layer.d_output, self.layer.d_output)
            self.bidirectional_linear1 = nn.Linear(self.layer.d_output, int(self.layer.d_output/2))
            self.bidirectional_linear2 = nn.Linear(self.layer.d_output, int(self.layer.d_output/2))
        else:
            self.bidirectional_linear1 = nn.Linear(self.layer.d_output, self.layer.d_output)

        # Residual
        # d_residual is the output dimension after residual
        if residual is None:
            self.residual = None
            self.d_residual = self.layer.d_output
        else:
            self.residual = utils.instantiate(residual_registry, residual, i_layer, d_input, self.layer.d_output)
            self.d_residual = self.residual.d_output

        # Normalization
        d_norm = d_input if self.prenorm else self.d_residual
        # We don't use config to directly instantiate since Normalization has some special cases
        if norm is None:
            self.norm = None
        elif isinstance(norm, str):
            self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm)
        else:
            self.norm = Normalization(d_norm, transposed=self.transposed, **norm)

        # Pool
        self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed)

        # Dropout
        dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout
        self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity()

        # Stochastic depth
        self.drop_path = StochasticDepth(drop_path, mode='row') if drop_path > 0.0 else nn.Identity()
        # sample = torch.randn((128, 2048, 256))
        # self.mask = create_random_mask(sample.size(), sparsity=0.20)

    @property
    def d_output(self):
        return self.pool.d_output if self.pool is not None else self.d_residual

    @property
    def d_state(self):
        return self.layer.d_state

    @property
    def state_to_tensor(self):
        return self.layer.state_to_tensor

    def default_state(self, *args, **kwargs):
        return self.layer.default_state(*args, **kwargs)

    def forward(self, x, state=None, **kwargs):
        device = x.device
        #x = x * self.mask[:,:x.shape[1],:].to(device)
        self.PoissonEncoder = encoding.PoissonEncoder()
        size = x.size()

        y = x
        # Pre-norm
        if self.norm is not None and self.prenorm: y = self.norm(y)

        random_uniform = torch.rand_like(y)
        random_uniform_2 = torch.rand_like(y)

        #y = y + (torch.randn_like(x))/10

        #Spikes
        y_spikes = torch.where(y > random_uniform, torch.ones_like(y), torch.zeros_like(y)) #+ torch.where(y < -1 * random_uniform, -1 * torch.ones_like(y), torch.zeros_like(y))
        # if random.uniform(0,1) <0.02:
        #     print('Yspikes : ', y_spikes[0,0,100:130])
        #y_spikes = self.lif_activation(y.clone().detach().transpose(1,2).contiguous()).transpose(1,2).contiguous()

        y = torch.clamp(y, 0, 1)

        # if random.uniform(0, 1) <= .25:
        #     stochastic_activity = torch.sum(y_spikes, dim=(0, 2))
        #     div = y.shape[0] * y.shape[2]
        #     seq_wise = stochastic_activity  / div
        #     print('SHAPE: ', seq_wise.shape)
        #     print('Input', seq_wise[0:200])
        #     # print('activity : ', y_for[0][500:550][10])
        #
        #     # print('probability of firing : ', y[10,100,:])

        y = self.Replace.apply(y, y_spikes)
        # Black box layer
        # if random.uniform(0, 1) <= 1:
        #     stochastic_activity = torch.sum(y_spikes, dim=(0, 1))
        #     div = y.shape[0] * y.shape[1]
        #     seq_wise = stochastic_activity / div
        #     print('Batch Size: ', y.shape[0])
        #     print('Seq Len: ', y.shape[1])
        #     print('Input Total Spikes: ', torch.sum(y_spikes, dim=(0, 1, 2)))
        #     print('Neuron count: ', seq_wise.shape)
        #     #print('Input Avg activity over batch and time', seq_wise)
        #     # print('Forward', seq_wise[0:200])
        #     # print('activity : ', y_for[0][500:550][10])
        #     # print('Output Spikes1 : ', y_for_spikes[10,:1000,75])
        #     print('Input Spikes : ', y[10,:1000,10])

        y_for, new_state = self.layer(y, state=state, **kwargs)


        # Adding fault

        if self.bidirectional:
            # cloned_data = x.clone()
            #
            # # Flip the cloned matrix vertically (along rows)
            # flipped_data = torch.flip(cloned_data, dims=[0])
            y = x #flipped_data

            y_spikes = torch.where(y > random_uniform_2, torch.ones_like(y), torch.zeros_like(y)) #+ torch.where(y < -1 * random_uniform, -1 * torch.ones_like(y), torch.zeros_like(y))
            # y_spikes = torch.where(y > random_uniform, torch.ones_like(y), torch.zeros_like(y)) + torch.where(
            #     y - 1 > random_uniform, torch.ones_like(y), torch.zeros_like(y)) + torch.where(y - 2 > random_uniform,
            #                                                                                    torch.ones_like(y),
            #                                                                                    torch.zeros_like(
            #                                                                                        y)) + torch.where(
            #     y < -1 * random_uniform, -1 * torch.ones_like(y), torch.zeros_like(y))

            #y_spikes = self.lif_activation_Seq_dim(y.clone().detach().contiguous()).contiguous()

            #y_spikes = self.PoissonEncoder(y)
            y = torch.clamp(y, 0, 1)

            y = self.Replace.apply(y, y_spikes)

            #y = y * self.mask[:, :y.shape[1], :].to(device)
            y_rev, _ = self.reverse_layer(y, state=state, **kwargs)

            #y = torch.cat([y_for, y_rev], dim=-1)
            #y_rev = self.bidirectional_linear2(y_rev)
            #y = torch.cat([y_for, y_rev], dim=-1)

            y_for_spikes = torch.where(y_for > random_uniform, torch.ones_like(y_for), torch.zeros_like(y_for)) #+ torch.where(y_for < -1 * random_uniform, -1 * torch.ones_like(y_for), torch.zeros_like(y_for))

            y = torch.clamp(y_for, 0, 1)

            # if random.uniform(0,1) <= 1:
            #     stochastic_activity = torch.sum(y_for_spikes, dim=(0, 1))
            #     div = y.shape[0] * y.shape[1]
            #     seq_wise = stochastic_activity / div
            #     print('Batch Size: ', y.shape[0])
            #     print('Seq Len: ', y.shape[1])
            #     print('Neuron Count: ', y.shape[2])
            #     print('Total Spikes: ', torch.sum(y_for_spikes, dim=(0, 1, 2)))
            #     print('Neuron count: ',seq_wise.shape)
            #     #print('Avg activity over batch and time', seq_wise)
            #     # print('Forward', seq_wise[0:200])
            #     # print('activity : ', y_for[0][500:550][10])
            #     print('Output Spikes1 : ', y_for_spikes[10,:1000,10])
            #     print('probability of firing1 : ', y[10,:1000,10])

            y_for = self.Replace.apply(y, y_for_spikes)


            # if random.uniform(0,1) <= 1:
            #     sum_prob_per_neuron = torch.sum(y, dim=(0, 1))
            #     div = y.shape[0] * y.shape[1]
            #     avg_probability_per_neuron = sum_prob_per_neuron / div
            #     print('probability of firing : ', avg_probability_per_neuron)
            #     stochastic_activity = torch.sum(y, dim=(0, 2))
            #     div = y.shape[0] * y.shape[2]
            #     seq_wise = stochastic_activity / div
            #     #print('prob : ', y[0][500:550][10])
            #     #print('activity : ', y_for[0][500:550][10])
            #
            #     print('stochastic_activity : ', seq_wise[400:500])
            #     actual_activity = torch.sum(y_for_spikes, dim=(0, 2))
            #     seq_wise = actual_activity / div
            #     print('actual_activity : ', seq_wise[0:100])


            y_rev_spikes = torch.where(y_rev > random_uniform, torch.ones_like(y_rev), torch.zeros_like(y_rev)) #+ torch.where(y_rev < -1 * random_uniform, -1 * torch.ones_like(y_rev), torch.zeros_like(y_rev))
            #y_rev_spikes = self.lif_activation_bi_dim2(y_rev.clone().detach().contiguous()).contiguous()
            y = torch.clamp(y_rev, 0, 1)

            # if random.uniform(0,1) <= 1:
            #     stochastic_activity = torch.sum(y_rev_spikes, dim=(0, 2))
            #     div = y.shape[0] * y.shape[2]
            #     seq_wise = stochastic_activity / div
            #     print('Batch Size: ', y.shape[0])
            #     print('Seq Len: ', y.shape[1])
            #     print('Neuron Count: ', y.shape[2])
            #     print('Total Spikes: ', torch.sum(y_rev_spikes, dim=(0, 1, 2)))
            #     print('Total activity over neuron and batch', seq_wise[:1000])
            #     print('Continued Total activity over neuron and batch', seq_wise[1000:])
            #
            #     print('Forward', seq_wise[0:200])
            #     print('activity : ', y_for[0][500:550][10])
            #     print('Output Spikes1 : ', y_for_spikes[10,:1000,75])
            #     print('probability of firing1 : ', y[10,:1000,75])

            y_rev = self.Replace.apply(y, y_rev_spikes)

            y = torch.cat([y_for, y_rev], dim=-1)

            y = self.bidirectional_linear(y)
            y = self.activation(y)


            # y_spikes = self.lif_activation_bi_dim(y.clone().detach().transpose(1,2).contiguous()).transpose(1,2).contiguous()
            # y = torch.clamp(y, 0, 1)
            # y = self.Replace.apply(y, y_spikes)

        else:
            y_for_spikes = torch.where(y_for > random_uniform, torch.ones_like(y_for), torch.zeros_like(y_for))
            #y_for_spikes = self.lif_activation_bi_dim1(y_for.clone().detach().transpose(1,2).contiguous()).transpose(1,2).contiguous()
            y = torch.clamp(y_for, 0, 1)
            # activity = torch.sum(y, dim=(0,2))
            # div = y.shape[0] * y.shape[2]
            # seq_wise = activity/div
            # print(seq_wise[40:70])

            # if random.uniform(0,1) <= 1:
            #     stochastic_activity = torch.sum(y, dim=(0, 2))
            #     div = y.shape[0] * y.shape[2]
            #     seq_wise = stochastic_activity / div
            #     #print('prob : ', y[0][500:550][10])
            #     #print('activity : ', y_for[0][500:550][10])
            #
            #     print('probability of firing : ', y[10,100,:])


            y_for = self.Replace.apply(y, y_for_spikes)
            y = y_for


            y = self.bidirectional_linear1(y)
            y = self.activation(y)

        # Residual can be turned of for MNIST AND SC10: MIGHT GIVE BETTER RESULTS
        if self.residual is not None:
            y = self.residual(x, self.drop_path(self.drop(y)), self.transposed)

        # Post-norm
        if self.norm is not None and not self.prenorm:
            #print('Post Norm')
            y = self.norm(y)

        # Pool can be turned of for MNIST AND SC10: MIGHT GIVE BETTER RESULTS
        if self.pool is not None: y, _ = self.pool(y)

        # y_spikes = torch.where(y > random_uniform, torch.ones_like(y), torch.zeros_like(y)) #self.lif_activation(y.clone().detach().contiguous()).contiguous()
        # y = torch.clamp(y, 0, 1)
        # y = self.Replace.apply(y, y_spikes)

        return y, state

    def step(self, x, state, **kwargs):
        assert not self.bidirectional
        y = x

        # Pre-norm
        if self.norm is not None and self.prenorm:
            y = self.norm.step(y)

        # Black box layer
        y, state = self.layer.step(y, state, **kwargs)

        # Residual
        if self.residual is not None: y = self.residual(x, y, transposed=False) # NOTE this would not work with concat residual function (catformer)

        # Post-norm
        if self.norm is not None and not self.prenorm:
            y = self.norm.step(y)

        # Pool
        if self.pool is not None: y, _ = self.pool(y)

        return y, state
