#!/usr/bin/env python3
"""
Zero-shot LLM evaluation on S2B — reproduces Section 4.4 / Table 4.

The speaker is always the rule-based Posdis-Speaker.
The listener is an LLM queried zero-shot (no fine-tuning).

Usage examples:

  # OpenAI GPT-4o-mini
  python run_eval.py --config configs/eval/gpt4o_mini.yaml --o 1 --shots 1

  # Local Mixtral via vLLM (offline)
  python run_eval.py --config configs/eval/mixtral_vllm.yaml --o 4 --shots 2

  # Full Table 4 sweep (all 4 conditions, 5 seeds, 64 episodes each)
  python run_eval.py --config configs/eval/mixtral_vllm.yaml --table4

  # llama.cpp (CPU)
  python run_eval.py --config configs/eval/smollm_llamacpp.yaml --o 1 --shots 1 \\
      --n_seeds 1 --n_episodes 4
"""
from __future__ import annotations

import argparse
import json
import os
import random

import numpy as np
import yaml
from tqdm import tqdm

from meta_rg.agents.rule_based import build_rule_based_agents
from meta_rg.backends import build_backend
from meta_rg.env_utils import build_env_from_cfg
from meta_rg.game_loop import run_episode
from meta_rg.metrics import aggregate_seeds, print_table, save_results
from meta_rg.stats import token_stats

# Sentinel dict for unavailable token stats (impossible values → filterable)
_NULL_TOK = {"mean": -1.0, "std": -1.0, "min": -1, "max": -1, "n": -1}


# ── CLI ───────────────────────────────────────────────────────────────────────

def parse_args():
    p = argparse.ArgumentParser(description="Zero-shot LLM evaluation on S2B (Table 4)")
    p.add_argument("--config", required=True, help="Path to backend YAML config")

    # Task overrides (override YAML)
    p.add_argument("--o", type=int, default=None, help="nbr_object_centric_samples (O=1 or 4)")
    p.add_argument("--shots", type=int, default=None, help="Number of shots S (1 or 2)")
    p.add_argument("--n_dim", type=int, default=None)
    p.add_argument("--v_min", type=int, default=None)
    p.add_argument("--v_max", type=int, default=None)
    p.add_argument("--vocab_size", type=int, default=None,
                   help="Override vocab_size from YAML (required when using --vocab_partition)")

    # Evaluation protocol
    p.add_argument("--n_seeds", type=int, default=5)
    p.add_argument("--n_episodes", type=int, default=64)
    p.add_argument("--base_seed", type=int, default=0)

    # Full Table 4 sweep: runs all (O, S) combinations
    p.add_argument("--table4", action="store_true",
                   help="Run all 4 (O,S) conditions from Table 4")

    # Prompt strategy
    p.add_argument("--prompt_strategy", type=str, default="none",
                   choices=["none", "zero_shot_cot", "few_shot_cot", "discussion_cot",
                            "few_shot_discussion_cot"],
                   help="Prompt strategy: none | zero_shot_cot | few_shot_cot | "
                        "discussion_cot | few_shot_discussion_cot")
    p.add_argument("--n_few_shot_games", type=int, default=2,
                   help="Warmup games per episode for few_shot_discussion_cot")
    p.add_argument("--discussion_mode", action="store_true",
                   help="Discussion format: env emits per-step prompts; eval loop "
                        "maintains multi-turn conversation history (incompatible with CoT)")
    p.add_argument("--max_new_tokens", type=int, default=None,
                   help="Override backend max_new_tokens/max_tokens (needed for CoT)")
    p.add_argument("--vocab_partition", action="store_true",
                   help="Speaker uses vocab-partitioned encoding: each latent gets a "
                        "disjoint token range. Requires vocab_size > nbr_latents * max_val.")
    p.add_argument("--domain", type=str, default="SCS",
                   choices=["SCS", "2D", "3D", "categorical", "pseudoword"],
                   help="Stimulus domain. 'categorical' replaces Gaussian floats with "
                        "named category items. 'pseudoword' uses dynamically generated "
                        "(CV)+ pseudowords. Both require --o 1.")
    p.add_argument("--verbose_prompt", action="store_true",
                   help="Include machine-readable tags and rhetorical filler in S2B prompts. "
                        "Default: slim prompts (strips [NBR_QUESTIONS], [MAX_NBR_OPTIONS], "
                        "and verbose preamble to reduce token count).")
    p.add_argument("--inductive_verbaliser", action="store_true",
                   help="Use inductive verbalization: reason via inverse symbol map "
                        "(value->symbol prediction) rather than the direct decode. "
                        "Only meaningful with --prompt_strategy few_shot_discussion_cot.")
    p.add_argument(
        "--elicitation_strategies",
        nargs="*",
        default=[],
        choices=["decoding_recipe", "cot_scaffold", "inline_hint"],
        metavar="STRATEGY",
        help="Listener/speaker prompt elicitation strategies (space-separated): "
             "decoding_recipe  cot_scaffold  inline_hint",
    )

    # Output
    p.add_argument("--output_dir", type=str, default="outputs/eval")
    p.add_argument("--wandb_project", type=str, default="meta-rg-s2b")
    p.add_argument("--weave_project", type=str, default=None,
                   help="W&B Weave project for LLM-call and episode tracing (optional)")

    return p.parse_args()


