import itertools

import numpy as np
import pytest
import torch
from scipy import integrate

from models import MPC, PC, SOS, ExpSOS, OrthogonalSOS
from tests.test_utils import generate_all_nary_samples


def check_normalized_log_scores(model: PC, x: torch.Tensor) -> torch.Tensor:
    scores = model.log_score(x)
    assert scores.shape == (len(x), 1)
    assert torch.all(torch.isfinite(scores))
    assert torch.allclose(
        torch.logsumexp(scores, dim=0).exp(), torch.tensor(1.0), atol=1e-15
    )
    return scores


def check_pmf(model: PC, num_states: int = 2):
    x = torch.LongTensor(generate_all_nary_samples(model.num_variables, num_states))
    lls = model.log_likelihood(x)
    assert lls.shape == (len(x), 1)
    assert torch.all(torch.isfinite(lls))
    assert torch.allclose(
        torch.logsumexp(lls, dim=0).exp(), torch.tensor(1.0), atol=1e-15
    )


def check_pdf(model, interval: tuple[float, float] | None = None, assume_normalized: bool = False):
    if assume_normalized:
        pdf = lambda y, x: torch.exp(model.log_score(torch.Tensor([[x, y]])))
    else:
        pdf = lambda y, x: torch.exp(model.log_likelihood(torch.Tensor([[x, y]])))
    if interval is None:
        a, b = -64.0, 64.0
    else:
        a, b = interval
    ig, err = integrate.dblquad(pdf, a, b, a, b)
    assert np.isclose(ig, 1.0, atol=1e-15), ig


@pytest.mark.parametrize(
    "num_variables,num_components,num_units,region_graph,sd",
    list(
        itertools.product([9, 12], [1, 4], [1, 3], ["rnd-bt", "rnd-lt"], [False, True])
    ),
)
def test_discrete_monotonic_pc(
    num_variables, num_components, num_units, region_graph, sd
):
    if region_graph == "qt":
        if num_variables == 9:
            image_shape = (1, 3, 3)
        else:  # num_variables == 12
            image_shape = (1, 4, 3)
    else:
        image_shape = None
    model = MPC(
        num_variables,
        image_shape=image_shape,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="categorical",
        input_layer_kwargs={"num_categories": 2},
        num_components=num_components,
        region_graph=region_graph,
        structured_decomposable=sd,
    )
    check_pmf(model)


@pytest.mark.parametrize(
    "num_variables,num_squares,num_units,region_graph,sd,input_layer",
    list(
        itertools.product(
            [9, 12],
            [1, 4],
            [1, 3],
            ["rnd-bt", "qt"],
            [False, True],
            ["categorical", "embedding"],
        )
    ),
)
def test_discrete_sos_pc(
    num_variables, num_squares, num_units, region_graph, sd, input_layer
):
    input_layer_kwargs = (
        {"num_categories": 2} if input_layer == "categorical" else {"num_states": 2}
    )
    if region_graph == "qt":
        if num_variables == 9:
            image_shape = (1, 3, 3)
        else:  # num_variables == 12
            image_shape = (1, 4, 3)
    else:
        image_shape = None
    model = SOS(
        num_variables,
        image_shape=image_shape,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer=input_layer,
        input_layer_kwargs=input_layer_kwargs,
        num_squares=num_squares,
        region_graph=region_graph,
        structured_decomposable=sd,
    )
    check_pmf(model)


@pytest.mark.parametrize(
    "num_variables,num_squares,num_units,region_graph,sd,input_layer",
    list(
        itertools.product(
            [9, 12],
            [1, 4],
            [1, 3],
            ["rnd-bt", "qt"],
            [False, True],
            ["categorical", "embedding"],
        )
    ),
)
def test_discrete_complex_sos_pc(
    num_variables, num_squares, num_units, region_graph, sd, input_layer
):
    input_layer_kwargs = (
        {"num_categories": 2} if input_layer == "categorical" else {"num_states": 2}
    )
    if region_graph == "qt":
        if num_variables == 9:
            image_shape = (1, 3, 3)
        else:  # num_variables == 12
            image_shape = (1, 4, 3)
    else:
        image_shape = None
    model = SOS(
        num_variables,
        image_shape=image_shape,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer=input_layer,
        input_layer_kwargs=input_layer_kwargs,
        num_squares=num_squares,
        region_graph=region_graph,
        structured_decomposable=sd,
        complex=True,
    )
    check_pmf(model)


