"""
GRPO + LoRA training on the Symbolic Behaviour Benchmark.

Adapted from Regym/benchmark/GRPOApproach/SymbolicBehaviourBenchmark/train_grpo_s2b.py
with the following changes:
  - Self-contained class interface (GRPOTrainer) instead of top-level functions
  - Backends decoupled: trainer holds the HF model + tokenizer directly
  - Shared game-loop primitives from meta_rg.game_loop
  - Config passed as a dataclass instead of argparse.Namespace

The algorithm:
  For each training step:
    1. Rollout B×G games (B groups of G trajectories each)
    2. Normalize rewards within each group → advantages
    3. Policy gradient loss: -advantage * Σ log π(token)
    4. Optional KL penalty against frozen reference model
    5. Adam update with gradient clipping
"""
from __future__ import annotations

import copy
import os
import random
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm

from meta_rg.s2b_import import ensure_s2b_importable
from meta_rg.env_utils import build_env_from_cfg
from meta_rg.agents.rule_based import build_rule_based_agents
from meta_rg.game_loop import parse_action, run_one_game, is_test_mode
from meta_rg.env_utils import (
    get_prompt_text,
    no_op_action,
    ensure_action_shape,
    int_comm_to_ohe,
)

ensure_s2b_importable()


# ── Data structures ────────────────────────────────────────────────────────────

@dataclass
class Trajectory:
    prompts: List[str]
    gen_ids: List[torch.Tensor]
    reward: float


@dataclass
class GRPOConfig:
    # S2B task
    nbr_latents: int = 3
    nbr_distractors: int = 0
    vocab_size: int = 6
    max_sentence_length: int = 3
    min_nbr_values_per_latent: int = 2
    max_nbr_values_per_latent: int = 5
    nbr_communication_rounds: int = 1
    descriptive: bool = True
    nbr_object_centric_samples: int = 1
    provide_listener_feedback: bool = True
    sampling_strategy: Optional[str] = "component-focused-1shot"
    domain: str = 'SCS'

    # Model
    model_id: str = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.05
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])

    # Training
    role: str = "listener"
    counterpart: str = "rule-based"
    B: int = 4
    G: int = 4
    num_steps: int = 1000
    lr: float = 1e-4
    weight_decay: float = 0.0
    max_grad_norm: float = 1.0
    beta_kl: float = 0.0
    temperature: float = 0.7
    max_new_tokens: int = 256

    # Misc
    seed: int = 0
    output_dir: str = "outputs/grpo"
    log_interval: int = 10
    eval_interval: int = 100
    save_interval: int = 500
    wandb_project: Optional[str] = "meta-rg-s2b-grpo"
    weave_project: Optional[str] = None


# ── Model helpers ──────────────────────────────────────────────────────────────

def _build_model_and_tokenizer(cfg: GRPOConfig, device: torch.device):
    from peft import LoraConfig, get_peft_model
    from transformers import AutoModelForCausalLM, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(cfg.model_id, torch_dtype=torch.bfloat16)
    lora_cfg = LoraConfig(
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        target_modules=cfg.lora_target_modules,
        lora_dropout=cfg.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_cfg)
    model.to(device)

    ref_model = None
    if cfg.beta_kl > 0.0:
        ref_model = copy.deepcopy(model)
        ref_model.requires_grad_(False)
        ref_model.eval()
        ref_model.to(device)

    return model, tokenizer, ref_model


# ── Generation ────────────────────────────────────────────────────────────────

def _format_prompt(tokenizer, prompt_text: str) -> str:
    messages = [{"role": "user", "content": prompt_text}]
    return tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )


def _generate(model, tokenizer, prompt_text: str, cfg: GRPOConfig, device: torch.device):
    formatted = _format_prompt(tokenizer, prompt_text)
    inputs = tokenizer(formatted, return_tensors="pt", truncation=True).to(device)
    prompt_len = inputs.input_ids.shape[1]

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=cfg.max_new_tokens,
            do_sample=True,
            temperature=cfg.temperature,
            pad_token_id=tokenizer.eos_token_id,
        )

    gen_ids = output_ids[0, prompt_len:].cpu()
    gen_text = tokenizer.decode(gen_ids, skip_special_tokens=True)
    return gen_text, gen_ids


