"""Tests for the pseudoword stimulus domain."""
import re

import numpy as np
import pytest

from meta_rg.s2b_import import ensure_s2b_importable
ensure_s2b_importable()

from symbolic_behaviour_benchmark.pseudoword_stimulus_dataset import (
    PseudowordStimulusDataset, VOWELS, CONSONANTS, ALPHABET,
)
from meta_rg.env_utils import build_env, get_prompt_text, no_op_action, ensure_action_shape
from meta_rg.agents.rule_based import build_rule_based_agents


# ── Shared helpers ─────────────────────────────────────────────────────────────

def _make_ds(**kw):
    defaults = dict(
        nbr_latents=3,
        min_nbr_values_per_latent=2,
        max_nbr_values_per_latent=4,
        min_word_length=2,
        max_word_length=6,
    )
    defaults.update(kw)
    return PseudowordStimulusDataset(**defaults)


def _base_words(ds):
    return [s['name'] for ld in ds.latent_dims.values() for s in ld['sections'].values()]


# ── Grammar / format tests (O=1) ───────────────────────────────────────────────

class TestPseudowordBase:
    def test_words_strictly_alternating_cv(self):
        ds = _make_ds()
        vowels_set = set(VOWELS)
        consonants_set = set(CONSONANTS)
        for word in _base_words(ds):
            lower = word.lower()
            assert len(lower) % 2 == 0, f"'{word}' has odd length"
            for i, ch in enumerate(lower):
                if i % 2 == 0:
                    assert ch in consonants_set, f"pos {i} of '{word}' should be consonant"
                else:
                    assert ch in vowels_set, f"pos {i} of '{word}' should be vowel"

    def test_words_are_uppercase(self):
        ds = _make_ds()
        for word in _base_words(ds):
            assert word == word.upper() and word.isalpha()

    def test_words_globally_unique_within_episode(self):
        ds = _make_ds(nbr_latents=4, min_nbr_values_per_latent=3, max_nbr_values_per_latent=5)
        words = _base_words(ds)
        assert len(words) == len(set(words)), f"duplicates: {words}"

    def test_latent_dims_sections_have_name_keys(self):
        ds = _make_ds()
        for l_idx, ld in ds.latent_dims.items():
            for s_idx, sec in ld['sections'].items():
                assert 'name' in sec, f"dim {l_idx} section {s_idx} missing 'name'"

    def test_getitem_shape_and_dtype(self):
        ds = _make_ds(nbr_latents=3)
        exp = ds[0]['experiences']
        assert exp.dtype == np.float32
        assert exp.shape == (1, 1, 3)

    def test_getitem_values_in_bounds(self):
        ds = _make_ds(nbr_latents=3)
        for idx in range(min(len(ds), 10)):
            exp = ds[idx]['experiences'].flatten()
            for l_idx, v in enumerate(exp):
                max_val = ds.latent_dims[l_idx]['size'] - 1
                assert 0 <= int(v) <= max_val

    def test_latent_class_to_text_o1(self):
        ds = _make_ds(nbr_latents=3)
        flat = ds[0]['experiences'].flatten().astype(np.float32)
        result = ds.latent_class_to_text(flat)
        assert len(result) == 1 and len(result[0]) == 3
        for label in result[0]:
            assert isinstance(label, str) and label == label.upper()
            # O=1: no suffix, so label IS the base word (even length from (CV)+)
            assert len(label) % 2 == 0

    def test_prototype_shares_words(self):
        train_ds = _make_ds()
        test_ds = PseudowordStimulusDataset(
            train=False, prototype=train_ds,
            nbr_latents=3, min_nbr_values_per_latent=2, max_nbr_values_per_latent=4,
        )
        for l_idx in train_ds.latent_dims:
            assert (
                [s['name'] for s in train_ds.latent_dims[l_idx]['sections'].values()] ==
                [s['name'] for s in test_ds.latent_dims[l_idx]['sections'].values()]
            )

    def test_no_registry_limit_on_latents(self):
        ds = PseudowordStimulusDataset(nbr_latents=15, min_nbr_values_per_latent=2,
                                       max_nbr_values_per_latent=3)
        assert len(ds.latent_dims) == 15




