import sys, types
try:
    import importlib
    importlib.import_module("torch._dynamo._trace_wrapped_higher_order_op")
except Exception:
    mod = types.ModuleType("torch._dynamo._trace_wrapped_higher_order_op")
    class TransformGetItemToIndex: pass
    mod.TransformGetItemToIndex = TransformGetItemToIndex
    sys.modules["torch._dynamo._trace_wrapped_higher_order_op"] = mod


import os, json, glob, argparse
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import time

from datetime import timedelta
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoTokenizer, AutoModelForCausalLM


def ddp_setup():
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"]); world = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.distributed.init_process_group(backend="nccl",timeout=timedelta(hours=2))
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
    else:
        rank, world, local_rank = 0, 1, 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return rank, world, local_rank, device

def is_main():
    return int(os.environ.get("RANK", "0")) == 0

def barrier():
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        torch.distributed.barrier()


def load_json_file(fp: str) -> List[Dict[str, Any]]:
    with open(fp, "r", encoding="utf-8") as f:
        data = json.load(f)
    if not isinstance(data, list):
        raise ValueError(f"{fp} must be a list of samples")
    return data

def load_jsonl_file(fp: str) -> List[Dict[str, Any]]:
    out = []
    with open(fp, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line: out.append(json.loads(line))
    return out


class OfflineGRPODataset(Dataset):

    def __init__(
        self,
        data_path: str,
        tokenizer: AutoTokenizer,
        max_prompt_len: int = 256,
        max_response_len: int = 256,
        prompt_field: str = "prompt",
        responses_field: str = "responses",
        reward_field: str = "reward",
        max_group: int = 4,
    ):
        self.tok = tokenizer
        self.maxP = max_prompt_len
        self.maxR = max_response_len
        self.prompt_field = prompt_field
        self.responses_field = responses_field
        self.reward_field = reward_field
        self.max_group = max_group
        self.items: List[Dict[str, Any]] = []

        if os.path.isdir(data_path):
            files = sorted(
                glob.glob(os.path.join(data_path, "**/*.json"), recursive=True) +
                glob.glob(os.path.join(data_path, "**/*.jsonl"), recursive=True)
            )
            if not files: raise FileNotFoundError(f"No json/jsonl in {data_path}")
            for fp in files:
                recs = load_json_file(fp) if fp.endswith(".json") else load_jsonl_file(fp)
                self._extend(recs, fp)
        else:
            try:
                recs = load_json_file(data_path)
            except Exception:
                recs = load_jsonl_file(data_path)
            self._extend(recs, data_path)

        if not self.items:
            raise ValueError("Empty dataset")

    def _extend(self, recs: List[Dict[str, Any]], src: str):
        for ex in recs:

            if self.prompt_field not in ex or self.responses_field not in ex or self.reward_field not in ex:
                raise KeyError(f"Missing '{self.prompt_field}', '{self.responses_field}', or '{self.reward_field}' in {src}")
            prompt = ex[self.prompt_field]
            responses = ex[self.responses_field]
            rewards = ex[self.reward_field]
            if not isinstance(responses, list) or len(responses) < 2:
                continue
            if len(responses) != len(rewards): 
                raise ValueError(f"Length of responses and rewards do not match in {src}")
            self.items.append({"prompt": prompt, "responses": responses, "rewards": rewards})

    def __len__(self): return len(self.items)

    def _prompt_text(self, user_prompt: str) -> str:

        return self.tok.apply_chat_template(
            [{"role": "user", "content": user_prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        ex = self.items[idx]
        responses = ex["responses"]
        if self.max_group is not None and len(responses) > self.max_group:
            responses = responses[: self.max_group]
        rewards_list = ex["rewards"]
        if self.max_group is not None and len(rewards_list) > self.max_group:
            rewards_list = rewards_list[: self.max_group]

        ptxt = self._prompt_text(ex["prompt"])
        p_tokens = self.tok(ptxt, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze(0).tolist()
        if len(p_tokens) > self.maxP:
            p_tokens = p_tokens[-self.maxP:]

        pr_list, attn_list, am_list = [], [], []
        for resp in responses:
            r_tokens = self.tok(resp, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze(0).tolist()
            r_tokens = r_tokens[: self.maxR]

            p_pad = self.maxP - len(p_tokens)
            r_pad = self.maxR - len(r_tokens)

            prompt_part = ([self.tok.pad_token_id] * p_pad) + p_tokens
            resp_part   = r_tokens + ([self.tok.pad_token_id] * r_pad)

            pr = torch.tensor(prompt_part + resp_part, dtype=torch.long) 
            attn = torch.tensor(([0]*p_pad + [1]*len(p_tokens)) + ([1]*len(r_tokens) + [0]*r_pad), dtype=torch.long)
            am   = torch.tensor(([0]*self.maxP) + ([1]*len(r_tokens) + [0]*r_pad), dtype=torch.long)

            pr_list.append(pr); attn_list.append(attn); am_list.append(am)


        rewards = torch.tensor(rewards_list, dtype=torch.float32)
        advantages = (rewards - rewards.mean()) / (rewards.std(unbiased=False) + 1e-8)

        return {
            "prompt_response_ids": torch.stack(pr_list, dim=0),
            "attention_mask":     torch.stack(attn_list, dim=0),
            "action_mask":        torch.stack(am_list, dim=0),
            "advantages": advantages,
        }

def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    pr, attn, am, adv, groups = [], [], [], [], []
    for ex in batch:
        pr.append(ex["prompt_response_ids"])
        attn.append(ex["attention_mask"])
        am.append(ex["action_mask"])
        adv.append(ex["advantages"])
        groups.append(ex["prompt_response_ids"].shape[0])
    return {
        "prompt_response_ids": torch.cat(pr, dim=0),
        "attention_mask": torch.cat(attn, dim=0),
        "action_mask": torch.cat(am, dim=0),
        "advantages": torch.cat(adv, dim=0),
        "group_sizes": groups,
    }

@dataclass
class GRPOArguments:
    output_dir: str = "./offline_grpo_full_ckpts"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    lr: float = 1e-6
    save_steps: int = 200
    epoch: int = 1
    num_generations: int = 0   
    max_prompt_length: int = 192
    max_generate_length: int = 160  
    clip_eps: float = 0.2
    gradient_accumulation_steps: int = 2
    num_iterations: int = 1
    batch_size: int = 1
    beta: float = 0.01       
    seed: int = 42
    fp16: bool = False
    bf16: bool = True
    max_group: int = 4
    prompt_field: str = "prompt"
    responses_field: str = "responses"
    reward_field: str = "reward"
    num_workers: int = 2


class GRPOTrainerOfflineFull:
    def __init__(
        self,
        model,
        args: GRPOArguments,
        train_dataset: Dataset,
        tokenizer: AutoTokenizer,
    ):
        self.args = args
        self.rank, self.world, self.local_rank, self.device = ddp_setup()
        torch.manual_seed(self.args.seed)
        if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.args.seed)

        self._t0 = time.time()
        self._last_update_time = self._t0
        self._ema_step_time = None

        self.tokenizer = tokenizer
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.padding_side = "left"

        self.model = model.to(self.device)
        if hasattr(self.model.config, "use_cache"): self.model.config.use_cache = False
        if hasattr(self.model, "gradient_checkpointing_enable"): self.model.gradient_checkpointing_enable()

        from copy import deepcopy
        self.ref_model = deepcopy(self.model).to(self.device)
        for p in self.ref_model.parameters(): p.requires_grad_(False)
        self.ref_model.eval()

        self.model = DDP(self.model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=False)

        self.train_dataset = train_dataset
        self.sampler = DistributedSampler(self.train_dataset, shuffle=True, drop_last=False) if self.world > 1 else None
        self.dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            sampler=self.sampler,
            shuffle=(self.sampler is None),
            num_workers=self.args.num_workers,
            pin_memory=True,
            persistent_workers=(self.args.num_workers > 0),
            prefetch_factor=(2 if self.args.num_workers > 0 else None),
            collate_fn=collate_fn,
        )

        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)

        self.input_buffer = [None] * self.args.gradient_accumulation_steps
        self.update_steps = 0

        if is_main(): os.makedirs(self.args.output_dir, exist_ok=True)

    def get_action_log_probs(self, model, input_ids, attention_mask, num_actions):
        mdl = model.module if isinstance(model, DDP) else model
        out = mdl(input_ids=input_ids, attention_mask=attention_mask)
        logits = out.logits
        log_probs = F.log_softmax(logits[:, :-1, :], dim=-1)
        labels = input_ids[:, 1:]
        tok_logp = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
        if num_actions > tok_logp.shape[1]:
            num_actions = tok_logp.shape[1]
        return tok_logp[:, -num_actions:]

    def generate_experiences(self, batch: Dict[str, Any]) -> Dict[str, torch.Tensor]:
        pr_ids = batch["prompt_response_ids"].to(self.device)   
        attn   = batch["attention_mask"].to(self.device)        
        am     = batch["action_mask"].to(self.device)          
        adv    = batch["advantages"].to(self.device)           
        A = am.shape[1]

        with torch.no_grad():
            old_lp = self.get_action_log_probs(self.ref_model, pr_ids, attn, A)

        return {
            "prompt_response_ids": pr_ids,
            "attention_mask": attn,
            "action_mask": am,
            "old_action_log_probs": old_lp,
            "advantages": adv,
        }


    def compute_loss(self, model, inputs):
        pr_ids = inputs["prompt_response_ids"]
        attn   = inputs["attention_mask"]
        am     = inputs["action_mask"]
        adv    = inputs["advantages"]
        old_lp = inputs["old_action_log_probs"]

        num_actions = am.shape[1]
        curr_lp = self.get_action_log_probs(model, pr_ids, attn, num_actions)

        ratio = torch.exp(curr_lp - old_lp)
        ratio_clipped = torch.clamp(ratio, 1 - self.args.clip_eps, 1 + self.args.clip_eps)
        per_tok1 = ratio * adv.unsqueeze(1)
        per_tok2 = ratio_clipped * adv.unsqueeze(1)
        per_tok = -torch.min(per_tok1, per_tok2)

        valid = am[:, -per_tok.shape[1]:]
        per_tok = per_tok * valid

        if self.args.beta and self.args.beta > 0.0:
            with torch.no_grad():
                ref_lp = self.get_action_log_probs(self.ref_model, pr_ids, attn, num_actions)
            log_ratio = ref_lp - curr_lp
            k3 = (log_ratio.exp() - 1.0 - log_ratio) * valid
            per_tok = per_tok + self.args.beta * k3

        denom = valid.sum(dim=1).clamp_min(1)
        loss = (per_tok.sum(dim=1) / denom).mean()

        with torch.no_grad():
            approx_kl = ((old_lp - curr_lp) * valid).sum() / denom.sum()

        if is_main() and not hasattr(self, "_dbg_once"):
            self._dbg_once = True
            print(f"[debug] valid_token_count={int(valid.sum().item())}, adv_mean={float(adv.mean().item()):.6f}, adv_std={float(adv.std(unbiased=False).item()):.6f}",flush=True)
        return loss, float(approx_kl.item())

    def train_step(self, model, inputs, optimizer, step):
        model.train()
        amp_dtype = (torch.bfloat16 if self.args.bf16 else (torch.float16 if self.args.fp16 else None))
        ctx = (torch.amp.autocast('cuda', dtype=amp_dtype) if amp_dtype is not None
               else torch.autocast("cuda", enabled=False))
        scaler = getattr(self, "_scaler", None)
        if scaler is None:
            self._scaler = torch.amp.GradScaler('cuda', enabled=(amp_dtype==torch.float16))
            scaler = self._scaler

        with ctx:
            loss, _ = self.compute_loss(model, inputs)
            loss = loss / self.args.gradient_accumulation_steps

        if amp_dtype == torch.float16:
            scaler.scale(loss).backward()
        else:
            loss.backward()

        if (step + 1) % self.args.gradient_accumulation_steps == 0:
            if amp_dtype == torch.float16:
                scaler.step(optimizer); scaler.update()
            else:
                optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            if is_main():
                print(f"step: {self.update_steps}/{self.global_steps}  grpo_loss: {loss.item():.8f}",flush=True)
        torch.cuda.empty_cache()

    def train(self):
        start = time.time()
        import math
        iters_per_epoch_per_rank = len(self.dataloader)
        updates_per_epoch_per_rank = math.ceil(iters_per_epoch_per_rank / self.args.gradient_accumulation_steps)
        self.global_steps = self.args.num_iterations * self.args.epoch * updates_per_epoch_per_rank

        if is_main():
            print(f"[info] len(dataset)={len(self.train_dataset)}, world_size={self.world}, "
                f"iters/epoch(per-rank)={iters_per_epoch_per_rank}, grad_accum={self.args.gradient_accumulation_steps}, "
                f"global_steps(per-rank)={self.global_steps}",flush=True)

        for ep in range(self.args.epoch):
            if self.sampler is not None: self.sampler.set_epoch(ep)
            for idx, batch in enumerate(self.dataloader):
                inputs = self.generate_experiences(batch)
                self.input_buffer[idx % self.args.gradient_accumulation_steps] = inputs

                if (idx + 1) % self.args.gradient_accumulation_steps == 0:
                    for _ in range(self.args.num_iterations):
                        for step, cached in enumerate(self.input_buffer):
                            self.train_step(self.model, cached, self.optimizer, step)
                        self.update_steps += 1


                        if is_main():
                            torch.cuda.synchronize()
                            now = time.time()
                            step_sec = now - self._last_update_time
                            self._last_update_time = now
                            if self._ema_step_time is None:
                                self._ema_step_time = step_sec
                            else:
                                self._ema_step_time = 0.9 * self._ema_step_time + 0.1 * step_sec
                            remain = max(0, self.global_steps - self.update_steps)
                            eta_sec = self._ema_step_time * remain
                            def _fmt(sec): 
                                h = int(sec // 3600); m = int((sec % 3600) // 60); s = int(sec % 60)
                                return f"{h:02d}:{m:02d}:{s:02d}"
                            print(f"[time] step {self.update_steps}/{self.global_steps} "
                                f"took {step_sec:.2f}s | avg {self._ema_step_time:.2f}s/step | ETA {_fmt(eta_sec)}", 
                                flush=True)


                        if is_main() and self.update_steps % self.args.save_steps == 0:
                            self.save_checkpoint(self.update_steps)
                del inputs

        if is_main():
            self.save_checkpoint(self.update_steps)
        barrier()
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            torch.distributed.destroy_process_group()

        end = time.time()
        if is_main():
            elapsed = end - start
            h, m, s = int(elapsed // 3600), int((elapsed % 3600) // 60), int(elapsed % 60)
            print(f"[train] finished in {h:02d}:{m:02d}:{s:02d} (h:m:s)", flush=True)

    def save_checkpoint(self, step: int):
        out = os.path.join(self.args.output_dir, f"checkpoint_{step}")
        os.makedirs(out, exist_ok=True)
        mdl = self.model.module if isinstance(self.model, DDP) else self.model
        mdl.save_pretrained(out)
        self.tokenizer.save_pretrained(out)


def build_parser():
    p = argparse.ArgumentParser()
    p.add_argument("--model_name_or_path", type=str, required=True)
    p.add_argument("--data_path", type=str, required=True)
    p.add_argument("--output_dir", type=str, default="./offline_grpo_full_ckpts")

    p.add_argument("--epochs", type=int, default=1)
    p.add_argument("--batch_size", type=int, default=1)
    p.add_argument("--grad_accum_steps", type=int, default=2)
    p.add_argument("--num_iterations", type=int, default=1)
    p.add_argument("--save_steps", type=int, default=200)
    p.add_argument("--lr", type=float, default=1e-6)
    p.add_argument("--beta", type=float, default=0.01)
    p.add_argument("--clip_eps", type=float, default=0.2)

    p.add_argument("--max_prompt_len", type=int, default=192)
    p.add_argument("--max_response_len", type=int, default=160)
    p.add_argument("--max_group", type=int, default=4)

    p.add_argument("--prompt_field", type=str, default="prompt")
    p.add_argument("--responses_field", type=str, default="responses")
    p.add_argument("--reward_field", type=str, default="reward")
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--num_workers", type=int, default=0)
    p.add_argument("--bf16", action="store_true")
    p.add_argument("--fp16", action="store_true")
    return p

def main():
    args_cli = build_parser().parse_args()

    args = GRPOArguments(
        output_dir=args_cli.output_dir,
        lr=args_cli.lr,
        save_steps=args_cli.save_steps,
        epoch=args_cli.epochs,
        max_prompt_length=args_cli.max_prompt_len,
        max_generate_length=args_cli.max_response_len,
        clip_eps=args_cli.clip_eps,
        gradient_accumulation_steps=args_cli.grad_accum_steps,
        num_iterations=args_cli.num_iterations,
        batch_size=args_cli.batch_size,
        beta=args_cli.beta,
        seed=args_cli.seed,
        bf16=args_cli.bf16,
        fp16=args_cli.fp16,
        max_group=args_cli.max_group,
        prompt_field=args_cli.prompt_field,
        responses_field=args_cli.responses_field,
        reward_field=args_cli.reward_field,
        num_workers=args_cli.num_workers,
    )

    torch_dtype = torch.bfloat16 if (args.bf16 and torch.cuda.is_available()) else (torch.float16 if args.fp16 else None)
    tokenizer = AutoTokenizer.from_pretrained(args_cli.model_name_or_path, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(args_cli.model_name_or_path, torch_dtype=torch_dtype)

    dataset = OfflineGRPODataset(
        data_path=args_cli.data_path,
        tokenizer=tokenizer,
        max_prompt_len=args.max_prompt_length,
        max_response_len=args.max_generate_length,
        prompt_field=args.prompt_field,
        responses_field=args.responses_field,
        reward_field=args.reward_field,
        max_group=args.max_group,
    )

    trainer = GRPOTrainerOfflineFull(
        model=model,
        args=args,
        train_dataset=dataset,
        tokenizer=tokenizer,
    )
    trainer.train()

if __name__ == "__main__":
    main()
