#!/usr/bin/env python3
"""
Debug / smoke-test for the DSPy zero-shot-CoT (or few-shot-CoT) strategy
with Llama-3.1-8B-Instruct.

Every episode and every LLM call is a separate Weave trace (via @weave.op
decorators already in BackendLM.forward and the wrappers in run_eval).
This script adds an additional verbose console printout so you can see
what exactly is sent to and received from the model.

Usage:
  python test_dspy_cot.py                           # zero_shot_cot, 2 episodes
  python test_dspy_cot.py --n_episodes 4
  python test_dspy_cot.py --strategy few_shot_cot
  python test_dspy_cot.py --weave_project my-project
"""
from __future__ import annotations

import argparse
import re
import random
import textwrap

import numpy as np
import yaml
import weave

import dspy

from meta_rg.agents.rule_based import build_rule_based_agents
from meta_rg.backends import build_backend
from meta_rg.env_utils import build_env_from_cfg, get_prompt_text, is_test_mode
from meta_rg.game_loop import parse_action, run_one_game
from meta_rg.prompt_strategy import BackendLM, ListenerSignature, _DEMOS, _make_response

# ── Defaults ──────────────────────────────────────────────────────────────────

CONFIG         = "configs/eval/llama3_8b_hf.yaml"
N_EPISODES     = 2
SEED           = 0
STRATEGY       = "zero_shot_cot"
MAX_NEW_TOKENS = 384
WEAVE_PROJECT  = "meta-rg-s2b"
WANDB_PROJECT  = "meta-rg-s2b"


# ── Verbose BackendLM: adds console printout on top of the inherited @weave.op ─

class VerboseBackendLM(BackendLM):
    """Prints the full DSPy-formatted prompt and raw response to stdout."""

    def forward(self, prompt=None, messages=None, **kwargs):
        if messages:
            text = "\n\n".join(m["content"] for m in messages if m.get("content"))
        else:
            text = prompt or ""

        _sep = "─" * 72
        print(f"\n{_sep}")
        print("  [DSPy → LLM]  full text sent to backend.generate():")
        print(_sep)
        print(textwrap.indent(text, "  "))
        print(_sep)

        content = self._backend.generate(text)
        self._track_tokens(text, content)

        print(f"\n  [LLM → DSPy]  raw response:")
        print(_sep)
        print(textwrap.indent(content, "  "))
        print(_sep)

        from meta_rg.prompt_strategy import _log_llm_call
        _log_llm_call(prompt_text=text, content=content)
        return _make_response(content, text)


# ── Episode runner — each call is a Weave trace ───────────────────────────────
# Heavy objects stored here; weave op only sees primitive seed/index params.
_ep_state: dict = {}