# ── Object-centric sampling (O>1) ─────────────────────────────────────────────

class TestPseudowordOC:
    def test_o_greater_than_1_accepted(self):
        ds = _make_ds(nbr_object_centric_samples=4)
        assert ds is not None

    def test_suffix_count_equals_o(self):
        for O in (2, 4, 8, 16):
            ds = _make_ds(nbr_object_centric_samples=O)
            assert len(ds._suffixes) == O, f"O={O}: expected {O} suffixes, got {len(ds._suffixes)}"

    def test_suffixes_are_uppercase(self):
        ds = _make_ds(nbr_object_centric_samples=4)
        for s in ds._suffixes:
            assert s == s.upper() and s.isalpha(), f"suffix '{s}' not UPPERCASE alpha"

    def test_suffixes_globally_unique(self):
        ds = _make_ds(nbr_object_centric_samples=8)
        assert len(ds._suffixes) == len(set(ds._suffixes))

    def test_suffix_length_scales_with_o(self):
        # O<=16: suffix_len=1; O=17..256: suffix_len=2
        ds16 = _make_ds(nbr_object_centric_samples=16)
        ds17 = _make_ds(nbr_object_centric_samples=17)
        assert all(len(s) == 1 for s in ds16._suffixes)
        assert all(len(s) == 2 for s in ds17._suffixes)

    def test_suffixes_shared_across_latents(self):
        """All latent values use the exact same suffix set."""
        O = 4
        ds = _make_ds(nbr_object_centric_samples=O)
        # The suffix set is stored once; this test verifies it's not per-latent.
        assert hasattr(ds, '_suffixes')
        assert len(ds._suffixes) == O

    def test_prototype_shares_suffixes(self):
        train_ds = _make_ds(nbr_object_centric_samples=4)
        test_ds = PseudowordStimulusDataset(
            train=False, prototype=train_ds,
            nbr_latents=3, min_nbr_values_per_latent=2, max_nbr_values_per_latent=4,
            nbr_object_centric_samples=4,
        )
        assert train_ds._suffixes == test_ds._suffixes

    def test_composite_encoding_and_decoding(self):
        O = 4
        ds = _make_ds(nbr_object_centric_samples=O)
        # Manually construct a composite observation and check latent_class_to_text
        # base_idx=0, oc_idx=2 for each of 3 latents
        base_indices = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        for oc_idx in range(O):
            encoded = base_indices * O + oc_idx
            result = ds.latent_class_to_text(encoded)
            assert len(result) == 1 and len(result[0]) == 3
            for label in result[0]:
                # label = base_word + suffix[oc_idx]
                expected_suffix = ds._suffixes[oc_idx]
                assert label.endswith(expected_suffix), (
                    f"oc_idx={oc_idx}: '{label}' should end with '{expected_suffix}'"
                )
                # Base word part (strip suffix)
                base_part = label[:-len(expected_suffix)] if expected_suffix else label
                assert base_part in _base_words(ds), f"base part '{base_part}' not in word list"

    def test_all_oc_samples_same_base_different_suffix(self):
        O = 4
        ds = _make_ds(nbr_object_centric_samples=O)
        latent_class = np.zeros(3, dtype=np.float32)  # first value of each dim
        labels_per_oc = []
        for oc_idx in range(O):
            obs = ds.generate_object_centric_observations(latent_class, oc_idx)
            result = ds.latent_class_to_text(obs)
            labels_per_oc.append(result[0])
        # All OC samples have same base word but different suffixes
        for dim_idx in range(3):
            texts = [labels_per_oc[oc][dim_idx] for oc in range(O)]
            suffixes = [ds._suffixes[oc] for oc in range(O)]
            for oc, (text, suf) in enumerate(zip(texts, suffixes)):
                assert text.endswith(suf), f"dim {dim_idx} OC {oc}: '{text}' missing suffix '{suf}'"
            # Bases are identical across OC samples
            bases = [t[:-len(s)] if s else t for t, s in zip(texts, suffixes)]
            assert len(set(bases)) == 1, f"dim {dim_idx} base words differ across OC: {bases}"

    def test_dataset_size_includes_o(self):
        O = 4
        ds = _make_ds(nbr_object_centric_samples=O, min_nbr_values_per_latent=2,
                      max_nbr_values_per_latent=2, nbr_latents=2)
        # dataset_size = 2 * 2 * O = 16 (before train/test split)
        assert ds.dataset_size == 2 * 2 * O