@pytest.mark.parametrize(
    "num_variables,num_units,region_graph,sd,input_layer",
    list(
        itertools.product(
            [9, 12],
            [1, 3],
            ["rnd-bt", "qt"],
            [False, True],
            ["categorical", "embedding"],
        )
    ),
)
def test_discrete_exp_sos_pc(num_variables, num_units, region_graph, sd, input_layer):
    input_layer_kwargs = (
        {"num_categories": 2} if input_layer == "categorical" else {"num_states": 2}
    )
    if region_graph == "qt":
        if num_variables == 9:
            image_shape = (1, 3, 3)
        else:  # num_variables == 12
            image_shape = (1, 4, 3)
    else:
        image_shape = None
    model = ExpSOS(
        num_variables,
        image_shape=image_shape,
        num_input_units=num_units,
        num_sum_units=num_units,
        mono_num_input_units=2,
        mono_num_sum_units=2,
        input_layer=input_layer,
        input_layer_kwargs=input_layer_kwargs,
        region_graph=region_graph,
        structured_decomposable=sd,
        complex=True,
    )
    check_pmf(model)


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_components,num_units,region_graph",
    list(itertools.product([1], [2], ["rnd-bt"])),
)
def test_continuous_monotonic_pc(num_components, num_units, region_graph):
    num_variables = 2
    model = MPC(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="gaussian",
        num_components=num_components,
        region_graph=region_graph,
    )
    check_pdf(model)


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_squares,num_units,region_graph",
    list(itertools.product([1], [2], ["rnd-bt"])),
)
def test_continuous_sos_pc(num_squares, num_units, region_graph):
    num_variables = 2
    model = SOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="gaussian",
        num_squares=num_squares,
        region_graph=region_graph,
    )
    check_pdf(model)


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_squares,num_units,region_graph",
    list(itertools.product([1], [5], ["rnd-bt"])),
)
def test_continuous_sos_pc_fourier(num_squares, num_units, region_graph):
    num_variables = 2
    interval = -64.0, 64.0
    model = SOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="fourier",
        input_layer_kwargs=dict(period=[interval[1]-interval[0]]*num_variables),
        num_squares=num_squares,
        region_graph=region_graph,
    )
    check_pdf(model, interval=interval)


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_squares,num_units,region_graph,use_tucker",
    list(itertools.product([3], [7], ["rnd-bt"], [False, True])),
)
def test_continuous_osos_pc_fourier(num_squares, num_units, region_graph, use_tucker):
    num_variables = 2
    interval = -64.0, 64.0
    model = OrthogonalSOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="fourier",
        input_layer_kwargs=dict(period=[interval[1]-interval[0]]*num_variables),
        num_squares=num_squares,
        region_graph=region_graph,
        complex=True,
        use_tucker=use_tucker
    )
    check_pdf(model, interval=interval, assume_normalized=False)
    check_pdf(model, interval=interval, assume_normalized=True)


@pytest.mark.slow
@pytest.mark.parametrize(
    "num_units,region_graph",
    list(itertools.product([2], ["rnd-bt"])),
)
def test_continuous_exp_sos_pc(num_units, region_graph):
    num_variables = 2
    model = ExpSOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        mono_num_input_units=2,
        mono_num_sum_units=2,
        input_layer="gaussian",
        region_graph=region_graph,
    )
    check_pdf(model)


@pytest.mark.parametrize(
    "num_squares,num_units,structured,use_tucker",
    list(itertools.product([1, 3], [2, 4], [False, True], [False, True]))
)
def test_discrete_osos_pc(num_squares, num_units, structured, use_tucker):
    num_variables = 4
    num_states = 4
    model = OrthogonalSOS(
        num_variables,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="embedding",
        input_layer_kwargs=dict(num_states=num_states),
        num_squares=num_squares,
        region_graph="rnd-bt",
        structured_decomposable=structured,
        complex=True,
        use_tucker=use_tucker
    )
    check_pmf(model, num_states=num_states)


@pytest.mark.parametrize(
    "region_graph,num_squares,num_repetitions_units,use_tucker",
    list(itertools.product(['rnd-bt', 'rnd-qt-2'], [1, 3], [(2, 2), (3, 2), (2, 3)], [False, True]))
)
def test_discrete_non_structured_osos_pc(region_graph, num_squares, num_repetitions_units, use_tucker):
    num_variables, image_shape = 4, (1, 2, 2)
    num_states = 6
    num_repetitions, num_units = num_repetitions_units
    model = OrthogonalSOS(
        num_variables,
        image_shape=image_shape,
        num_input_units=num_units,
        num_sum_units=num_units,
        input_layer="embedding",
        input_layer_kwargs=dict(num_states=num_states),
        num_squares=num_squares,
        region_graph=region_graph,
        num_repetitions=num_repetitions,
        complex=True,
        use_tucker=use_tucker
    )

    embedding_layer = list(model.layers())[0]
    scope_idx = embedding_layer.scope_idx.squeeze(dim=1)
    scope_idx_stride = model.num_repetitions * model.num_variables
    weight = embedding_layer.weight()
    for i, v in enumerate(range(model.num_variables)):
        v_scope_idx = scope_idx[i * scope_idx_stride:(i + 1) * scope_idx_stride]
        v_idx = i * scope_idx_stride + torch.argwhere(v_scope_idx == v).squeeze(dim=1)
        emb_weight = weight[v_idx].view(-1, weight.shape[-1])
        weight_mm_conj_transposed = torch.mm(emb_weight, emb_weight.conj().T)
        assert torch.allclose(weight_mm_conj_transposed.imag, torch.tensor(0.0))
        assert torch.allclose(weight_mm_conj_transposed.real, torch.eye(emb_weight.shape[0]))

    check_pmf(model, num_states=num_states)