# ── Config loading ─────────────────────────────────────────────────────────────

def load_config(path: str) -> dict:
    with open(path) as f:
        return yaml.safe_load(f)


def make_task_cfg(base_cfg: dict, o: int, shots: int, domain: str = 'SCS') -> dict:
    cfg = dict(base_cfg.get("task", {}))
    # Paper defaults for Table 4
    cfg.setdefault("nbr_latents", 3)
    cfg.setdefault("nbr_distractors", 0)
    cfg.setdefault("vocab_size", 6)
    cfg.setdefault("max_sentence_length", 3)
    cfg.setdefault("min_nbr_values_per_latent", 2)
    cfg.setdefault("max_nbr_values_per_latent", 5)
    cfg.setdefault("nbr_communication_rounds", 1)
    cfg.setdefault("descriptive", True)
    cfg.setdefault("provide_listener_feedback", True)
    cfg["nbr_object_centric_samples"] = o
    cfg["sampling_strategy"] = f"component-focused-{shots}shot"
    cfg["domain"] = domain
    return cfg


# ── Single condition evaluation ────────────────────────────────────────────────

def eval_condition(
    task_cfg: dict,
    backend_cfg: dict,
    n_seeds: int,
    n_episodes: int,
    base_seed: int,
    pbar_desc: str = "",
    prompt_strategy_name: str = "none",
    discussion_mode: bool = False,
    vocab_partition: bool = False,
    use_wandb: bool = False,
    verbose_prompt: bool = False,
    elicitation_strategies: list | None = None,
    n_few_shot_games: int = 0,
    slim_verbaliser: bool = True,
    inductive_verbaliser: bool = False,
) -> dict:
    """Evaluate one (O, S) condition; return aggregate_seeds result."""
    game_kw = dict(
        nbr_communication_rounds=task_cfg["nbr_communication_rounds"],
        nbr_distractors=task_cfg["nbr_distractors"],
        descriptive=task_cfg["descriptive"],
        vocab_size=task_cfg["vocab_size"],
        max_sentence_length=task_cfg["max_sentence_length"],
        provide_listener_feedback=task_cfg["provide_listener_feedback"],
    )

    backend = build_backend(backend_cfg)
    lm_generate = backend.generate
    tokenize_fn = backend.tokenize_fn
    _cot_gen = None  # CotGenerator / DiscussionCotBackend reference for per-episode error metrics
    _discussion_backend = None  # backend ref for discussion-mode generate_chat

    if prompt_strategy_name in ("discussion_cot", "few_shot_discussion_cot"):
        from meta_rg.prompt_strategy import build_prompt_strategy
        _disc_cot = build_prompt_strategy(prompt_strategy_name, backend)
        _cot_gen = _disc_cot
        _discussion_backend = _disc_cot   # generate_chat runs DSPy CoT
        discussion_mode = True            # discussion_cot implicitly enables discussion mode
        tokenize_fn = None                # BackendLM handles token tracking internally
    elif prompt_strategy_name != "none":
        from meta_rg.prompt_strategy import build_prompt_strategy
        lm_generate = build_prompt_strategy(prompt_strategy_name, backend)
        _cot_gen = lm_generate  # save before possible weave wrapping
        tokenize_fn = None  # DSPy manages LM calls internally; token metrics unavailable for CoT

    if discussion_mode and _discussion_backend is None:
        _discussion_backend = backend  # plain discussion mode: use generate_chat directly

    from meta_rg.step_logger import make_step_logger
    _global_step:   list[int] = [0]
    _current_seed:  list[int] = [base_seed]
    _current_ep:    list[int] = [0]

    # Weave: trace every raw LM call (covers both 'none', CoT, and discussion paths)
    try:
        import weave as _weave
        _orig_lm = lm_generate
        @_weave.op(name="lm_call")
        def lm_generate(prompt_text: str) -> str:  # noqa: F811
            return _orig_lm(prompt_text)
        # Log the full messages list so Weave renders the Chat tab.
        # Per-call size is O(N) for call N (~2 KB/game × N), well under the 3.6 MB limit.
        # NOT a wrapper around generate_chat itself — avoids serialising __self__.
        @_weave.op(name="lm_call_chat")
        def _log_chat_call(messages: list, response: str) -> str:
            return response
        if _discussion_backend is not None:
            _orig_chat = _discussion_backend.generate_chat
            def _traced_chat(msgs: list) -> str:
                resp = _orig_chat(msgs)
                _log_chat_call(messages=msgs, response=resp)
                return resp
            _discussion_backend.generate_chat = _traced_chat
        # Thin episode op: only seed + episode_idx are weave-visible (avoids
        # serialising env, rb_speaker, and backend objects into the trace).
        # Heavy args are captured via _ep_state, updated each seed iteration.
        _ep_state: dict = {}
        @_weave.op(name="episode")
        def _run_episode(seed: int, episode_idx: int) -> dict:
            return run_episode(
                _ep_state["env"],
                _ep_state["rb_speaker"],
                lm_generate,
                rb_listener=_ep_state.get("rb_listener"),
                game_kwargs=game_kw,
                tokenize_fn=tokenize_fn,
                discussion_backend=_discussion_backend,
                step_callback=_ep_state.get("step_callback"),
                n_few_shot_games=(n_few_shot_games
                                  if prompt_strategy_name == "few_shot_discussion_cot"
                                  else 0),
                few_shot_domain=task_cfg.get("domain", "SCS"),
                few_shot_nbr_object_centric=task_cfg.get("nbr_object_centric_samples", 1),
            )
    except ImportError:
        _ep_state: dict = {}
        def _run_episode(seed: int, episode_idx: int) -> dict:  # noqa: F811
            return run_episode(
                _ep_state["env"],
                _ep_state["rb_speaker"],
                lm_generate,
                rb_listener=_ep_state.get("rb_listener"),
                game_kwargs=game_kw,
                tokenize_fn=tokenize_fn,
                discussion_backend=_discussion_backend,
                step_callback=_ep_state.get("step_callback"),
                n_few_shot_games=(n_few_shot_games
                                  if prompt_strategy_name == "few_shot_discussion_cot"
                                  else 0),
                few_shot_domain=task_cfg.get("domain", "SCS"),
                few_shot_nbr_object_centric=task_cfg.get("nbr_object_centric_samples", 1),
            )

    seed_results = []
    outer = tqdm(range(n_seeds), desc=pbar_desc or "seeds", unit="seed")
    for s in outer:
        seed = base_seed + s
        random.seed(seed)
        np.random.seed(seed)

        _task_cfg_env = dict(task_cfg)
        if discussion_mode:
            _task_cfg_env["discussion_mode"] = True
        _task_cfg_env["verbose_prompts"] = verbose_prompt
        _task_cfg_env["allow_cot_response"] = prompt_strategy_name in (
            "discussion_cot", "few_shot_discussion_cot")
        _task_cfg_env["elicitation_strategies"] = elicitation_strategies or []
        env = build_env_from_cfg(_task_cfg_env, seed=seed)
        rb_speaker, rb_listener = build_rule_based_agents(
            vocab_size=task_cfg["vocab_size"],
            max_sentence_length=task_cfg["max_sentence_length"],
            nbr_communication_rounds=task_cfg["nbr_communication_rounds"],
            nbr_latents=task_cfg["nbr_latents"],
            max_nbr_values_per_latent=task_cfg.get("max_nbr_values_per_latent", 5),
            vocab_partition=vocab_partition,
            use_hypothesis_listener=(prompt_strategy_name == "few_shot_discussion_cot"),
            slim_verbaliser=slim_verbaliser,
            inductive_verbaliser=inductive_verbaliser,
        )
        _ep_state["env"] = env
        _ep_state["rb_speaker"] = rb_speaker
        _ep_state["rb_listener"] = rb_listener

        ep_results = []
        ep_bar = tqdm(range(n_episodes), desc=f"  seed {s}", leave=False, unit="ep")
        for ep in ep_bar:
            env.seed(seed + ep * 1000)
            backend.set_seed(seed)
            if _cot_gen is not None:
                _cot_gen.reset_stats()
            _current_seed[0] = seed
            _current_ep[0]   = ep
            _reset_snap, _step_cb = make_step_logger(
                backend=backend,
                cot_gen=_cot_gen,
                use_wandb=use_wandb,
                global_step=_global_step,
                get_current_seed=lambda: _current_seed[0],
                get_current_ep=lambda: _current_ep[0],
            )
            _reset_snap()
            _ep_state["step_callback"] = _step_cb
            result = _run_episode(seed=seed, episode_idx=ep)

            # ── Normalise: every episode result always carries all metrics ────
            if _cot_gen is not None:
                # CoT: real error counts + token stats from BackendLM
                result["n_truncated"]           = _cot_gen.n_truncated
                result["n_adapter_errors"]      = _cot_gen.n_adapter_errors
                result["n_re_prompt_truncated"] = _cot_gen.n_re_prompt_truncated
                result["n_format_errors"]       = _cot_gen.n_format_errors
                ptl = _cot_gen.prompt_token_lengths
                if ptl:
                    result["prompt_tokens"]     = token_stats(ptl)
                    result["completion_tokens"] = token_stats(_cot_gen.completion_token_lengths)
                else:
                    result.setdefault("prompt_tokens",     _NULL_TOK)
                    result.setdefault("completion_tokens", _NULL_TOK)
            else:
                # Non-CoT: error stats are N/A; token stats come from run_episode (or absent)
                result["n_truncated"]           = -1
                result["n_adapter_errors"]      = -1
                result["n_re_prompt_truncated"] = -1
                result["n_format_errors"]       = -1
                result.setdefault("prompt_tokens",     _NULL_TOK)
                result.setdefault("completion_tokens", _NULL_TOK)

            ep_results.append(result)
            ep_bar.set_postfix({"zsct": f"{result['zsct_acc']*100:.1f}%"})

        seed_results.append(ep_results)
        seed_zsct = np.mean([r["zsct_acc"] * 100 for r in ep_results])
        outer.set_postfix({"seed_zsct": f"{seed_zsct:.1f}%"})

    # Collect prefix-cache stats from HFAPIBackend before closing
    _prompt_tok  = getattr(backend, "total_prompt_tokens",     None)
    _cached_tok  = getattr(backend, "total_cached_tokens",     None)
    _compl_tok   = getattr(backend, "total_completion_tokens", None)

    backend.close()
    agg = aggregate_seeds(seed_results)

    if _prompt_tok is not None:
        agg["api_prompt_tokens"]     = _prompt_tok
        agg["api_completion_tokens"] = _compl_tok
        agg["api_cached_tokens"]     = _cached_tok
        agg["api_cache_hit_rate"]    = (
            _cached_tok / _prompt_tok if _prompt_tok > 0 else 0.0
        )
    return agg