# ── Rollout ────────────────────────────────────────────────────────────────────

def _rollout_one_game(
    env, model, tokenizer, cfg: GRPOConfig, device: torch.device,
    rb_speaker, rb_listener, counterpart_model, seed: int
) -> Trajectory:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    env.seed(seed)
    obs, infos = env.reset()
    rb_speaker.reset()
    rb_listener.reset()

    step_data: List[Tuple[str, torch.Tensor]] = []
    final_reward = 0.0

    n_steps = cfg.nbr_communication_rounds + 1 + int(cfg.provide_listener_feedback)
    for _ in range(n_steps):
        current_round = infos[0]["round_idx"]

        if current_round == -1:  # feedback step
            a0 = no_op_action(cfg.max_sentence_length)
            a1 = no_op_action(cfg.max_sentence_length)
            obs, rewards, done, infos = env.step([a0, a1])
            final_reward = float(rewards[0])
            continue

        is_speaker_round = current_round == 0
        is_listener_round = current_round == cfg.nbr_communication_rounds

        # ── Speaker ──────────────────────────────────────────────────────────
        if cfg.role in ("speaker", "both") and is_speaker_round:
            pt = get_prompt_text(infos[0])
            gt, gids = _generate(model, tokenizer, pt, cfg, device)
            a0 = parse_action(gt, 0, cfg.nbr_distractors, cfg.descriptive,
                               cfg.vocab_size, cfg.max_sentence_length)
            step_data.append((pt, gids))
        elif cfg.role in ("speaker", "both"):
            a0 = no_op_action(cfg.max_sentence_length)
        elif cfg.counterpart == "lm" and counterpart_model is not None and is_speaker_round:
            pt = get_prompt_text(infos[0])
            gt, _ = _generate(counterpart_model, tokenizer, pt, cfg, device)
            a0 = parse_action(gt, 0, cfg.nbr_distractors, cfg.descriptive,
                               cfg.vocab_size, cfg.max_sentence_length)
        else:
            a0 = rb_speaker.next_action(state=obs[0], infos=infos[0])

        # ── Listener ─────────────────────────────────────────────────────────
        if cfg.role in ("listener", "both") and is_listener_round:
            pt = get_prompt_text(infos[1])
            gt, gids = _generate(model, tokenizer, pt, cfg, device)
            a1 = parse_action(gt, 1, cfg.nbr_distractors, cfg.descriptive,
                               cfg.vocab_size, cfg.max_sentence_length)
            step_data.append((pt, gids))
        elif cfg.role in ("listener", "both"):
            a1 = no_op_action(cfg.max_sentence_length)
        elif cfg.counterpart == "lm" and counterpart_model is not None and is_listener_round:
            pt = get_prompt_text(infos[1])
            gt, _ = _generate(counterpart_model, tokenizer, pt, cfg, device)
            a1 = parse_action(gt, 1, cfg.nbr_distractors, cfg.descriptive,
                               cfg.vocab_size, cfg.max_sentence_length)
        else:
            listener_infos = dict(infos[1])
            listener_infos["communication_channel"] = int_comm_to_ohe(
                obs[1]["communication_channel"], cfg.vocab_size, cfg.max_sentence_length
            )
            a1 = rb_listener.next_action(state=obs[1], infos=listener_infos)

        a0 = ensure_action_shape(a0, cfg.max_sentence_length)
        a1 = ensure_action_shape(a1, cfg.max_sentence_length)
        obs, rewards, done, infos = env.step([a0, a1])
        final_reward = float(rewards[0])

    return Trajectory(
        prompts=[p for p, _ in step_data],
        gen_ids=[g for _, g in step_data],
        reward=final_reward,
    )


