"""S2B environment factory and game-step helpers."""
from __future__ import annotations

import numpy as np

from meta_rg.s2b_import import ensure_s2b_importable

ensure_s2b_importable()

from symbolic_behaviour_benchmark.envs import generate_receptive_constructive_test_env  # noqa: E402
from symbolic_behaviour_benchmark.utils.utils import BT2STR  # noqa: E402


# ── Environment factory ────────────────────────────────────────────────────────

def build_env(
    nbr_latents: int = 3,
    nbr_distractors: int = 0,
    vocab_size: int = 6,
    max_sentence_length: int = 3,
    min_nbr_values_per_latent: int = 2,
    max_nbr_values_per_latent: int = 5,
    nbr_communication_rounds: int = 1,
    descriptive: bool = True,
    nbr_object_centric_samples: int = 1,
    provide_listener_feedback: bool = True,
    sampling_strategy: str | None = "component-focused-1shot",
    seed: int = 0,
    discussion_mode: bool = False,
    domain: str = 'SCS',
    verbose_prompts: bool = False,
    allow_cot_response: bool = False,
    elicitation_strategies: list | None = None,
):
    """Create an S2B ReceptiveConstructiveTest environment directly (no gym.make wrapper)."""
    env = generate_receptive_constructive_test_env(
        nbr_latents=nbr_latents,
        nbr_distractors=nbr_distractors,
        vocab_size=vocab_size,
        max_sentence_length=max_sentence_length,
        min_nbr_values_per_latent=min_nbr_values_per_latent,
        max_nbr_values_per_latent=max_nbr_values_per_latent,
        nbr_communication_rounds=nbr_communication_rounds,
        descriptive=descriptive,
        include_prompts=True,
        seed=seed,
        nbr_object_centric_samples=nbr_object_centric_samples,
        provide_listener_feedback=provide_listener_feedback,
        sampling_strategy=sampling_strategy,
        discussion_mode=discussion_mode,
        domain=domain,
        verbose_prompts=verbose_prompts,
        allow_cot_response=allow_cot_response,
        elicitation_strategies=elicitation_strategies,
    )
    return env


def build_env_from_cfg(cfg: dict, seed: int = 0):
    """Build env from a flat config dict (keys match build_env kwargs)."""
    kwargs = {k: cfg[k] for k in (
        "nbr_latents", "nbr_distractors", "vocab_size", "max_sentence_length",
        "min_nbr_values_per_latent", "max_nbr_values_per_latent",
        "nbr_communication_rounds", "descriptive", "nbr_object_centric_samples",
        "provide_listener_feedback", "sampling_strategy", "discussion_mode",
        "domain", "verbose_prompts", "allow_cot_response", "elicitation_strategies",
    ) if k in cfg}
    return build_env(**kwargs, seed=seed)


# ── Low-level action helpers ───────────────────────────────────────────────────

def no_op_action(max_sentence_length: int) -> dict:
    return {
        "decision": 0,
        "communication_channel": np.zeros((1, max_sentence_length), dtype=np.int64),
    }


def int_comm_to_ohe(comm_int: np.ndarray, vocab_size: int, max_sentence_length: int) -> np.ndarray:
    """
    Convert integer communication channel (1, L) → OHE (1, L*(vocab_size+1)).
    Required by PositionallyDisentangledListenerAgent.
    """
    ohe = np.zeros((1, max_sentence_length * (vocab_size + 1)), dtype=np.float64)
    for i, tok in enumerate(comm_int[0]):
        ohe[0, i * (vocab_size + 1) + int(tok)] = 1.0
    return ohe


def ensure_action_shape(action: dict, max_sentence_length: int) -> dict:
    """Guarantee communication_channel has shape (1, L) as expected by the env."""
    comm = action.get("communication_channel", np.zeros((1, max_sentence_length), dtype=np.int64))
    if isinstance(comm, np.ndarray) and comm.ndim == 1:
        comm = comm.reshape(1, -1)
    return {**action, "communication_channel": comm}


def get_prompt_text(infos_agent: dict) -> str:
    """Decode the byte-tensor prompt from the agent's info dict."""
    bt = infos_agent.get("prompt", None)
    if bt is None:
        return ""
    return BT2STR(bt)[0]


def get_step_prompt_text(infos_agent: dict) -> str:
    """Decode the current-step-only prompt (discussion mode). Falls back to full prompt."""
    bt = infos_agent.get("step_prompt", None)
    if bt is None:
        return get_prompt_text(infos_agent)
    return BT2STR(bt)[0]


def get_intro_prompt_text(infos_agent: dict) -> str:
    """Decode the static intro (game rules) for discussion mode."""
    bt = infos_agent.get("intro_prompt", None)
    if bt is None:
        return ""
    return BT2STR(bt)[0]


def is_test_mode(infos_agent: dict) -> bool:
    return "test" in str(infos_agent.get("mode", "train"))
