from typing import Any, cast
from collections.abc import Mapping
from numbers import Number

import torch
import numpy as np
from torch import Tensor

from cirkit.backend.torch.layers import TorchInputLayer
from cirkit.utils.scope import Scope
from cirkit.symbolic.parameters import (
    Parameter,
    TensorParameter,
    ParameterFactory,
    ConstantParameter,
)
from cirkit.symbolic.layers import InputLayer
from cirkit.backend.torch.parameters.parameter import TorchParameter
from cirkit.backend.torch.semiring import Semiring, SumProductSemiring
from cirkit.symbolic.initializers import NormalInitializer, ConstantTensorInitializer


class FourierLayer(InputLayer):
    def __init__(
        self,
        scope: Scope,
        num_output_units: int,
        period: Number | Parameter,
        freqs: Parameter | None = None,
        bias: Parameter | None = None,
        freqs_factory: ParameterFactory | None = None,
        bias_factory: ParameterFactory | None = None,
    ):
        r"""Initializes a Fourier basis layer, which defines a set of complex exponentials over
        a closed interval. That is, given a vector of frequencies $\kappa\in\mathbb{Z}^K$, where
        $K$ is the number of Fourier basis, this layer encodes the functions
        $$
        f_j(x) = e^{-i2\kappa_j\pi x}
        $$
        where $1\leq j\leq K$.
        If frequencies are not given, then by default it uses the frequencies $[0, \cdots, K - 1]$.
        The Fourier layer also supports the optional specification of a log partition function, or
        renormalizer. If it is None, then it is chosen such that the product of a complex exponential
        and its conjugate having different frequencies integrates to 1 over the specified
        closed interval. That is, if the log partition is None, then it is implicitly assumed to
        be equal to $-\frac{1}{2} \log(xmax - xmin)$.

        Args:
            scope: The variables scope the layer depends on.
            num_output_units: The number of Fourier basis.
            xmin: The minimum of the interval.
            xmax: The maximum of the interval.
            freqs: The frequencies parameter of shape $(K,)$, where $K$ is the number of output
                units.
            freqs_factory: A factory used to construct the freqs parameter if not specified.

        Raises:
            ValueError: If the scope is either empty or contains more than one variable.
            ValueError: If the given interval is not valid.
            ValueError: If the frequencies parameter shape is not correct.
            ValueError: If the log partition parameter shape is not correct.
        """
        if len(scope) != 1:
            raise ValueError("The Fourier layer encodes univariate functions")
        super().__init__(scope, num_output_units)
        
        if not isinstance(period, Parameter):
            if period <= 0:
                raise ValueError("The period must be positive")
        
            period = Parameter.from_input(ConstantParameter(1, value=period))
        # TODO check that it is not learnable
        
        if bias is None:
            if bias_factory is None:
                bias = Parameter.from_input(
                    TensorParameter(*self._freqs_shape, initializer=NormalInitializer(0.0, 1.0))
                )
            else:
                bias = bias_factory(self._freqs_shape)
                
        if freqs is None:
            if freqs_factory is None:
                assert self.num_output_units % 2 == 1, f'Fourier layers require an odd number of frequencies, \
                    got {self.num_output_units} instead'
                freqs = Parameter.from_input(
                    ConstantParameter(  
                        self.num_output_units,
                        value=np.arange(-(self.num_output_units // 2), self.num_output_units // 2 + 1)
                    )
                )
            else:
                freqs = freqs_factory(self._freqs_shape)
        if freqs.shape != self._freqs_shape:
            raise ValueError(
                f"Expected 'freqs' parameter shape {self._freqs_shape}, found {freqs.shape}"
            )
        
        self.bias = bias
        self.freqs = freqs
        self.period = period

    @property
    def _freqs_shape(self) -> tuple[int, ...]:
        return (self.num_output_units,)

    @property
    def config(self) -> Mapping[str, Any]:
        return {
            "scope": self.scope,
            "num_output_units": self.num_output_units,
        }

    @property
    def params(self) -> Mapping[str, Parameter]:
        return {"freqs": self.freqs, "bias": self.bias, "period": self.period}


class TorchFourierLayer(TorchInputLayer):
    def __init__(
        self,
        scope_idx: Tensor,
        num_output_units: int,
        *,
        period: TorchParameter,
        freqs: TorchParameter,
        bias: TorchParameter,
        semiring: Semiring | None = None,
    ) -> None:
        r"""Initialize a Fourier basis layer.

        Args:
            scope_idx: A tensor of shape $(F, D)$, where $F$ is the number of folds, and
                $D$ is the number of variables on which the input layers in each fold are defined
                on. Alternatively, a tensor of shape $(D,)$ can be specified, which will be
                interpreted as a tensor of shape $(1, D)$, i.e., with $F = 1$.
            num_output_units: The number of output units.
            period: The period of the modelled function.
            freqs: The frequencies parameter, having shape $(F, K)$, where $K$ is the number
                of output units.

        Raises:
            ValueError: If the scope is either empty or contains more than one variable.
            ValueError: If the given interval is not valid.
            ValueError: If the frequencies parameter shape is not correct.
            ValueError: If the log partition parameter shape is not correct.
        """
        num_variables = scope_idx.shape[-1]
        if num_variables != 1:
            raise ValueError("The Fourier layer encodes a univariate distribution")
        super().__init__(
            scope_idx,
            num_output_units,
            semiring=semiring,
        )
        self.period = period
        if not self._valid_freqs_shape(freqs):
            raise ValueError(
                f"Expected number of folds {self.num_folds} "
                f"and shape {self._freqs_shape} for 'freqs', found"
                f"{freqs.num_folds} and {freqs.shape}, respectively"
            )
        self.freqs = freqs
        self.bias = bias
        self._i2pi = 1j * 2 * np.pi

    def _valid_freqs_shape(self, p: TorchParameter) -> bool:
        if p.num_folds != self.num_folds:
            return False
        return p.shape == self._freqs_shape

    @property
    def _freqs_shape(self) -> tuple[int, ...]:
        return (self.num_output_units,)

    @property
    def config(self) -> Mapping[str, Any]:
        return {"num_output_units": self.num_output_units}

    @property
    def params(self) -> Mapping[str, TorchParameter]:
        return {"freqs": self.freqs, "bias": self.bias, "period": self.period}
    
    @property
    def log_partition(self):
        return -0.5 * torch.log(self.period())

    def forward(self, x: Tensor) -> Tensor:
        # x: (F, B, D=1)
        # freqs: (F, K) -> (F, 1, K)
        freqs = self.freqs().unsqueeze(dim=1)
        # log_partition (F, K) -> (F, 1, K)
        log_partition = self.log_partition.unsqueeze(dim=1)
        # bias: (F, K) -> (F, 1, K)
        bias = self.bias().unsqueeze(dim=1).real
        # y: (F, B, K)
        y = torch.exp(self._i2pi * freqs * (x + bias) / self.period().unsqueeze(dim=1) + log_partition)
        return self.semiring.map_from(y, SumProductSemiring)


def compile_fourier_layer(compiler: "TorchCompiler", sl: FourierLayer) -> TorchFourierLayer:
    bias = compiler.compile_parameter(sl.bias)
    freqs = compiler.compile_parameter(sl.freqs)
    period = compiler.compile_parameter(sl.period)
    
    return TorchFourierLayer(
        torch.tensor(tuple(sl.scope)),
        sl.num_output_units,
        period=period,
        freqs=freqs,
        bias=bias,
        semiring=compiler.semiring
    )
