"""Tests for VocabPartitionedSpeakerAgent."""
import numpy as np
import pytest

from meta_rg.agents.rule_based import VocabPartitionedSpeakerAgent, build_rule_based_agents


def _make_infos(latent_values: list) -> dict:
    # Shape (1, 1, nbr_latents): [0,0] → 1-D array, [sid] → scalar
    return {
        "speaker_exp_latents": np.array([[latent_values]]),
        "round_idx": 0,
    }


def test_encoding_correctness():
    """Tokens must fall in the correct per-latent block."""
    # vocab_size=13, nbr_latents=3 → block_size = (13-1)//3 = 4
    agent = VocabPartitionedSpeakerAgent(
        action_space=None, vocab_size=13, max_sentence_length=3,
        nbr_communication_rounds=1, nbr_latents=3,
        max_nbr_values_per_latent=4,
    )
    agent.reset()
    action = agent.next_action(state=None, infos=_make_infos([0, 1, 2]))
    # _reg_comm_chan strips the outer dim: shape → (max_sentence_length,)
    tokens = action["communication_channel"].astype(int).tolist()

    assert tokens[0] == 0 * 4 + 0 + 1  # latent 0, value 0 → 1
    assert tokens[1] == 1 * 4 + 1 + 1  # latent 1, value 1 → 6
    assert tokens[2] == 2 * 4 + 2 + 1  # latent 2, value 2 → 11


def test_tokens_disjoint_across_latents():
    """No token overlap between latent ranges regardless of value combination."""
    # vocab_size=16, nbr_latents=3, max_val=4 → block_size=5
    agent = VocabPartitionedSpeakerAgent(
        action_space=None, vocab_size=16, max_sentence_length=3,
        nbr_communication_rounds=1, nbr_latents=3,
        max_nbr_values_per_latent=4,
    )
    all_tokens: set = set()
    for v0 in range(4):
        for v1 in range(4):
            for v2 in range(4):
                agent.reset()
                action = agent.next_action(state=None, infos=_make_infos([v0, v1, v2]))
                t = tuple(action["communication_channel"].astype(int)[:3])
                # Tokens across positions must be disjoint
                assert len(set(t)) == 3, f"token collision for vals {v0},{v1},{v2}: {t}"
                all_tokens.update(t)
    # All tokens must be in [1, vocab_size-1]
    assert all(1 <= tok <= 15 for tok in all_tokens)


def test_rejects_small_vocab():
    """Raises ValueError when vocab_size is too small to partition."""
    with pytest.raises(ValueError, match="block_size"):
        VocabPartitionedSpeakerAgent(
            action_space=None, vocab_size=6, max_sentence_length=3,
            nbr_communication_rounds=1, nbr_latents=3,
            max_nbr_values_per_latent=5,  # block_size=1 < 5
        )


def test_error_message_names_required_size():
    """ValueError message must tell the user the minimum vocab_size needed."""
    with pytest.raises(ValueError, match="16"):  # 3*5+1=16
        VocabPartitionedSpeakerAgent(
            action_space=None, vocab_size=6, max_sentence_length=3,
            nbr_communication_rounds=1, nbr_latents=3,
            max_nbr_values_per_latent=5,
        )


def test_build_rule_based_agents_vocab_partition():
    """build_rule_based_agents with vocab_partition=True returns the right class."""
    speaker, _ = build_rule_based_agents(
        vocab_size=16, max_sentence_length=3, nbr_communication_rounds=1,
        nbr_latents=3, max_nbr_values_per_latent=4, vocab_partition=True,
    )
    assert isinstance(speaker, VocabPartitionedSpeakerAgent)


def test_build_rule_based_agents_default_unchanged():
    """Default (vocab_partition=False) still returns the upstream speaker."""
    from symbolic_behaviour_benchmark.rule_based_agents.positionally_disentangled_speaker_agent import (
        PositionallyDisentangledSpeakerAgent,
    )
    speaker, _ = build_rule_based_agents(vocab_size=6)
    assert type(speaker) is PositionallyDisentangledSpeakerAgent
