#!/usr/bin/env python3
"""
GRPO + LoRA fine-tuning on S2B — RLVR extension (future work).

Adapted from Regym/benchmark/GRPOApproach/SymbolicBehaviourBenchmark/train_grpo_s2b.py

Usage:
  python run_grpo.py --config configs/grpo/smollm.yaml
  python run_grpo.py --config configs/grpo/smollm.yaml --role listener --counterpart rule-based

Load and evaluate a checkpoint:
  python run_grpo.py --config configs/grpo/smollm.yaml --eval_only \\
      --checkpoint outputs/grpo/step_500
"""
from __future__ import annotations

import argparse
import os

import yaml

from meta_rg.training.grpo_trainer import GRPOConfig, GRPOTrainer


def parse_args():
    p = argparse.ArgumentParser(description="GRPO+LoRA fine-tuning on S2B")
    p.add_argument("--config", required=True, help="Path to GRPO YAML config")

    # Override individual fields
    p.add_argument("--role", choices=["speaker", "listener", "both"], default=None)
    p.add_argument("--counterpart", choices=["rule-based", "lm"], default=None)
    p.add_argument("--seed", type=int, default=None)
    p.add_argument("--num_steps", type=int, default=None)
    p.add_argument("--output_dir", type=str, default=None)
    p.add_argument("--wandb_project", type=str, default="meta-rg-s2b-grpo")
    p.add_argument("--o", type=int, default=None, dest="nbr_object_centric_samples")
    p.add_argument("--shots", type=int, default=None)
    p.add_argument("--domain", type=str, default=None,
                   choices=["SCS", "2D", "3D", "categorical", "pseudoword"],
                   help="Stimulus domain. 'categorical' and 'pseudoword' require --o 1. "
                        "'pseudoword' uses dynamically generated (CV)+ pseudowords.")

    # Eval-only mode
    p.add_argument("--eval_only", action="store_true")
    p.add_argument("--checkpoint", type=str, default=None,
                   help="Checkpoint dir for --eval_only")
    p.add_argument("--n_eval_games", type=int, default=200)

    return p.parse_args()


def load_cfg(path: str) -> dict:
    with open(path) as f:
        return yaml.safe_load(f) or {}


def build_grpo_config(raw: dict, args) -> GRPOConfig:
    """Merge YAML config with CLI overrides into a GRPOConfig dataclass."""
    flat = {}

    # Task section
    flat.update(raw.get("task", {}))
    # Model section
    flat.update(raw.get("model", {}))
    # Training section
    flat.update(raw.get("training", {}))
    # Top-level keys
    for k in ("seed", "output_dir", "wandb_project", "weave_project"):
        if k in raw:
            flat[k] = raw[k]

    # Shots → sampling_strategy
    shots = flat.pop("shots", None)
    if shots is not None:
        flat["sampling_strategy"] = f"component-focused-{shots}shot"

    # Build dataclass (only recognised fields)
    valid = {f.name for f in GRPOConfig.__dataclass_fields__.values()}
    filtered = {k: v for k, v in flat.items() if k in valid}
    cfg = GRPOConfig(**filtered)

    # CLI overrides
    for field_name in ("role", "counterpart", "seed", "num_steps", "output_dir",
                       "wandb_project", "nbr_object_centric_samples", "domain"):
        val = getattr(args, field_name, None)
        if val is not None:
            setattr(cfg, field_name, val)

    if args.shots is not None:
        cfg.sampling_strategy = f"component-focused-{args.shots}shot"

    return cfg


def main():
    args = parse_args()
    raw = load_cfg(args.config)
    cfg = build_grpo_config(raw, args)

    if args.eval_only:
        # Evaluate a saved checkpoint
        if args.checkpoint is None:
            raise ValueError("--eval_only requires --checkpoint")
        _eval_checkpoint(cfg, args.checkpoint, args.n_eval_games)
        return

    trainer = GRPOTrainer(cfg)
    trainer.train()


def _eval_checkpoint(cfg: GRPOConfig, checkpoint_dir: str, n_games: int) -> None:
    import copy
    import random

    import numpy as np
    import torch
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer

    from meta_rg.agents.rule_based import build_rule_based_agents
    from meta_rg.env_utils import build_env_from_cfg
    from meta_rg.training.grpo_trainer import _eval

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[eval] checkpoint={checkpoint_dir}  device={device}")

    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    base = AutoModelForCausalLM.from_pretrained(cfg.model_id, torch_dtype=torch.bfloat16)
    model = PeftModel.from_pretrained(base, checkpoint_dir).to(device).eval()

    counterpart_model = None
    if cfg.counterpart == "lm":
        counterpart_model = copy.deepcopy(model)
        counterpart_model.requires_grad_(False)
        counterpart_model.eval()

    task_cfg = {k: getattr(cfg, k) for k in (
        "nbr_latents", "nbr_distractors", "vocab_size", "max_sentence_length",
        "min_nbr_values_per_latent", "max_nbr_values_per_latent",
        "nbr_communication_rounds", "descriptive", "nbr_object_centric_samples",
        "provide_listener_feedback", "sampling_strategy",
    )}
    env = build_env_from_cfg(task_cfg, seed=cfg.seed)
    rb_speaker, rb_listener = build_rule_based_agents(
        vocab_size=cfg.vocab_size,
        max_sentence_length=cfg.max_sentence_length,
        nbr_communication_rounds=cfg.nbr_communication_rounds,
        nbr_latents=cfg.nbr_latents,
    )

    tr, te = _eval(env, model, tokenizer, cfg, device,
                   rb_speaker, rb_listener, counterpart_model,
                   n_games=n_games, base_seed=cfg.seed + 77777)

    print(f"\nResults over {n_games} games:")
    print(f"  Train accuracy : {tr:.1f}%")
    print(f"  ZSCT accuracy  : {te:.1f}%  (compositional generalisation)")


if __name__ == "__main__":
    main()