# ── Env-level tests ────────────────────────────────────────────────────────────

class TestPseudowordEnv:
    def test_build_env_pseudoword_o1(self):
        env = build_env(domain='pseudoword', seed=0)
        assert env is not None

    def test_build_env_pseudoword_o4(self):
        env = build_env(domain='pseudoword', nbr_object_centric_samples=4, seed=0)
        assert env is not None



    def test_pseudoword_prompt_uppercase_not_floats(self):
        env = build_env(
            nbr_latents=3, nbr_object_centric_samples=1,
            sampling_strategy='component-focused-1shot',
            domain='pseudoword', seed=42,
        )
        speaker, _ = build_rule_based_agents()
        obs, infos = env.reset()
        speaker.reset()
        a0 = ensure_action_shape(speaker.next_action(state=obs[0], infos=infos[0]),
                                 env.max_sentence_length)
        _, _, _, infos2 = env.step([a0, no_op_action(env.max_sentence_length)])
        prompt = get_prompt_text(infos2[1])
        assert not re.search(r'\d+\.\d+', prompt), "prompt contains float literals"
        assert re.search(r'[A-Z]{2,}', prompt), "prompt contains no uppercase words"

    def test_pseudoword_o4_prompt_has_suffixed_words(self):
        env = build_env(
            nbr_latents=3, nbr_object_centric_samples=4,
            sampling_strategy='component-focused-1shot',
            domain='pseudoword', seed=42,
        )
        speaker, _ = build_rule_based_agents()
        obs, infos = env.reset()
        speaker.reset()
        a0 = ensure_action_shape(speaker.next_action(state=obs[0], infos=infos[0]),
                                 env.max_sentence_length)
        _, _, _, infos2 = env.step([a0, no_op_action(env.max_sentence_length)])
        prompt = get_prompt_text(infos2[1])
        assert not re.search(r'\d+\.\d+', prompt), "prompt contains float literals"
        assert re.search(r'[A-Z]{2,}', prompt), "prompt contains no uppercase words"

    def test_reproducibility_same_seed(self):
        def _get_words(seed):
            env = build_env(domain='pseudoword', nbr_latents=3, seed=seed)
            env.seed(seed)
            env.reset()
            base_ds = env.datasets['train'].datasets['train']
            return [
                [s['name'] for s in ld['sections'].values()]
                for ld in base_ds.latent_dims.values()
            ]

        words_a = _get_words(7)
        words_b = _get_words(7)
        words_c = _get_words(99)

        assert words_a == words_b, "same seed must produce same words"
        assert words_a != words_c, "different seeds should produce different words"


# ── Categorical O>1 tests ──────────────────────────────────────────────────────

