from typing import Any

import numpy as np

from cirkit.symbolic.circuit import CircuitBlock
from cirkit.symbolic.layers import ConstantValueLayer
from cirkit.symbolic.parameters import Parameter, OuterSumParameter, ConstantParameter, SumParameter, LogParameter
from cirkit.utils.scope import Scope
from layers import FourierLayer
from parameters import NegateParameter, ExpandParameter, FourierIntegralParameter


def conjugate_fourier_layer(sl: FourierLayer) -> CircuitBlock:
    # Using negative frequencies to conjugate the complex exponentials
    freqs = Parameter.from_unary(NegateParameter(sl.freqs.shape), sl.freqs.ref())
    sl = FourierLayer(
        sl.scope, sl.num_output_units, sl.period, freqs=freqs
    )
    return CircuitBlock.from_layer(sl)


def multiply_fourier_layers(sl1: FourierLayer, sl2: FourierLayer) -> CircuitBlock:
    if sl1.scope != sl2.scope:
        raise ValueError(
            f"Expected Fourier layers to have the same scope,"
            f" but found '{sl1.scope}' and '{sl2.scope}'"
        )

    # Check the intervals are overlapping, at least for now
    # TODO adapt this part of the code
    # period1, period2 = sl1.period, sl2.period
    # if not np.isclose(period1, period2):
    #     raise NotImplementedError(
    #         f"Multiplication of Fourier layers having different"
    #         f"periods is not supported yet"
    #     )
    period = Parameter.from_binary(
        OuterSumParameter(sl1.period.shape, sl2.period.shape), sl1.period.ref(), sl2.period.ref()
    )
    
    # Multiply complex exponentials by summing the frequencies
    # A simple sum is enough because we assumed above they have the same periods
    freqs = Parameter.from_binary(
        OuterSumParameter(sl1.freqs.shape, sl2.freqs.shape), sl1.freqs.ref(), sl2.freqs.ref()
    )

    # Construct output Fourier layer
    sl = FourierLayer(
        sl1.scope,
        sl1.num_output_units * sl2.num_output_units,
        period,
        freqs=freqs,
    )
    return CircuitBlock.from_layer(sl)


def integrate_fourier_layer(sl: FourierLayer, *, scope: Scope) -> CircuitBlock:
    if not len(sl.scope & scope):
        raise ValueError(
            f"The scope of the Fourier layer '{sl.scope}'"
            f" is expected to be a subset of the integration scope '{scope}'"
        )
    log_period = Parameter.from_unary(LogParameter(sl.period.shape), sl.period.ref())

    log_value = Parameter.from_binary(
        SumParameter((sl.num_output_units,), (sl.num_output_units,)),
        Parameter.from_sequence(
            Parameter.from_binary(
                FourierIntegralParameter((sl.num_output_units,), (1,)), 
                sl.freqs.ref(),
                sl.period.ref()
            ),
            LogParameter((sl.num_output_units,))
        ),
        Parameter.from_unary(
            ExpandParameter((1,), shape=(sl.num_output_units,)),
            Parameter.from_unary(NegateParameter(sl.period.shape), log_period.ref())
            # Parameter.from_input(
            #     ConstantParameter(1, value=-np.log(sl.period))
            # ),
        ),
    )
    sl = ConstantValueLayer(sl.num_output_units, log_space=True, value=log_value)
    return CircuitBlock.from_layer(sl)


def zero_frequency_condition(x: Any) -> Any:
    return x == 0
