import itertools

import pytest
import torch

from tests.test_circuit_model import build_circuit
from mtp.models.gpt import GPT
from mtp.models.mtp import MultiTokenLM
from mtp.models.lm import LM



@pytest.fixture(params=[('cp', 1), ('cp', 2), ('hmm', 2)])
def mtp(request):
    vocab_size = 2
    n_embd = 8
    n_layer = 1
    n_head = 2
    kind, n_component = request.param
    n_token = 2
    lm = LM(
        lm=GPT(vocab_size, n_embd, n_layer, n_head),
        ref_enc='encoder',
        ref_head='head',
        encoder_only=False
    )
    circuit = build_circuit(vocab_size, n_token, n_component, kind=kind)
    mtp = MultiTokenLM(
        lm,
        circuit,
        mt_head_kwargs={'n_embd': n_embd, 'n_head': n_head, 'transformer_n_layer': 1},
    )
    yield mtp


def test_mtp_forward(mtp: MultiTokenLM):
    batch_size, seq_length = 8, 15
    seq = torch.randint(high=2, size=(batch_size, seq_length + 1))
    xx = seq[:, :seq_length]
    yy = seq[:, 1:]
    results = mtp(xx, yy)
    assert torch.isfinite(results['loss'])
    assert results['loss'] >= 0.0


def test_mtp_generate(mtp: MultiTokenLM):
    # TODO: which value is the "beginning of sentence"?
    BOS = 1
    # Sample a bunch of short sentences
    # We will use these samples to get empirical estimates of the sentences distribution
    num_steps = 2
    num_seqs, max_seq_length = 2 ** 19, mtp.n_token * num_steps + 1
    seqs = torch.full(size=(num_seqs, 1), fill_value=BOS, dtype=torch.int64)
    for _ in range(num_steps):
        gresult = mtp.generate(seqs)
        toks = gresult['tokens']
        assert toks.shape == (num_seqs, mtp.n_token)
        seqs = torch.cat([seqs, toks], dim=1)
    assert torch.all(torch.isin(seqs, torch.tensor(list(range(mtp.lm.encoder.vocab_size)))))
    # Map samples to indices of the probabilities computed above
    # seqs_idx: (num_seqs,)
    seqs_idx = torch.sum(seqs * torch.tensor([0] + list(reversed([2 ** i for i in range(max_seq_length - 1)]))), dim=-1)
    # Compute ratios and compare with the probabilities
    _, counts = torch.unique(seqs_idx, return_counts=True)
    ratios = counts / num_seqs
    assert len(ratios) == 2 ** (max_seq_length - 1)

    # Compute the likelihood of the sentence and check it matches with empirical estimates
    # obtained by sampling sentences (see above)
    worlds = torch.tensor(list(itertools.product([0, 1], repeat=max_seq_length - 1)))
    worlds_seqs = torch.cat([torch.full(size=(worlds.shape[0], 1), fill_value=BOS, dtype=torch.int64), worlds], dim=1)
    xx = worlds_seqs[:, :-1].contiguous()
    yy = worlds_seqs[:, 1:].contiguous()
    results = mtp(xx, yy, return_log_probs=True)
    log_probs = results['log_probs'].view(worlds.shape[0], -1)
    assert log_probs.shape[1] == max_seq_length - 1
    worlds_log_probs = torch.sum(log_probs[:, [i * mtp.n_token for i in range(num_steps)]], dim=1)
    worlds_probs = torch.exp(worlds_log_probs)
    assert torch.isclose(torch.sum(ratios), torch.tensor(1.0))
    assert torch.isclose(torch.sum(worlds_probs), torch.tensor(1.0))
    assert torch.allclose(ratios, worlds_probs, rtol=4e-2), \
        torch.max(torch.abs(worlds_probs / ratios - 1.0))