# ── Main ───────────────────────────────────────────────────────────────────────

def main():
    args = parse_args()
    cfg = load_config(args.config)

    backend_cfg = cfg.get("backend", {})
    model_name = backend_cfg.get("model", cfg.get("name", "model"))

    # Apply max_new_tokens override (needed when using CoT strategies)
    if args.max_new_tokens is not None:
        for key in ("max_new_tokens", "max_tokens"):
            if key in backend_cfg:
                backend_cfg[key] = args.max_new_tokens
        if "max_new_tokens" not in backend_cfg and "max_tokens" not in backend_cfg:
            backend_cfg["max_new_tokens"] = args.max_new_tokens

    if args.discussion_mode and args.prompt_strategy not in ("none", "discussion_cot"):
        raise ValueError(
            "--discussion_mode is only compatible with --prompt_strategy none or discussion_cot"
        )

    os.makedirs(args.output_dir, exist_ok=True)

    use_wandb = args.wandb_project is not None
    if use_wandb:
        import wandb
        wandb.init(project=args.wandb_project, config={"config": cfg, "args": vars(args)})

    if args.weave_project:
        try:
            import weave
            weave.init(args.weave_project)
            # Disable weave's DSPy auto-instrumentation: WeaveCallback and SymbolPatchers
            # inject 'self' and 'lm' into every DSPy trace. Our own @weave.op decorators
            # (lm_call, lm_call_chat, episode, env_step, env_reset) cover everything we need.
            try:
                import dspy as _dspy
                from weave.integrations.dspy.dspy_callback import WeaveCallback as _WeaveCB
                from weave.integrations.dspy import dspy_sdk as _dspy_sdk
                _dspy.settings.configure(
                    callbacks=[c for c in _dspy.settings.get("callbacks", [])
                               if not isinstance(c, _WeaveCB)]
                )
                if _dspy_sdk._dspy_patcher is not None:
                    _dspy_sdk._dspy_patcher.undo_patch()
            except Exception:
                pass
        except ImportError:
            print("Warning: weave not installed; skipping Weave tracing. "
                  "Install with: pip install 'meta-rg-s2b[weave]'")

    conditions = [(1, 1), (1, 2), (4, 1), (4, 2)] if args.table4 else [
        (args.o or cfg.get("task", {}).get("nbr_object_centric_samples", 1),
         args.shots or 1)
    ]

    all_results = {}
    table_results = {model_name: {}}

    for (o, shots) in conditions:
        task_cfg = make_task_cfg(cfg, o, shots, domain=args.domain)
        # Apply any CLI overrides
        if args.n_dim:
            task_cfg["nbr_latents"] = args.n_dim
        if args.v_min:
            task_cfg["min_nbr_values_per_latent"] = args.v_min
        if args.v_max:
            task_cfg["max_nbr_values_per_latent"] = args.v_max
        if args.vocab_size:
            task_cfg["vocab_size"] = args.vocab_size

        desc = f"O={o} S={shots}"
        print(f"\n{'='*60}")
        print(f"  Condition: {desc}")
        print(f"  Model: {model_name}")
        print(f"  Seeds: {args.n_seeds}  Episodes: {args.n_episodes}")
        print(f"{'='*60}")

        result = eval_condition(
            task_cfg=task_cfg,
            backend_cfg=backend_cfg,
            n_seeds=args.n_seeds,
            n_episodes=args.n_episodes,
            base_seed=args.base_seed,
            pbar_desc=desc,
            prompt_strategy_name=args.prompt_strategy,
            discussion_mode=args.discussion_mode,
            vocab_partition=args.vocab_partition,
            use_wandb=use_wandb,
            verbose_prompt=args.verbose_prompt,
            elicitation_strategies=args.elicitation_strategies,
            n_few_shot_games=args.n_few_shot_games,
            slim_verbaliser=not args.verbose_prompt,
            inductive_verbaliser=args.inductive_verbaliser,
        )
        all_results[(o, shots)] = result
        table_results[model_name][(o, shots)] = result

        print(f"\n  ZSCT accuracy: {result['mean']:.1f} ± {result['std']:.1f}%")
        print(f"  Per-seed: {result['per_seed']}")
        if "prompt_tokens" in result:
            p, c = result["prompt_tokens"], result["completion_tokens"]
            n_ep = len(result.get("per_episode_tokens", []))
            print(f"  Prompt tokens     : mean={p['mean']:.1f}  std={p['std']:.1f}"
                  f"  min={p['min']}  max={p['max']}  (over {n_ep} episodes)")
            print(f"  Completion tokens : mean={c['mean']:.1f}  std={c['std']:.1f}"
                  f"  min={c['min']}  max={c['max']}  (over {n_ep} episodes)")
            print(f"  LM calls total    : {result['n_lm_calls']}")
        if result.get("api_prompt_tokens", 0) > 0:
            pt = result["api_prompt_tokens"]
            ct = result["api_completion_tokens"]
            ca = result["api_cached_tokens"]
            hr = result["api_cache_hit_rate"]
            print(f"  API prompt tokens : {pt:,}  completion: {ct:,}  total: {pt+ct:,}")
            if ca > 0:
                print(f"  Prefix cache hits : {ca:,} / {pt:,} prompt tokens  ({hr:.1%})")
            else:
                print(f"  Prefix cache hits : 0 (provider may not support caching)")
        if result.get("n_context_overflow", 0) > 0:
            print(f"  Context overflow  : {result['n_context_overflow']}"
                  f"  (rate={result['context_overflow_rate']:.1%})")
        if "n_truncated" in result:
            print(f"  Truncated (1st)   : {result['n_truncated']}"
                  f"  (rate={result['truncation_rate']:.1%})")
            print(f"  Adapter errors    : {result['n_adapter_errors']}"
                  f"  (rate={result['adapter_error_rate']:.1%})")
            print(f"  Re-prompt failures: {result['n_re_prompt_truncated']}"
                  f"  (of re-prompts: {result['re_prompt_failure_rate']:.1%})")
            print(f"  Format errors     : {result['n_format_errors']}"
                  f"  (rate={result['format_error_rate']:.1%})")

        if use_wandb:
            import wandb
            log_dict = {
                f"zsct/O{o}_S{shots}/mean": result["mean"],
                f"zsct/O{o}_S{shots}/std": result["std"],
            }
            for prefix, stats in (
                (f"tokens/prompt/O{o}_S{shots}",      result["prompt_tokens"]),
                (f"tokens/completion/O{o}_S{shots}",  result["completion_tokens"]),
            ):
                log_dict.update({f"{prefix}/{k}": v for k, v in stats.items()})
            log_dict[f"tokens/n_lm_calls/O{o}_S{shots}"] = result.get("n_lm_calls", -1)
            if result.get("api_prompt_tokens", 0) > 0:
                apfx = f"api_tokens/O{o}_S{shots}"
                log_dict.update({
                    f"{apfx}/prompt":     result["api_prompt_tokens"],
                    f"{apfx}/completion": result["api_completion_tokens"],
                    f"{apfx}/cached":     result["api_cached_tokens"],
                    f"{apfx}/cache_hit_rate": result["api_cache_hit_rate"],
                })
            pfx = f"errors/O{o}_S{shots}"
            log_dict.update({
                f"{pfx}/n_truncated":           result.get("n_truncated", -1),
                f"{pfx}/n_adapter_errors":      result.get("n_adapter_errors", -1),
                f"{pfx}/n_re_prompt_truncated": result.get("n_re_prompt_truncated", -1),
                f"{pfx}/n_format_errors":       result.get("n_format_errors", -1),
                f"{pfx}/truncation_rate":       result.get("truncation_rate", -1.0),
                f"{pfx}/adapter_error_rate":    result.get("adapter_error_rate", -1.0),
                f"{pfx}/re_prompt_failure_rate":result.get("re_prompt_failure_rate", -1.0),
                f"{pfx}/format_error_rate":     result.get("format_error_rate", -1.0),
                f"{pfx}/n_context_overflow":    result.get("n_context_overflow", 0),
                f"{pfx}/context_overflow_rate": result.get("context_overflow_rate", 0.0),
            })
            wandb.log(log_dict)

    # Summary table
    if len(conditions) == 4:
        print("\n\nTable 4 results:")
        print_table(table_results)

    # Save JSON
    strategy_tag = f"_{args.prompt_strategy}" if args.prompt_strategy != "none" else ""
    if args.discussion_mode:
        strategy_tag += "_discussion"
    out_path = os.path.join(
        args.output_dir,
        f"{model_name.replace('/', '_')}{strategy_tag}_results.json",
    )
    serialisable = {str(k): v for k, v in all_results.items()}
    save_results(out_path, serialisable)
    print(f"\nResults saved to {out_path}")


if __name__ == "__main__":
    main()