class TestCategoricalOC:
    def _make_cat_ds(self, O, **kw):
        from symbolic_behaviour_benchmark.categorical_stimulus_dataset import CategoricalStimulusDataset
        defaults = dict(nbr_latents=3, min_nbr_values_per_latent=2,
                        max_nbr_values_per_latent=4, nbr_object_centric_samples=O)
        defaults.update(kw)
        return CategoricalStimulusDataset(**defaults)

    def test_o_greater_than_1_accepted(self):
        ds = self._make_cat_ds(4)
        assert ds is not None

    def test_adjective_count_equals_o(self):
        for O in (2, 4, 8):
            ds = self._make_cat_ds(O)
            assert len(ds._adjectives) == O, f"O={O}: expected {O} adjectives"

    def test_adjectives_globally_unique(self):
        ds = self._make_cat_ds(8)
        assert len(ds._adjectives) == len(set(ds._adjectives))

    def test_adjectives_are_lowercase_strings(self):
        ds = self._make_cat_ds(4)
        for adj in ds._adjectives:
            assert isinstance(adj, str) and adj.isalpha() and adj == adj.lower(), \
                f"adjective '{adj}' not a lowercase alpha string"

    def test_prototype_shares_adjectives(self):
        from symbolic_behaviour_benchmark.categorical_stimulus_dataset import CategoricalStimulusDataset
        train_ds = self._make_cat_ds(4)
        test_ds = CategoricalStimulusDataset(
            train=False, prototype=train_ds,
            nbr_latents=3, min_nbr_values_per_latent=2, max_nbr_values_per_latent=4,
            nbr_object_centric_samples=4,
        )
        assert train_ds._adjectives == test_ds._adjectives

    def test_composite_encoding_and_decoding(self):
        O = 4
        ds = self._make_cat_ds(O)
        base_indices = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        for oc_idx in range(O):
            encoded = base_indices * O + oc_idx
            result = ds.latent_class_to_text(encoded)
            assert len(result) == 1 and len(result[0]) == 3
            for label in result[0]:
                adj = ds._adjectives[oc_idx]
                assert adj in label, f"adjective '{adj}' missing from label '{label}'"

    def test_adjective_and_category_word_in_label(self):
        O = 4
        ds = self._make_cat_ds(O)
        encoded = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        for _ in range(10):
            result = ds.latent_class_to_text(encoded)
            label = result[0][0]
            word = ds.latent_dims[0]['sections'][0]['name']
            adj = ds._adjectives[0]
            assert word in label and adj in label, \
                f"label '{label}' missing word '{word}' or adj '{adj}'"

    def test_adjective_position_varies(self):
        O = 4
        ds = self._make_cat_ds(O)
        encoded = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        word = ds.latent_dims[0]['sections'][0]['name']
        adj = ds._adjectives[0]
        forms = set()
        for _ in range(50):
            label = ds.latent_class_to_text(encoded)[0][0]
            if label.startswith(adj):
                forms.add('prefix')
            else:
                forms.add('suffix')
        assert len(forms) == 2, f"adjective position never varied across 50 renders (only: {forms})"

    def test_dataset_size_includes_o(self):
        O = 4
        ds = self._make_cat_ds(O, min_nbr_values_per_latent=2, max_nbr_values_per_latent=2,
                               nbr_latents=2)
        assert ds.dataset_size == 2 * 2 * O

    def test_o_exceeds_pool_raises(self):
        from symbolic_behaviour_benchmark.categorical_stimulus_dataset import ADJECTIVE_POOL
        with pytest.raises(ValueError, match="ADJECTIVE_POOL"):
            self._make_cat_ds(len(ADJECTIVE_POOL) + 1)

    def test_build_env_categorical_o4(self):
        env = build_env(domain='categorical', nbr_object_centric_samples=4, seed=0)
        assert env is not None

    def test_reproducibility_same_seed(self):
        def _get_adjectives(seed):
            env = build_env(domain='categorical', nbr_latents=3, nbr_object_centric_samples=4,
                            seed=seed)
            env.seed(seed)
            env.reset()
            return env.datasets['train'].datasets['train']._adjectives

        adjs_a = _get_adjectives(7)
        adjs_b = _get_adjectives(7)
        adjs_c = _get_adjectives(99)
        assert adjs_a == adjs_b, "same seed must produce same adjectives"
        assert adjs_a != adjs_c, "different seeds should produce different adjectives"
