from typing import Optional

import torch as t
import torch.autograd as autograd
from torch import nn

from auto_encoder import device


class RectangleFunction(autograd.Function):
    @staticmethod
    def forward(ctx: autograd.Function, pre_activation_features_BSF: t.Tensor) -> t.Tensor:
        ctx.save_for_backward(pre_activation_features_BSF)

        rectangle_features_BSF = (
            (pre_activation_features_BSF > -0.5) & (pre_activation_features_BSF < 0.5)
        ).float()
        return rectangle_features_BSF

    @staticmethod
    def backward(ctx: autograd.Function, grad_output_BSF: t.Tensor) -> t.Tensor:
        (x,) = ctx.saved_tensors

        grad_input_BSF = grad_output_BSF.clone()
        grad_input_BSF[(x <= -0.5) | (x >= 0.5)] = 0

        return grad_input_BSF


class JumpReLUFunction(autograd.Function):
    @staticmethod
    def forward(
        ctx: autograd.Function,
        pre_activation_features_BSF: t.Tensor,
        log_threshold: t.Tensor,
        bandwidth: float,  # no gradients flow to the bandwidth so set as a float
        jump_threshold_offset: float,
    ) -> t.Tensor:

        bandwidth_tensor = t.tensor(bandwidth, device=device)
        jump_threshold_offset_tensor = t.tensor(jump_threshold_offset, device=device)

        ctx.save_for_backward(
            pre_activation_features_BSF,
            log_threshold,
            bandwidth_tensor,
            jump_threshold_offset_tensor,
        )

        threshold = t.exp(log_threshold) - jump_threshold_offset

        features_BSF = (
            pre_activation_features_BSF * (pre_activation_features_BSF > threshold).float()
        )

        return features_BSF

    @staticmethod
    def backward(
        ctx: autograd.Function, grad_output_BSF: t.Tensor
    ) -> tuple[t.Tensor, t.Tensor, None, None]:
        pre_activation_features_BSF, log_threshold, bandwidth_tensor, jump_threshold_offset_tensor = ctx.saved_tensors  # type: ignore

        bandwidth_tensor: t.Tensor
        jump_threshold_offset_tensor: t.Tensor

        bandwidth = bandwidth_tensor.item()
        jump_threshold_offset = jump_threshold_offset_tensor.item()

        threshold = t.exp(log_threshold) - jump_threshold_offset

        x_grad_BSF = (pre_activation_features_BSF > threshold).float() * grad_output_BSF

        threshold_grad = (
            -(threshold / bandwidth)
            * RectangleFunction.apply((pre_activation_features_BSF - threshold) / bandwidth)
            * grad_output_BSF
        )

        return x_grad_BSF, threshold_grad, None, None  # None for bandwidth


class JumpReLU(nn.Module):
    def __init__(
        self,
        num_features: int,
        bandwidth: float,
        jump_threshold_offset: float,
        device: str = device,
    ):
        super(JumpReLU, self).__init__()
        self.log_threshold = nn.Parameter(t.zeros(num_features, device=device))
        self.bandwidth = bandwidth
        self.jump_threshold_offset = jump_threshold_offset

    def forward(self, pre_activation_features_BSF: t.Tensor) -> t.Tensor:

        features_BSF: t.Tensor = JumpReLUFunction.apply(
            pre_activation_features_BSF,
            self.log_threshold,
            self.bandwidth,
            self.jump_threshold_offset,
        )  # type: ignore

        return features_BSF


class StepFunction(autograd.Function):
    @staticmethod
    def forward(
        ctx: autograd.Function,
        features_BSF: t.Tensor,
        log_threshold: t.Tensor,
        bandwidth: float,
        jump_threshold_offset: float,
    ):
        bandwidth_tensor = t.tensor(bandwidth, device=device)
        jump_threshold_offset_tensor = t.tensor(jump_threshold_offset, device=device)

        ctx.save_for_backward(
            features_BSF, log_threshold, bandwidth_tensor, jump_threshold_offset_tensor
        )

        threshold = t.exp(log_threshold) - jump_threshold_offset

        step_output = (features_BSF > threshold).float()
        return step_output

    @staticmethod
    def backward(
        ctx: autograd.Function, grad_output_BSF: t.Tensor
    ) -> tuple[t.Tensor, t.Tensor, None, None]:
        features_BSF, log_threshold, bandwidth_tensor, jump_threshold_offset_tensor = ctx.saved_tensors  # type: ignore

        bandwidth_tensor: t.Tensor
        jump_threshold_offset_tensor: t.Tensor

        bandwidth = bandwidth_tensor.item()
        jump_threshold_offset = jump_threshold_offset_tensor.item()

        threshold = t.exp(log_threshold) - jump_threshold_offset

        x_grad = t.zeros_like(features_BSF)

        threshold_grad = (
            -(1.0 / bandwidth_tensor)
            * RectangleFunction.apply((features_BSF - threshold) / bandwidth)
            * grad_output_BSF
        )

        return x_grad, threshold_grad, None, None  # None for bandwidth


class Step(nn.Module):
    def __init__(
        self,
        num_features: int,
        bandwidth: float,
        jump_threshold_offset: float,
        device: str = device,
    ):
        super(Step, self).__init__()
        self.log_threshold = nn.Parameter(t.zeros(num_features, device=device))
        self.bandwidth = bandwidth
        self.jump_threshold_offset = jump_threshold_offset

    def forward(
        self, features_BSF: t.Tensor, jump_log_threshold: Optional[t.Tensor] = None
    ) -> t.Tensor:
        # Typically we share the log_threshold from the jump_rely which is being backpropped through.
        log_threshold = (
            jump_log_threshold if jump_log_threshold is not None else self.log_threshold
        )

        step_output: t.Tensor = StepFunction.apply(features_BSF, log_threshold, self.bandwidth, self.jump_threshold_offset)  # type: ignore
        return step_output
        return step_output
