import itertools
import pytest
import torch

from cirkit.backend.torch.layers.input import TorchConstantLayer, TorchConstantValueLayer
from layers import TorchFourierLayer
from models import OrthogonalSOS


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_squares,num_units,region_graph",
    list(itertools.product([1, 4], [7], ["rnd-bt"])),
)
def test_continuous_osos_pc_integral_product_fourier(num_squares, num_units, region_graph):
    num_variables = 4
    interval = -64.0, 64.0
    model = OrthogonalSOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="fourier",
        input_layer_kwargs=dict(period=[interval[1]-interval[0]]*num_variables),
        num_squares=num_squares,
        region_graph=region_graph,
        complex=True,
    )
    for layer in model._circuit.layers:
        if isinstance(layer, TorchFourierLayer):
            assert torch.all(layer.freqs() == torch.arange(-(num_units // 2), num_units // 2 + 1))

    for layer in model._int_sq_circuit.layers:
        if isinstance(layer, TorchConstantValueLayer):
            assert layer.log_space
            assert torch.allclose(
                layer.value().view(num_variables * num_squares, num_units, num_units),
                torch.log(torch.eye(num_units))
            )
