# -*- coding: utf-8 -*-
import os
import math
import argparse
from dataclasses import dataclass
from typing import List

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOTrainer, GRPOConfig
import rewrite_reward


def build_gen_prompt(raw_prompt: str) -> str:
    return (raw_prompt or "").strip()


def make_reward_fn():
    def reward_fn(samples=None, prompts=None, completions=None, **kwargs) -> List[float]:
        outs = completions if (completions is not None and len(completions) > 0) else (samples or [])
        if prompts is None:
            prompts = []
        if len(prompts) == 0:
            prompts = [""] * len(outs)
        elif len(prompts) == 1 and len(outs) > 1:
            prompts = prompts * len(outs)
        elif len(prompts) != len(outs):
            if len(prompts) < len(outs):
                prompts = prompts + [""] * (len(outs) - len(prompts))
            else:
                prompts = prompts[:len(outs)]

        rewards: List[float] = []
        for out_str, prompt_text in zip(outs, prompts):
            try:
                r = rewrite_reward.compute_score(
                    data_source="math-rewrite",
                    solution_str=out_str,
                    ground_truth="",
                    extra_info={"prompt_text": prompt_text},
                )
                r = float(r)
                if math.isnan(r):
                    r = 0.0
            except Exception:
                r = 0.0
            rewards.append(r)

        if len(rewards) != len(outs):
            raise ValueError(
                f"Reward length {len(rewards)} != outputs length {len(outs)} "
                f"(prompts={len(prompts)}; samples={0 if samples is None else len(samples)}; "
                f"completions={0 if completions is None else len(completions)})"
            )
        return rewards

    return reward_fn


def get_model_and_tokenizer(model_name: str, load_in_8bit: bool = False, bf16: bool = True):
    torch_dtype = torch.bfloat16 if (bf16 and torch.cuda.is_available()) else torch.float16
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if load_in_8bit:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map=None, load_in_8bit=True, torch_dtype=torch_dtype
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name, device_map=None, torch_dtype=torch_dtype
        )
    return model, tokenizer


def build_peft_config(r: int = 8, alpha: int = 16, dropout: float = 0.05):
    return LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )


@dataclass
class ScriptArgs:
    model_name: str = ""
    dataset_path: str = ""
    output_dir: str = ""
    learning_rate: float = 3e-5
    weight_decay: float = 0.0
    warmup_ratio: float = 0.03
    max_steps: int = 500
    logging_steps: int = 10
    save_steps: int = 10
    per_device_batch_size: int = 2
    gradient_accumulation_steps: int = 4
    group_size: int = 4
    max_prompt_length: int = 512
    max_new_tokens: int = 4096
    temperature: float = 0.7
    top_p: float = 0.95
    kl_coef: float = 0.0
    lora_r: int = 64
    lora_alpha: int = 64
    lora_dropout: float = 0.05
    load_in_8bit: bool = False
    bf16: bool = True


def parse_args() -> ScriptArgs:
    p = argparse.ArgumentParser()
    # Required paths (no personal info; must be provided)
    p.add_argument("--model_name", type=str, required=True, help="HF model name or local path")
    p.add_argument("--dataset_path", type=str, required=True, help="Path to JSONL dataset with 'prompt' field")
    p.add_argument("--output_dir", type=str, required=True, help="Output directory")

    # Training knobs
    p.add_argument("--learning_rate", type=float, default=ScriptArgs.learning_rate)
    p.add_argument("--weight_decay", type=float, default=ScriptArgs.weight_decay)
    p.add_argument("--warmup_ratio", type=float, default=ScriptArgs.warmup_ratio)
    p.add_argument("--max_steps", type=int, default=ScriptArgs.max_steps)
    p.add_argument("--logging_steps", type=int, default=ScriptArgs.logging_steps)
    p.add_argument("--save_steps", type=int, default=ScriptArgs.save_steps)
    p.add_argument("--per_device_batch_size", type=int, default=ScriptArgs.per_device_batch_size)
    p.add_argument("--gradient_accumulation_steps", type=int, default=ScriptArgs.gradient_accumulation_steps)
    p.add_argument("--group_size", type=int, default=ScriptArgs.group_size)
    p.add_argument("--max_prompt_length", type=int, default=ScriptArgs.max_prompt_length)
    p.add_argument("--max_new_tokens", type=int, default=ScriptArgs.max_new_tokens)
    p.add_argument("--temperature", type=float, default=ScriptArgs.temperature)
    p.add_argument("--top_p", type=float, default=ScriptArgs.top_p)
    p.add_argument("--kl_coef", type=float, default=ScriptArgs.kl_coef)

    # LoRA
    p.add_argument("--lora_r", type=int, default=ScriptArgs.lora_r)
    p.add_argument("--lora_alpha", type=int, default=ScriptArgs.lora_alpha)
    p.add_argument("--lora_dropout", type=float, default=ScriptArgs.lora_dropout)

    # Precision / memory
    p.add_argument("--load_in_8bit", action="store_true")
    p.add_argument("--no_bf16", action="store_true")

    a = p.parse_args()
    s = ScriptArgs()
    for k, v in vars(a).items():
        if k == "no_bf16":
            setattr(s, "bf16", not v)
        else:
            setattr(s, k, v)
    return s


def main():
    args = parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    model, tokenizer = get_model_and_tokenizer(
        model_name=args.model_name, load_in_8bit=args.load_in_8bit, bf16=args.bf16
    )

    peft_config = build_peft_config(r=args.lora_r, alpha=args.lora_alpha, dropout=args.lora_dropout)

    ds = load_dataset("json", data_files=args.dataset_path, split="train")

    def _map_build_prompt(ex):
        return {"prompt": build_gen_prompt(ex["prompt"])}

    ds = ds.map(_map_build_prompt, remove_columns=[c for c in ds.column_names if c != "prompt"])

    reward_fn = make_reward_fn()

    config = GRPOConfig(
        output_dir=args.output_dir,
        learning_rate=args.learning_rate,
        weight_decay=args.weight_decay,
        warmup_ratio=args.warmup_ratio,
        max_steps=args.max_steps,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        importance_sampling_level="sequence",
        per_device_train_batch_size=args.per_device_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_prompt_length=args.max_prompt_length,
        temperature=args.temperature,
        top_p=args.top_p,
        num_generations=args.group_size,
        report_to=[],
        remove_unused_columns=False,
        generation_kwargs=dict(
            max_new_tokens=args.max_new_tokens,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        ),
    )

    # 6) trainer
    trainer = GRPOTrainer(
        model=model,
        args=config,          # if your TRL version expects 'config=', switch accordingly
        reward_funcs=[reward_fn],
        train_dataset=ds,
        peft_config=peft_config,
    )
    trainer.train()

    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    first_prompt = ds[0]["prompt"]
    inputs = tokenizer(first_prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        gen_ids = model.generate(
            **inputs,
            max_new_tokens=args.max_new_tokens,
            do_sample=True,
            temperature=args.temperature,
            top_p=args.top_p,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    print("\n===== SAMPLE OUTPUT =====")
    print(tokenizer.decode(gen_ids[0], skip_special_tokens=True))


if __name__ == "__main__":
    main()
