#!/usr/bin/env python3
"""
Verbose debug / smoke-test for the discussion-format eval mode
with Llama-3.1-8B-Instruct (or any HF backend).

Each game step shows:
  - The step_prompt (current game only, no accumulated history)
  - How many turns are now in the conversation history
  - The full user turn sent to the model
  - The raw model response

All metrics (ZSCT, token stats, format errors) are logged per episode to
W&B and as Weave traces per LLM call and episode.

Usage:
  python test_discussion_mode.py                    # 2 episodes, O=1 S=1
  python test_discussion_mode.py --n_episodes 4
  python test_discussion_mode.py --config configs/eval/llama32_1b_hf.yaml
  python test_discussion_mode.py --wandb_project my-project
"""
from __future__ import annotations

import argparse
import re
import random
import textwrap

import numpy as np
import yaml
import weave

from meta_rg.agents.rule_based import build_rule_based_agents
from meta_rg.backends import build_backend
from meta_rg.env_utils import (
    build_env_from_cfg,
    get_intro_prompt_text,
    get_step_prompt_text,
    is_test_mode,
)
from meta_rg.game_loop import parse_action, run_one_game
from meta_rg.stats import token_stats as _token_stats

# ── Defaults ──────────────────────────────────────────────────────────────────

CONFIG         = "configs/eval/llama3_8b_hf.yaml"
N_EPISODES     = 2
SEED           = 0
MAX_NEW_TOKENS = 384
WEAVE_PROJECT  = "meta-rg-s2b"
WANDB_PROJECT  = "meta-rg-s2b"

_SEP  = "─" * 72
_NULL_TOK = {"mean": -1.0, "std": -1.0, "min": -1, "max": -1, "n": -1}


# ── Episode runner ────────────────────────────────────────────────────────────
# Heavy objects stored here; weave op only sees primitive seed/index params.
_ep_state: dict = {}