# ── GRPO loss ─────────────────────────────────────────────────────────────────

def _token_log_probs(model, tokenizer, prompt_text: str, gen_ids: torch.Tensor, device):
    formatted = _format_prompt(tokenizer, prompt_text)
    inputs = tokenizer(formatted, return_tensors="pt", truncation=True).to(device)
    prompt_len = inputs.input_ids.shape[1]
    gen_ids_dev = gen_ids.to(device)
    gen_len = gen_ids_dev.shape[0]

    if gen_len == 0:
        return torch.tensor(0.0, device=device, requires_grad=model.training).unsqueeze(0)

    full_ids = torch.cat([inputs.input_ids, gen_ids_dev.unsqueeze(0)], dim=1)
    logits = model(full_ids).logits
    log_probs = F.log_softmax(logits.float(), dim=-1)
    gen_log_probs = log_probs[0, prompt_len - 1: prompt_len + gen_len - 1, :]
    return gen_log_probs.gather(-1, gen_ids_dev.view(-1, 1)).squeeze(-1)


def _grpo_loss(trajectories, B, G, model, tokenizer, ref_model, cfg, device):
    total = torch.tensor(0.0, device=device)
    n = 0
    for b in range(B):
        group = [t for (idx, t) in trajectories if idx == b]
        if not group:
            continue
        rewards = torch.tensor([t.reward for t in group], dtype=torch.float32, device=device)
        std = rewards.std()
        advs = (rewards - rewards.mean()) / (std + 1e-8) if std > 1e-8 else torch.zeros_like(rewards)

        for traj, adv in zip(group, advs):
            for prompt, gids in zip(traj.prompts, traj.gen_ids):
                lp = _token_log_probs(model, tokenizer, prompt, gids, device)
                total = total + (-adv * lp.sum())
                n += 1
                if cfg.beta_kl > 0.0 and ref_model is not None:
                    with torch.no_grad():
                        ref_lp = _token_log_probs(ref_model, tokenizer, prompt, gids, device)
                    total = total + cfg.beta_kl * (lp - ref_lp).mean()
                    n += 1

    return total / (B * G) if n > 0 else torch.tensor(0.0, device=device, requires_grad=True)


# ── Evaluation ────────────────────────────────────────────────────────────────

def _eval(env, model, tokenizer, cfg, device, rb_speaker, rb_listener, counterpart_model,
          n_games=100, base_seed=99999):
    model.eval()
    n_correct_train = n_train = n_correct_test = n_test = 0

    env.seed(base_seed)
    obs, infos = env.reset()
    rb_speaker.reset()
    rb_listener.reset()

    def lm_gen(pt):
        gt, _ = _generate(model, tokenizer, pt, cfg, device)
        return gt

    for _ in range(n_games):
        is_test = is_test_mode(infos[0])
        game_kw = dict(
            nbr_communication_rounds=cfg.nbr_communication_rounds,
            nbr_distractors=cfg.nbr_distractors,
            descriptive=cfg.descriptive,
            vocab_size=cfg.vocab_size,
            max_sentence_length=cfg.max_sentence_length,
            provide_listener_feedback=cfg.provide_listener_feedback,
        )
        # Assign lm to the correct role
        lm_gen_fn = lm_gen if cfg.role in ("listener", "both") else None
        rb_lis = rb_listener if cfg.role not in ("listener", "both") else None

        reward, done, obs, infos = run_one_game(
            env, obs, infos, rb_speaker, lm_gen_fn, rb_lis, **game_kw
        )
        if is_test:
            n_test += 1
            n_correct_test += int(reward > 0)
        else:
            n_train += 1
            n_correct_train += int(reward > 0)

        if done:
            env.seed(base_seed)
            obs, infos = env.reset()
            rb_speaker.reset()
            rb_listener.reset()

    return n_correct_train / max(n_train, 1) * 100, n_correct_test / max(n_test, 1) * 100


# ── GRPOTrainer ───────────────────────────────────────────────────────────────

