from enum import IntEnum, Enum

import torch
import torch.nn as nn
from torch import Tensor

torch.autograd.set_detect_anomaly(True)


class Dim(IntEnum):
    batch = 0
    seq = 1
    features = 2  # all features = n_units * unit_size


class EGRUThresholdInit(Enum):
    zero_scalar = 'zero-scalar'
    zero_vector = 'zero-vector'
    rand_vector = 'random-vector'
    const_scalar = 'const-scalar'


class SpikeFunction(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(ctx, inp, dampening_factor, exponential, pseudo_derivative_width):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that can be used
        to stash information for backward computation. You can cache arbitrary
        objects for use in the backward pass using the ctx.save_for_backward method.
        """
        ctx.dampening_factor = dampening_factor
        ctx.exponential = exponential
        ctx.pseudo_derivative_width = pseudo_derivative_width
        ctx.save_for_backward(inp)
        return torch.heaviside(inp, inp)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        inp, = ctx.saved_tensors
        dampening_factor, exponential, pseudo_derivative_width = ctx.dampening_factor, ctx.exponential, ctx.pseudo_derivative_width
        dE_dz = grad_output
        exponential = bool(exponential)
        # print(f"Pseudo-derivative width is {pseudo_derivative_width}")
        if not exponential:
            dz_du = dampening_factor * torch.maximum(1 - pseudo_derivative_width * torch.abs(inp),
                                                     Tensor((0,)).to(grad_output.device))
        else:
            dz_du = dampening_factor * torch.exp(-torch.abs(inp))
        dE_dv = dE_dz * dz_du
        return dE_dv, None, None, None


class RNNStatefulWrapper(nn.Module):
    """
    This class takes care of keeping track of the hidden states after every call.
    Does not work for `AsNet`!
    Or for `EVNN`!
    """

    def __init__(self, rnn: nn.Module):
        super(RNNStatefulWrapper, self).__init__()
        self.rnn = rnn

        self.n_layers = 1
        self.rnn_out = None

        self.initHidden()

    ## NOTE: Needed to reinitialize hidden after every batch
    def initHidden(self):
        self.hidden = None

    @property
    def hidden_size(self):
        return self.rnn.hidden_size

    def forward(self, all_inputs):
        self.rnn_out, self.hidden = self.rnn(all_inputs, self.hidden)
        return self.rnn_out, self.hidden

    def get_last_output(self):
        """
        Returns exactly the return value that the forward function returned this step
        Read-only function
        """
        return self.rnn_out, self.hidden


class RNNReadoutWrapper(nn.Module):
    """
    This class puts a readout on top of the passed in RNN.
    For pytorch LSTM, the outputs contain the hidden states of only the last layer.
    BUT DOESN"T USE READOUT. HAS TO BE DONE OUTSIDE.
    """

    def __init__(self, rnn: RNNStatefulWrapper, output_size: int):
        super(RNNReadoutWrapper, self).__init__()
        self.rnn = rnn
        self.output_size = output_size

        # self.hidden2out = nn.Sequential(nn.Linear(hidden_size, output_size), nn.Sigmoid())
        # NOTE: That there is no sigmoid here, since it's applied at the loss function
        # self.hidden2out = nn.Sequential(nn.Linear(hidden_size, output_size))
        self.hidden2out = nn.Linear(self.rnn.hidden_size, self.output_size)
        # self.out_vals = None

    ## NOTE: Needed to reinitialize hidden after every batch
    # def initHidden(self):
    #     self.rnn.initHidden()

    def forward(self, all_inputs):
        rnn_out, hidden = self.rnn(all_inputs)
        # out_vals = self.hidden2out(rnn_out)
        # self.out_vals = out_vals

        return rnn_out, hidden

    # def get_last_output(self):
    #     """
    #     Returns exactly the return value that the forward function returned this step
    #     Read-only function
    #     """
    #     return (self.out_vals, *self.rnn.get_last_output())
