"""Tests for CategoricalStimulusDataset and its integration with build_env."""
import re

import numpy as np
import pytest

from meta_rg.s2b_import import ensure_s2b_importable
ensure_s2b_importable()

from symbolic_behaviour_benchmark.categorical_stimulus_dataset import (
    CategoricalStimulusDataset,
    CATEGORY_REGISTRY,
)


class TestCategoricalStimulusDataset:
    def _make_dataset(self, nbr_latents=3, min_val=2, max_val=5, train=True, prototype=None):
        return CategoricalStimulusDataset(
            train=train,
            transform=None,
            sampling_strategy="component-focused-1shot",
            split_strategy="combinatorial2-40",
            nbr_latents=nbr_latents,
            min_nbr_values_per_latent=min_val,
            max_nbr_values_per_latent=max_val,
            nbr_object_centric_samples=1,
            prototype=prototype,
        )

    def test_o_greater_than_1_accepted(self):
        ds = CategoricalStimulusDataset(
            train=True,
            transform=None,
            sampling_strategy=None,
            split_strategy=None,
            nbr_latents=3,
            min_nbr_values_per_latent=2,
            max_nbr_values_per_latent=5,
            nbr_object_centric_samples=4,
            prototype=None,
        )
        assert ds is not None and len(ds._adjectives) == 4

    def test_latent_dims_has_distinct_categories(self):
        ds = self._make_dataset(nbr_latents=3)
        cats = [ds.latent_dims[i]['category_name'] for i in range(3)]
        assert len(set(cats)) == 3, f"Expected 3 distinct categories, got {cats}"

    def test_latent_dims_categories_in_registry(self):
        ds = self._make_dataset(nbr_latents=3)
        for i in range(3):
            assert ds.latent_dims[i]['category_name'] in CATEGORY_REGISTRY

    def test_section_names_are_valid_items(self):
        ds = self._make_dataset(nbr_latents=3)
        for i in range(3):
            cat = ds.latent_dims[i]['category_name']
            for s_idx, section in ds.latent_dims[i]['sections'].items():
                assert section['name'] in CATEGORY_REGISTRY[cat], (
                    f"dim {i} section {s_idx} name '{section['name']}' "
                    f"not in {CATEGORY_REGISTRY[cat]}"
                )

    def test_getitem_experiences_dtype_and_shape(self):
        ds = self._make_dataset(nbr_latents=3)
        item = ds[0]
        exp = item['experiences']
        assert exp.dtype == np.float32
        assert exp.shape == (1, 1, 3), f"Expected (1,1,3), got {exp.shape}"

    def test_getitem_experiences_are_nonnegative_integers(self):
        ds = self._make_dataset(nbr_latents=3)
        for idx in range(min(10, len(ds))):
            exp = ds[idx]['experiences'].flatten()
            for v in exp:
                assert v >= 0.0 and v == int(v), f"Expected non-negative integer, got {v}"

    def test_getitem_experiences_in_bounds(self):
        ds = self._make_dataset(nbr_latents=3, max_val=5)
        for idx in range(min(20, len(ds))):
            exp = ds[idx]['experiences'].flatten()
            for lidx, v in enumerate(exp):
                max_idx = ds.latent_dims[lidx]['size'] - 1
                assert int(v) <= max_idx, f"Index {v} out of range for dim {lidx} (max {max_idx})"

    def test_latent_class_to_text_returns_correct_names(self):
        ds = self._make_dataset(nbr_latents=3)
        # Build a flat array with known indices: [0, 1, 0]
        flat = np.array([0.0, 1.0, 0.0], dtype=np.float32)
        result = ds.latent_class_to_text(flat)
        assert len(result) == 1, f"Expected 1 stimulus group, got {len(result)}"
        labels = result[0]
        assert len(labels) == 3
        assert labels[0] == ds.latent_dims[0]['sections'][0]['name']
        assert labels[1] == ds.latent_dims[1]['sections'][1]['name']
        assert labels[2] == ds.latent_dims[2]['sections'][0]['name']

    def test_latent_class_to_text_with_multiple_stimuli(self):
        ds = self._make_dataset(nbr_latents=3)
        # Simulate 2 stimuli (e.g. target + 1 distractor): 6 indices
        flat = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 1.0], dtype=np.float32)
        result = ds.latent_class_to_text(flat)
        assert len(result) == 2, f"Expected 2 stimulus groups, got {len(result)}"
        for group_idx, labels in enumerate(result):
            assert len(labels) == 3, f"Group {group_idx} has {len(labels)} labels, expected 3"
            for lidx, label in enumerate(labels):
                assert isinstance(label, str), f"Group {group_idx} dim {lidx}: expected str, got {type(label)}"
                cat = ds.latent_dims[lidx]['category_name']
                assert label in CATEGORY_REGISTRY[cat], (
                    f"Group {group_idx} dim {lidx}: '{label}' not in category '{cat}'"
                )

    def test_prototype_shares_category_assignments(self):
        train_ds = self._make_dataset(nbr_latents=3, train=True)
        test_ds = self._make_dataset(nbr_latents=3, train=False, prototype=train_ds)
        # Test dataset must share same latent_dims as train
        for i in range(3):
            assert test_ds.latent_dims[i]['category_name'] == train_ds.latent_dims[i]['category_name']
            for s_idx in test_ds.latent_dims[i]['sections']:
                assert test_ds.latent_dims[i]['sections'][s_idx]['name'] == \
                       train_ds.latent_dims[i]['sections'][s_idx]['name']

    def test_registry_has_enough_items_for_max_val(self):
        # Every category must have >= max_nbr_values_per_latent items
        max_val = 10
        for cat_name, items in CATEGORY_REGISTRY.items():
            assert len(items) >= max_val, (
                f"Category '{cat_name}' has only {len(items)} items, need >= {max_val}"
            )

    def test_requires_enough_categories_for_latents(self):
        with pytest.raises(ValueError, match="categories"):
            self._make_dataset(nbr_latents=len(CATEGORY_REGISTRY) + 1)