class GRPOTrainer:
    def __init__(self, cfg: GRPOConfig) -> None:
        self.cfg = cfg
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        random.seed(cfg.seed)
        np.random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)

        task_cfg = {k: getattr(cfg, k) for k in (
            "nbr_latents", "nbr_distractors", "vocab_size", "max_sentence_length",
            "min_nbr_values_per_latent", "max_nbr_values_per_latent",
            "nbr_communication_rounds", "descriptive", "nbr_object_centric_samples",
            "provide_listener_feedback", "sampling_strategy", "domain",
        )}
        self.env = build_env_from_cfg(task_cfg, seed=cfg.seed)
        self.rb_speaker, self.rb_listener = build_rule_based_agents(
            vocab_size=cfg.vocab_size,
            max_sentence_length=cfg.max_sentence_length,
            nbr_communication_rounds=cfg.nbr_communication_rounds,
            nbr_latents=cfg.nbr_latents,
        )
        self.model, self.tokenizer, self.ref_model = _build_model_and_tokenizer(cfg, self.device)
        self.model.print_trainable_parameters()

        self.counterpart_model = None
        if cfg.counterpart == "lm":
            self.counterpart_model = copy.deepcopy(self.model)
            self.counterpart_model.requires_grad_(False)
            self.counterpart_model.eval()

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
        )

    def train(self) -> None:
        cfg = self.cfg
        use_wandb = cfg.wandb_project is not None
        if use_wandb:
            import wandb
            wandb.init(project=cfg.wandb_project, config=cfg.__dict__)

        os.makedirs(cfg.output_dir, exist_ok=True)
        pbar = tqdm(range(1, cfg.num_steps + 1), desc="GRPO", unit="step")

        for step in pbar:
            self.model.eval()
            trajectories = []
            for b in range(cfg.B):
                for g in range(cfg.G):
                    seed = cfg.seed + step * cfg.B * cfg.G + b * cfg.G + g
                    traj = _rollout_one_game(
                        self.env, self.model, self.tokenizer, cfg, self.device,
                        self.rb_speaker, self.rb_listener, self.counterpart_model, seed,
                    )
                    trajectories.append((b, traj))

            self.model.train()
            self.optimizer.zero_grad()
            loss = _grpo_loss(
                trajectories, cfg.B, cfg.G,
                self.model, self.tokenizer, self.ref_model, cfg, self.device,
            )
            if loss.requires_grad:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), cfg.max_grad_norm)
                self.optimizer.step()

            loss_val = loss.item()
            mean_reward = float(np.mean([t.reward for (_, t) in trajectories]))
            pbar.set_postfix({"loss": f"{loss_val:.4f}", "reward": f"{mean_reward:.3f}"})

            if step % cfg.log_interval == 0:
                tqdm.write(f"[step {step:05d}] loss={loss_val:.4f}  mean_reward={mean_reward:.3f}")
                if use_wandb:
                    import wandb
                    wandb.log({"loss": loss_val, "mean_reward": mean_reward, "step": step})

            if step % cfg.eval_interval == 0:
                n_eval = max(cfg.B * cfg.G * 4, 20)
                tr, te = _eval(
                    self.env, self.model, self.tokenizer, cfg, self.device,
                    self.rb_speaker, self.rb_listener, self.counterpart_model,
                    n_games=n_eval, base_seed=cfg.seed + 1_000_000 + step,
                )
                tqdm.write(f"  [eval {step}] train={tr:.1f}%  zsct={te:.1f}%")
                if use_wandb:
                    import wandb
                    wandb.log({"eval/train_acc": tr, "eval/zsct_acc": te, "step": step})

            if step % cfg.save_interval == 0:
                ckpt = os.path.join(cfg.output_dir, f"step_{step}")
                self.model.save_pretrained(ckpt)
                self.tokenizer.save_pretrained(ckpt)
                tqdm.write(f"  [save] → {ckpt}")

        tqdm.write("[train] done.")