@weave.op(name="discussion_episode")
def run_verbose_discussion_episode(*, ep_index: int, ep_seed: int) -> dict:
    """
    Run one full episode in discussion format.

    Maintains a per-episode conversation history. Each game step:
      1. Gets the step_prompt (current game only) from the env.
      2. Builds the next user turn (first turn prepends the intro).
      3. Calls backend.generate_chat(history).
      4. Appends the response as an assistant turn.

    All LLM calls are traced by Weave via generate_chat.
    """
    env = _ep_state["env"]
    rb_speaker = _ep_state["rb_speaker"]
    backend = _ep_state["backend"]
    game_kw = _ep_state["game_kw"]
    tokenize_fn = backend.tokenize_fn
    obs, infos = env.reset()
    rb_speaker.reset()

    intro_text = get_intro_prompt_text(infos[1])

    conv_history: list[dict] = []
    first_turn = True
    prompt_token_lengths: list[int] = []
    completion_token_lengths: list[int] = []

    done = False
    train_correct = train_total = test_correct = test_total = 0
    n_format_errors = 0
    game_idx = 0
    games_log = []
    _is_test = [False]  # mutable cell so verbose_generate can read the outer value

    # Per-call state set by verbose_generate, consumed after run_one_game.
    _last_call: dict = {}

    def verbose_generate(step_text: str) -> str:
        """
        Called by run_one_game at the listener round (round_idx=1, step_id=1).
        At this point infos[1] already contains the speaker's message, so
        step_text correctly includes 'Your partner has sent you the following message'.
        """
        nonlocal first_turn

        phase = "TEST " if _is_test[0] else "train"
        print(f"\n{'━'*72}")
        print(f"  Game {game_idx}  [{phase}]  ep={ep_index}  seed={ep_seed}")
        print(f"  Step prompt ({len(step_text)} chars, first 300 shown):")
        print(textwrap.indent(step_text[:300] + ("…" if len(step_text) > 300 else ""), "  "))

        content = (intro_text + "\n\n" + step_text) if first_turn else step_text
        first_turn = False
        conv_history.append({"role": "user", "content": content})

        print(f"\n{_SEP}")
        print(f"  [Turn #{len(conv_history)} → LLM]  "
              f"conversation has {len(conv_history)} user/assistant turns")
        if len(conv_history) == 1:
            print(f"  (first turn: includes intro + step)")
        print(f"  Full user content ({len(content)} chars, first 500 shown):")
        print(textwrap.indent(content[:500] + ("…" if len(content) > 500 else ""), "  "))
        print(_SEP)

        # LLM call — traced as Weave child via @weave.op on generate_chat
        response = backend.generate_chat(conv_history)
        conv_history.append({"role": "assistant", "content": response})

        print(f"\n  [LLM → turn #{len(conv_history)}]  raw response:")
        print(f"  {response!r}")

        # Token tracking
        if tokenize_fn is not None:
            full_prompt = "\n\n".join(m["content"] for m in conv_history[:-1])
            prompt_token_lengths.append(tokenize_fn(full_prompt))
            completion_token_lengths.append(tokenize_fn(response))

        # Error check
        error_tags: list[str] = []
        if len(re.findall(r"\d+", response)) < 4:
            error_tags.append("format-error")
        if error_tags:
            print(f"  Error types    : {', '.join(error_tags)}")

        a1 = parse_action(
            response,
            agent_idx=1,
            nbr_distractors=game_kw["nbr_distractors"],
            descriptive=game_kw["descriptive"],
            vocab_size=game_kw["vocab_size"],
            max_sentence_length=game_kw["max_sentence_length"],
        )
        print(f"  Parsed  decision={a1['decision']}  "
              f"comm={a1['communication_channel'].tolist()}")

        _last_call["response"] = response
        _last_call["a1"] = a1
        _last_call["error_tags"] = error_tags
        return response

    while not done:
        _is_test[0] = is_test_mode(infos[0])

        reward, done, obs, infos = run_one_game(
            env, obs, infos, rb_speaker,
            lm_generate=verbose_generate,
            use_step_prompt=True,
            **game_kw,
        )

        response   = _last_call.get("response", "")
        a1         = _last_call.get("a1")
        error_tags = _last_call.get("error_tags", [])
        n_format_errors += len(error_tags)
        phase = "TEST" if _is_test[0] else "train"

        print(f"  Reward: {reward:+.0f}  →  {'✓ correct' if reward > 0 else '✗ wrong'}")

        if _is_test[0]:
            test_total += 1
            test_correct += int(reward > 0)
        else:
            train_total += 1
            train_correct += int(reward > 0)

        games_log.append({
            "game": game_idx,
            "phase": phase,
            "conv_turns": len(conv_history),
            "response": response,
            "decision": int(a1["decision"]) if a1 is not None else -1,
            "comm": a1["communication_channel"].tolist() if a1 is not None else [],
            "reward": float(reward),
            "error_tags": error_tags,
        })
        game_idx += 1

    zsct_acc = test_correct / max(test_total, 1)
    sup_acc  = train_correct / max(train_total, 1)

    tok_prompt     = _token_stats(prompt_token_lengths) if prompt_token_lengths else _NULL_TOK
    tok_completion = _token_stats(completion_token_lengths) if completion_token_lengths else _NULL_TOK
    total_games = test_total + train_total

    print(f"\n  Episode {ep_index} summary: "
          f"ZSCT={zsct_acc*100:.0f}% ({test_correct}/{test_total})  "
          f"support={sup_acc*100:.0f}%")
    print(f"  Format errors    : {n_format_errors}/{total_games}"
          f"  ({n_format_errors/max(total_games,1):.1%})")
    if prompt_token_lengths:
        print(f"  Token stats  prompt: mean={tok_prompt['mean']:.0f} "
              f"max={tok_prompt['max']}  "
              f"completion: mean={tok_completion['mean']:.1f} "
              f"max={tok_completion['max']}  "
              f"n_calls={tok_prompt['n']}")

    return {
        "ep_index": ep_index,
        "ep_seed": ep_seed,
        "zsct_acc": zsct_acc,
        "support_acc": sup_acc,
        "n_test": test_total,
        "n_train": train_total,
        "n_format_errors": n_format_errors,
        "n_conv_turns_final": len(conv_history),
        "prompt_tokens": tok_prompt,
        "completion_tokens": tok_completion,
        "games": games_log,
    }


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--config",         default=CONFIG)
    p.add_argument("--n_episodes",     type=int, default=N_EPISODES)
    p.add_argument("--seed",           type=int, default=SEED)
    p.add_argument("--max_new_tokens", type=int, default=MAX_NEW_TOKENS)
    p.add_argument("--weave_project",  default=WEAVE_PROJECT)
    p.add_argument("--wandb_project",  default=WANDB_PROJECT)
    args = p.parse_args()

    weave.init(args.weave_project)

    import wandb
    wandb.init(
        project=args.wandb_project,
        config={
            "config": args.config,
            "mode": "discussion",
            "n_episodes": args.n_episodes,
            "seed": args.seed,
            "max_new_tokens": args.max_new_tokens,
        },
        tags=["debug", "discussion"],
    )

    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    backend_cfg = dict(cfg.get("backend", {}))
    backend_cfg["max_new_tokens"] = args.max_new_tokens

    task_cfg = dict(cfg.get("task", {}))
    task_cfg.setdefault("nbr_latents", 3)
    task_cfg.setdefault("nbr_distractors", 0)
    task_cfg.setdefault("vocab_size", 6)
    task_cfg.setdefault("max_sentence_length", 3)
    task_cfg.setdefault("min_nbr_values_per_latent", 2)
    task_cfg.setdefault("max_nbr_values_per_latent", 5)
    task_cfg.setdefault("nbr_communication_rounds", 1)
    task_cfg.setdefault("descriptive", True)
    task_cfg.setdefault("provide_listener_feedback", True)
    task_cfg.setdefault("nbr_object_centric_samples", 1)
    task_cfg.setdefault("sampling_strategy", "component-focused-1shot")
    task_cfg["discussion_mode"] = True

    game_kw = {k: task_cfg[k] for k in (
        "nbr_communication_rounds", "nbr_distractors", "descriptive",
        "vocab_size", "max_sentence_length", "provide_listener_feedback",
    )}

    print(f"Config        : {args.config}")
    print(f"Mode          : discussion")
    print(f"Episodes      : {args.n_episodes}   Seed: {args.seed}")
    print(f"max_new_tokens: {args.max_new_tokens}")
    print(f"Weave project : {args.weave_project}")

    print("\nLoading model…")
    backend = build_backend(backend_cfg)
    backend.set_seed(args.seed)

    # Wrap generate_chat with Weave so each LLM call is a child trace.
    _orig_chat = backend.generate_chat
    @weave.op(name="lm_call_chat")
    def _traced_chat(messages: list) -> str:
        return _orig_chat(messages)
    backend.generate_chat = _traced_chat

    random.seed(args.seed)
    np.random.seed(args.seed)

    env = build_env_from_cfg(task_cfg, seed=args.seed)
    rb_speaker, _ = build_rule_based_agents(
        vocab_size=task_cfg["vocab_size"],
        max_sentence_length=task_cfg["max_sentence_length"],
        nbr_communication_rounds=task_cfg["nbr_communication_rounds"],
        nbr_latents=task_cfg["nbr_latents"],
    )

    ep_results = []
    for ep in range(args.n_episodes):
        ep_seed = args.seed + ep * 1000
        env.seed(ep_seed)
        backend.set_seed(args.seed)

        print(f"\n{'='*72}")
        print(f"  EPISODE {ep}  (env_seed={ep_seed})")
        print(f"{'='*72}")

        _ep_state.update({"env": env, "rb_speaker": rb_speaker,
                          "backend": backend, "game_kw": game_kw})
        result = run_verbose_discussion_episode(ep_index=ep, ep_seed=ep_seed)
        ep_results.append(result)

        total_ep_games = result["n_test"] + result["n_train"]
        wandb.log({
            "episode": ep,
            "ep_seed": ep_seed,
            "zsct_acc": result["zsct_acc"],
            "support_acc": result["support_acc"],
            "n_test": result["n_test"],
            "n_train": result["n_train"],
            "n_conv_turns_final": result["n_conv_turns_final"],
            "tokens/prompt_mean":      result["prompt_tokens"]["mean"],
            "tokens/prompt_max":       result["prompt_tokens"]["max"],
            "tokens/completion_mean":  result["completion_tokens"]["mean"],
            "tokens/completion_max":   result["completion_tokens"]["max"],
            "tokens/n_lm_calls":       result["prompt_tokens"]["n"],
            "errors/n_format_errors":  result["n_format_errors"],
            "errors/format_error_rate":
                result["n_format_errors"] / max(total_ep_games, 1),
            # Sentinel -1 for metrics that don't apply to discussion mode
            "errors/n_truncated":           -1,
            "errors/n_adapter_errors":      -1,
            "errors/n_re_prompt_truncated": -1,
            "errors/truncation_rate":       -1.0,
            "errors/adapter_error_rate":    -1.0,
            "errors/re_prompt_failure_rate":-1.0,
        })

    # ── Summary ──────────────────────────────────────────────────────────────
    zsct_values = [r["zsct_acc"] * 100 for r in ep_results]
    total_games = sum(r["n_test"] + r["n_train"] for r in ep_results)
    total_fmt   = sum(r["n_format_errors"] for r in ep_results)

    print(f"\n{'='*72}")
    print(f"  SUMMARY  ({args.n_episodes} episodes, mode=discussion)")
    print(f"{'='*72}")
    print(f"  Per-episode ZSCT : {[f'{v:.0f}%' for v in zsct_values]}")
    print(f"  Mean ZSCT        : {sum(zsct_values)/len(zsct_values):.1f}%")
    print(f"  Format errors    : {total_fmt}/{total_games}"
          f"  ({total_fmt/max(total_games,1):.1%})")

    all_ptl = [r["prompt_tokens"]["mean"] for r in ep_results
               if r["prompt_tokens"]["mean"] >= 0]
    all_ctl = [r["completion_tokens"]["mean"] for r in ep_results
               if r["completion_tokens"]["mean"] >= 0]
    if all_ptl:
        sp = _token_stats(all_ptl)
        sc = _token_stats(all_ctl)
        print(f"  Prompt tokens    : mean={sp['mean']:.0f}  max={sp['max']}")
        print(f"  Completion tokens: mean={sc['mean']:.0f}  max={sc['max']}")

    wandb.log({
        "summary/mean_zsct":               sum(zsct_values) / len(zsct_values),
        "summary/n_format_errors":         total_fmt,
        "summary/format_error_rate":       total_fmt / max(total_games, 1),
        "summary/tokens_prompt_mean":      _token_stats(all_ptl)["mean"] if all_ptl else -1.0,
        "summary/tokens_completion_mean":  _token_stats(all_ctl)["mean"] if all_ctl else -1.0,
        # Sentinel -1 for CoT-only metrics (keeps dashboard schema consistent)
        "summary/n_truncated":             -1,
        "summary/n_adapter_errors":        -1,
        "summary/n_re_prompt_truncated":   -1,
        "summary/truncation_rate":         -1.0,
        "summary/adapter_error_rate":      -1.0,
        "summary/re_prompt_failure_rate":  -1.0,
    })
    wandb.finish()
    backend.close()


if __name__ == "__main__":
    main()