@weave.op(name="episode")
def run_verbose_episode(*, ep_index: int, ep_seed: int) -> dict:
    """Run one episode; print diagnostics and return metrics (stored in Weave)."""
    from meta_rg.stats import token_stats as _token_stats
    _NULL_TOK = {"mean": -1.0, "std": -1.0, "min": -1, "max": -1, "n": -1}

    env = _ep_state["env"]
    rb_speaker = _ep_state["rb_speaker"]
    dspy_module = _ep_state["dspy_module"]
    backend = _ep_state["backend"]
    backend_lm = _ep_state["backend_lm"]
    game_kw = _ep_state["game_kw"]

    backend_lm.reset_token_stats()
    obs, infos = env.reset()
    rb_speaker.reset()
    done = False
    train_correct = train_total = test_correct = test_total = 0
    n_truncated = n_adapter_errors = n_re_prompt_truncated = n_format_errors = 0
    game_idx = 0
    games_log = []

    while not done:
        is_test = is_test_mode(infos[0])
        phase = "TEST " if is_test else "train"

        raw_s2b = get_prompt_text(infos[1])

        print(f"\n{'━'*72}")
        print(f"  Game {game_idx}  [{phase}]  ep={ep_index}  seed={ep_seed}")
        print(f"  Raw S2B prompt ({len(raw_s2b)} chars, first 400 shown):")
        print(textwrap.indent(raw_s2b[:400] + ("…" if len(raw_s2b) > 400 else ""), "  "))

        # DSPy call — traced by Weave as llm_call child of this episode trace
        adapter_error = False
        try:
            pred = dspy_module(game_description=raw_s2b)
            print(f"\n  pred.answer    : {pred.answer!r}")
            if hasattr(pred, "reasoning"):
                print(f"  pred.reasoning : {str(pred.reasoning)[:200]!r}")
            answer = (pred.answer or "").strip()
        except Exception as exc:
            adapter_error = True
            n_adapter_errors += 1
            answer = ""
            print(f"\n  pred.answer    : <AdapterParseError: {type(exc).__name__}>")

        # ── Fallback when first trial failed (truncation or adapter error) ────
        error_tags: list[str] = []
        if not answer and adapter_error:
            error_tags.append("adapter-error")
        if not answer and not adapter_error:
            n_truncated += 1
            error_tags.append("truncated")
            print(f"\n  *** TRUNCATED — pred.answer empty, re-prompting ***")
            reprompt = (raw_s2b +
                        "\n\nNow state ONLY the final answer as 4 space-separated integers "
                        "(decision token1 token2 token3). No explanation. "
                        "Example: '0 2 3 1'.\nAnswer:")
            # Use 16 tokens — enough for "D T1 T2 T3\n"
            orig_max = getattr(backend, "max_new_tokens", None) or getattr(backend, "max_tokens", 64)
            if hasattr(backend, "max_new_tokens"):
                backend.max_new_tokens = 16
            elif hasattr(backend, "max_tokens"):
                backend.max_tokens = 16
            raw_reprompt = backend.generate(reprompt).strip()
            if hasattr(backend, "max_new_tokens"):
                backend.max_new_tokens = orig_max
            elif hasattr(backend, "max_tokens"):
                backend.max_tokens = orig_max
            print(f"  Re-prompt raw  : {raw_reprompt!r}")
            nums = re.findall(r"\d+", raw_reprompt)
            answer = " ".join(str(n) for n in nums[:4]) if nums else ""
            if not answer:
                n_re_prompt_truncated += 1
                error_tags.append("re-prompt-failed")
                answer = "0 0 0 0"
                print(f"  *** RE-PROMPT ALSO FAILED — using default '{answer}' ***")

        if len(re.findall(r"\d+", answer)) < 4:
            n_format_errors += 1
            error_tags.append("format-error")

        if error_tags:
            print(f"  Error types    : {', '.join(error_tags)}")

        a1 = parse_action(
            answer,
            agent_idx=1,
            nbr_distractors=game_kw["nbr_distractors"],
            descriptive=game_kw["descriptive"],
            vocab_size=game_kw["vocab_size"],
            max_sentence_length=game_kw["max_sentence_length"],
        )
        print(f"  Parsed  decision={a1['decision']}  "
              f"comm={a1['communication_channel'].tolist()}")

        cached = answer
        reward, done, obs, infos = run_one_game(
            env, obs, infos, rb_speaker,
            lm_generate=lambda _: cached,
            **game_kw,
        )
        print(f"  Reward: {reward:+.0f}  →  {'✓ correct' if reward > 0 else '✗ wrong'}")

        if is_test:
            test_total += 1
            test_correct += int(reward > 0)
        else:
            train_total += 1
            train_correct += int(reward > 0)

        games_log.append({
            "game": game_idx, "phase": phase.strip(),
            "answer": answer,
            "decision": int(a1["decision"]),
            "comm": a1["communication_channel"].tolist(),
            "reward": float(reward),
            "error_tags": error_tags,
        })
        game_idx += 1

    zsct_acc = test_correct / max(test_total, 1)
    sup_acc  = train_correct / max(train_total, 1)

    print(f"\n  Episode {ep_index} summary: "
          f"ZSCT={zsct_acc*100:.0f}% ({test_correct}/{test_total})  "
          f"support={sup_acc*100:.0f}%")
    total_games = test_total + train_total
    n_re_prompted = n_truncated + n_adapter_errors

    ptl = backend_lm.prompt_token_lengths
    ctl = backend_lm.completion_token_lengths
    tok_prompt     = _token_stats(ptl) if ptl else _NULL_TOK
    tok_completion = _token_stats(ctl) if ctl else _NULL_TOK

    if n_truncated or n_adapter_errors or n_re_prompt_truncated or n_format_errors:
        print(f"  Error breakdown: "
              f"truncated={n_truncated}/{total_games}  "
              f"adapter-errors={n_adapter_errors}/{total_games}  "
              f"re-prompt-failed={n_re_prompt_truncated}/{max(n_re_prompted,1)}  "
              f"format-errors={n_format_errors}/{total_games}")
    if ptl:
        print(f"  Token stats  prompt: mean={tok_prompt['mean']:.0f} "
              f"max={tok_prompt['max']}  "
              f"completion: mean={tok_completion['mean']:.1f} "
              f"max={tok_completion['max']}  "
              f"n_calls={tok_prompt['n']}")

    return {
        "ep_index": ep_index,
        "ep_seed": ep_seed,
        "zsct_acc": zsct_acc,
        "support_acc": sup_acc,
        "n_test": test_total,
        "n_train": train_total,
        "n_truncated": n_truncated,
        "n_adapter_errors": n_adapter_errors,
        "n_re_prompt_truncated": n_re_prompt_truncated,
        "n_format_errors": n_format_errors,
        "prompt_tokens": tok_prompt,
        "completion_tokens": tok_completion,
        "games": games_log,
    }


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    p = argparse.ArgumentParser()
    p.add_argument("--config",         default=CONFIG)
    p.add_argument("--strategy",       default=STRATEGY,
                   choices=["zero_shot_cot", "few_shot_cot"])
    p.add_argument("--n_episodes",     type=int, default=N_EPISODES)
    p.add_argument("--seed",           type=int, default=SEED)
    p.add_argument("--max_new_tokens", type=int, default=MAX_NEW_TOKENS)
    p.add_argument("--weave_project",  default=WEAVE_PROJECT)
    p.add_argument("--wandb_project",  default=WANDB_PROJECT)
    args = p.parse_args()

    weave.init(args.weave_project)

    import wandb
    wandb.init(
        project=args.wandb_project,
        config={
            "config": args.config,
            "strategy": args.strategy,
            "n_episodes": args.n_episodes,
            "seed": args.seed,
            "max_new_tokens": args.max_new_tokens,
        },
        tags=["debug", "cot", args.strategy],
    )

    # ── Config ───────────────────────────────────────────────────────────────
    with open(args.config) as f:
        cfg = yaml.safe_load(f)

    backend_cfg = dict(cfg.get("backend", {}))
    backend_cfg["max_new_tokens"] = args.max_new_tokens

    task_cfg = dict(cfg.get("task", {}))
    task_cfg.setdefault("nbr_latents", 3)
    task_cfg.setdefault("nbr_distractors", 0)
    task_cfg.setdefault("vocab_size", 6)
    task_cfg.setdefault("max_sentence_length", 3)
    task_cfg.setdefault("min_nbr_values_per_latent", 2)
    task_cfg.setdefault("max_nbr_values_per_latent", 5)
    task_cfg.setdefault("nbr_communication_rounds", 1)
    task_cfg.setdefault("descriptive", True)
    task_cfg.setdefault("provide_listener_feedback", True)
    task_cfg.setdefault("nbr_object_centric_samples", 1)
    task_cfg.setdefault("sampling_strategy", "component-focused-1shot")

    game_kw = {k: task_cfg[k] for k in (
        "nbr_communication_rounds", "nbr_distractors", "descriptive",
        "vocab_size", "max_sentence_length", "provide_listener_feedback",
    )}

    print(f"Config        : {args.config}")
    print(f"Strategy      : {args.strategy}")
    print(f"Episodes      : {args.n_episodes}   Seed: {args.seed}")
    print(f"max_new_tokens: {args.max_new_tokens}")
    print(f"Weave project : {args.weave_project}")

    # ── Backend + DSPy module ────────────────────────────────────────────────
    print("\nLoading model…")
    backend = build_backend(backend_cfg)
    backend.set_seed(args.seed)

    verbose_lm = VerboseBackendLM(backend)
    dspy.configure(lm=verbose_lm)

    if args.strategy == "zero_shot_cot":
        module = dspy.ChainOfThought(ListenerSignature)
    else:
        module = dspy.ChainOfThought(ListenerSignature)
        module.demos = list(_DEMOS)

    # ── Env + speaker ────────────────────────────────────────────────────────
    random.seed(args.seed)
    np.random.seed(args.seed)

    env = build_env_from_cfg(task_cfg, seed=args.seed)
    rb_speaker, _ = build_rule_based_agents(
        vocab_size=task_cfg["vocab_size"],
        max_sentence_length=task_cfg["max_sentence_length"],
        nbr_communication_rounds=task_cfg["nbr_communication_rounds"],
        nbr_latents=task_cfg["nbr_latents"],
    )

    # ── Run episodes ─────────────────────────────────────────────────────────
    ep_results = []
    for ep in range(args.n_episodes):
        ep_seed = args.seed + ep * 1000
        env.seed(ep_seed)
        backend.set_seed(args.seed)

        print(f"\n{'='*72}")
        print(f"  EPISODE {ep}  (env_seed={ep_seed})")
        print(f"{'='*72}")

        _ep_state.update({"env": env, "rb_speaker": rb_speaker,
                          "dspy_module": module, "backend": backend,
                          "backend_lm": verbose_lm, "game_kw": game_kw})
        result = run_verbose_episode(ep_index=ep, ep_seed=ep_seed)
        ep_results.append(result)

        total_ep_games = result["n_test"] + result["n_train"]
        ep_re_prompted = result["n_truncated"] + result["n_adapter_errors"]
        wandb.log({
            "episode": ep,
            "ep_seed": ep_seed,
            "zsct_acc": result["zsct_acc"],
            "support_acc": result["support_acc"],
            "n_test": result["n_test"],
            "n_train": result["n_train"],
            "tokens/prompt_mean":      result["prompt_tokens"]["mean"],
            "tokens/prompt_max":       result["prompt_tokens"]["max"],
            "tokens/completion_mean":  result["completion_tokens"]["mean"],
            "tokens/completion_max":   result["completion_tokens"]["max"],
            "tokens/n_lm_calls":       result["prompt_tokens"]["n"],
            "errors/n_truncated":           result["n_truncated"],
            "errors/n_adapter_errors":      result["n_adapter_errors"],
            "errors/n_re_prompt_truncated": result["n_re_prompt_truncated"],
            "errors/n_format_errors":       result["n_format_errors"],
            "errors/truncation_rate":
                result["n_truncated"] / max(total_ep_games, 1),
            "errors/adapter_error_rate":
                result["n_adapter_errors"] / max(total_ep_games, 1),
            "errors/re_prompt_failure_rate":
                result["n_re_prompt_truncated"] / max(ep_re_prompted, 1),
            "errors/format_error_rate":
                result["n_format_errors"] / max(total_ep_games, 1),
        })

    # ── Summary ──────────────────────────────────────────────────────────────
    zsct_values    = [r["zsct_acc"] * 100 for r in ep_results]
    total_games    = sum(r["n_test"] + r["n_train"] for r in ep_results)
    total_trunc    = sum(r["n_truncated"] for r in ep_results)
    total_adapter  = sum(r["n_adapter_errors"] for r in ep_results)
    total_reprompt = sum(r["n_re_prompt_truncated"] for r in ep_results)
    total_fmt      = sum(r["n_format_errors"] for r in ep_results)
    total_re_prompted = total_trunc + total_adapter

    print(f"\n{'='*72}")
    print(f"  SUMMARY  ({args.n_episodes} episodes, strategy={args.strategy})")
    print(f"{'='*72}")
    print(f"  Per-episode ZSCT : {[f'{v:.0f}%' for v in zsct_values]}")
    print(f"  Mean ZSCT        : {sum(zsct_values)/len(zsct_values):.1f}%")
    print(f"  Truncated (1st)  : {total_trunc}/{total_games}"
          f"  ({total_trunc/max(total_games,1):.1%})")
    print(f"  Adapter errors   : {total_adapter}/{total_games}"
          f"  ({total_adapter/max(total_games,1):.1%})")
    print(f"  Re-prompt failed : {total_reprompt}/{max(total_re_prompted,1)}"
          f"  ({total_reprompt/max(total_re_prompted,1):.1%} of re-prompts)")
    print(f"  Format errors    : {total_fmt}/{total_games}"
          f"  ({total_fmt/max(total_games,1):.1%})")

    from meta_rg.stats import token_stats as _ts
    _NULL_TOK = {"mean": -1.0, "max": -1, "n": -1}
    all_ptl = [l for r in ep_results for l in (r["prompt_tokens"].get("mean", -1),) if l >= 0]
    all_ctl = [l for r in ep_results for l in (r["completion_tokens"].get("mean", -1),) if l >= 0]
    summary_prompt     = _ts(all_ptl) if all_ptl else _NULL_TOK
    summary_completion = _ts(all_ctl) if all_ctl else _NULL_TOK

    wandb.log({
        "summary/mean_zsct": sum(zsct_values) / len(zsct_values),
        "summary/tokens_prompt_mean":     summary_prompt["mean"],
        "summary/tokens_prompt_max":      summary_prompt["max"],
        "summary/tokens_completion_mean": summary_completion["mean"],
        "summary/tokens_completion_max":  summary_completion["max"],
        "summary/n_truncated": total_trunc,
        "summary/n_adapter_errors": total_adapter,
        "summary/n_re_prompt_truncated": total_reprompt,
        "summary/n_format_errors": total_fmt,
        "summary/truncation_rate": total_trunc / max(total_games, 1),
        "summary/adapter_error_rate": total_adapter / max(total_games, 1),
        "summary/re_prompt_failure_rate": total_reprompt / max(total_re_prompted, 1),
        "summary/format_error_rate": total_fmt / max(total_games, 1),
    })
    wandb.finish()

    backend.close()


if __name__ == "__main__":
    main()