class TestBuildEnvCategoricalDomain:
    def test_build_env_categorical_succeeds(self):
        from meta_rg.env_utils import build_env
        env = build_env(
            nbr_latents=3,
            min_nbr_values_per_latent=2,
            max_nbr_values_per_latent=5,
            nbr_object_centric_samples=1,
            sampling_strategy="component-focused-1shot",
            domain='categorical',
            seed=0,
        )
        assert env is not None

    def test_build_env_categorical_o4_accepted(self):
        from meta_rg.env_utils import build_env
        env = build_env(
            nbr_latents=3,
            nbr_object_centric_samples=4,
            domain='categorical',
            seed=0,
        )
        assert env is not None

    def test_categorical_prompt_contains_names_not_floats(self):
        """After env.reset(), the listener prompt must contain category names."""
        from meta_rg.env_utils import build_env, get_prompt_text, no_op_action, ensure_action_shape
        from meta_rg.agents.rule_based import build_rule_based_agents

        env = build_env(
            nbr_latents=3,
            min_nbr_values_per_latent=2,
            max_nbr_values_per_latent=5,
            nbr_object_centric_samples=1,
            sampling_strategy="component-focused-1shot",
            domain='categorical',
            seed=0,
        )
        speaker, listener = build_rule_based_agents()
        obs, infos = env.reset()
        speaker.reset()

        # Send a speaker action so the listener gets a prompt with the stimulus
        a0 = speaker.next_action(state=obs[0], infos=infos[0])
        a0 = ensure_action_shape(a0, max_sentence_length=env.max_sentence_length)
        a1 = no_op_action(max_sentence_length=env.max_sentence_length)
        obs2, _, _, infos2 = env.step([a0, a1])

        prompt_text = get_prompt_text(infos2[1])
        # Stimulus part must NOT contain decimal-formatted floats (e.g. -0.321)
        stim_part = prompt_text.split("stimulus:")[-1].split("\n")[0]
        assert not re.search(r"-?\d+\.\d+", stim_part), (
            f"Prompt stimulus part contains float values: {stim_part}"
        )
        # Stimulus part MUST contain at least one item from the category registry
        from symbolic_behaviour_benchmark.categorical_stimulus_dataset import CATEGORY_REGISTRY
        all_items = [item for items in CATEGORY_REGISTRY.values() for item in items]
        assert any(item in prompt_text for item in all_items), (
            f"Prompt does not contain any category items.\nPrompt:\n{prompt_text}"
        )
