from typing import Callable, Any, cast

import torch
from torch import Tensor

from cirkit.backend.torch.parameters.nodes import TorchEntrywiseParameterOp, TorchUnaryParameterOp, TorchBinaryParameterOp
from cirkit.symbolic.parameters import EntrywiseParameterOp, UnaryParameterOp, BinaryParameterOp


class NegateParameter(EntrywiseParameterOp):
    """A symbolic parameter operator representing the entry-wise negation of a parameter
    tensor.
    """


class FourierIntegralParameter(BinaryParameterOp):
    def __init__(
        self,
        in_shape1: tuple[int, ...],
        in_shape2: tuple[int, ...],
    ):
        super().__init__(in_shape1, in_shape2)

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape1


class ExpandParameter(UnaryParameterOp):
    def __init__(self, in_shape: tuple[int, ...], *, shape: tuple[int, ...]):
        super().__init__(in_shape)
        self._shape = shape

    @property
    def shape(self) -> tuple[int, ...]:
        return self._shape

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["shape"] = self.shape
        return config


class TorchNegateParameter(TorchEntrywiseParameterOp):
    """Negate reparameterization."""

    def forward(self, x: Tensor) -> Tensor:
        return -x


class TorchFourierIntegralParameter(TorchBinaryParameterOp):
    def __init__(
        self,
        in_shape1: tuple[int, ...],
        in_shape2: tuple[int, ...],
        *,
        num_folds: int = 1,
    ):
        super().__init__(in_shape1, in_shape2, num_folds=num_folds)

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        # Non-zero frequency implies integral annihilation
        return torch.where(x == 0, y, 0.0)
        # return y.masked_fill(x != 0, 0.0)

    @property
    def shape(self) -> tuple[int, ...]:
        return self.in_shape1


class TorchExpandParameter(TorchUnaryParameterOp):
    def __init__(
        self,
        in_shape: tuple[int, ...],
        shape: tuple[int, ...],
        *,
        num_folds: int = 1,
    ):
        super().__init__(in_shape, num_folds=num_folds)
        self._shape = shape

    @property
    def shape(self) -> tuple[int, ...]:
        return self._shape

    @property
    def config(self) -> dict[str, Any]:
        config = super().config
        config["shape"] = self.shape
        return config

    def forward(self, x: Tensor) -> Tensor:
        return x.expand(size=(-1, *self._shape))


def compile_negate_parameter(compiler: "TorchCompiler", p: NegateParameter) -> TorchNegateParameter:
    (in_shape,) = p.in_shapes
    return TorchNegateParameter(in_shape)


def compile_fourier_integral_parameter(compiler: "TorchCompiler", p: FourierIntegralParameter) -> TorchFourierIntegralParameter:
    (in_shape1,in_shape2) = p.in_shapes
    return TorchFourierIntegralParameter(in_shape1, in_shape2)


def compile_expand_parameter(compiler: "TorchCompiler", p: ExpandParameter) -> TorchExpandParameter:
    (in_shape,) = p.in_shapes
    return TorchExpandParameter(in_shape, p.shape)
