
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import json
import argparse
import random
import wandb
from datetime import datetime
import warnings
import time


SYSTEM_PROMPT = """
Respond in the following format, with only the numerical answer between the <answer> tags:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer> 
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def legacy_extract_answer(text: str) -> str | None:
    text = text.replace(",", "")
    answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
    if m := re.search(answer_pattern, text):
        try:
            m_float = float(m.group(1))
            return m.group(1)
        except ValueError:
            pass

    if "####" in text:
        tail = text.split("####")[-1].strip()
        if m := re.search(r"([+-]?\d+(?:\.\d+)?)", tail):
            try:
                m_float = float(m.group(1))
                return m.group(1)
            except ValueError:
                pass

    # last-chunk heuristics
    parts = re.split(r"answer", text, flags=re.IGNORECASE)
    if len(parts) > 1:
        numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", parts[-1])
        if numbers:
            try:
                numbers_float = float(numbers[0])
                return numbers[0]
            except ValueError:
                pass

    lines = text.strip().splitlines()
    if lines:
        numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", lines[-1])
        if numbers:
            try:
                numbers_float = float(numbers[-1])
                return numbers[-1]
            except ValueError:
                pass
    return None

def extract_answer_from_text(text: str) -> str:
    if '<answer>' in text:
        return extract_xml_answer(text)
    else:
        legacy_answer = legacy_extract_answer(text)
        if legacy_answer is not None:
            return legacy_answer
        else:
            return text

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(dataset_path: str, start_idx: int | None, end_idx: int | None) -> Dataset:
    with open(dataset_path, "r") as f:
        data = [json.loads(line) for line in f]
    if end_idx is not None:
        data = data[:end_idx]
    if start_idx is not None:
        data = data[start_idx:]
    data = [{ # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt'].strip()}
        ],
        'answer': x['final_answer']
    } for x in data] # type: ignore
    return data # type: ignore

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_answer_from_text(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    rewards = []
    for r, a in zip(extracted_responses, answer):
        try:
            if abs(float(r) - float(a)) < 1e-9:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    return rewards

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def float_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer_from_text(r) for r in responses]
    results = []
    for r in extracted_responses:
        try:
            r_float = float(r)
            results.append(0.5)
        except ValueError:
            results.append(0.0)
    return results

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]


def main() -> None:
    parser = argparse.ArgumentParser(description="GSM8K evaluation with vLLM")
    parser.add_argument("--model", help="HF repo or local path")
    parser.add_argument("--dataset", help="Path to dataset")
    parser.add_argument("--start_idx", type=int, default=None)
    parser.add_argument("--end_idx", type=int, default=None)
    parser.add_argument("--output_dir", help="Path to output directory")
    parser.add_argument("--run_name", help="Name of the run", default=None)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
    parser.add_argument("--max_steps", type=int, default=None)
    parser.add_argument("--shuffle_dataset", action="store_true", default=False)
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--num_generations", type=int, default=16)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--max_prompt_length", type=int, default=512)
    parser.add_argument("--max_completion_length", type=int, default=1536)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--save_steps", type=int, default=None)
    parser.add_argument("--save_total_limit", type=int, default=None)
    parser.add_argument("--logging_steps", type=int, default=1)
    parser.add_argument("--float_reward_func", action="store_true", default=False)
    args = parser.parse_args()

    if args.run_name is None:
        args.run_name = args.output_dir.split("/")[-1] + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")

    wandb.init(
        project="long-horizon-reasoning",
        name=args.run_name,
    )

    dataset = get_gsm8k_questions(args.dataset, args.start_idx, args.end_idx)

    if args.max_steps is None:
        args.max_steps = len(dataset)

    if args.max_steps > len(dataset):
        warnings.warn(f"max steps is greater than the number of samples in the dataset. max steps: {args.max_steps}, dataset size: {len(dataset)}")

    if args.save_steps is None:
        args.save_steps = args.max_steps

    if args.save_total_limit is None:
        args.save_total_limit = args.max_steps

    training_args = GRPOConfig(
        output_dir=args.output_dir,
        run_name=args.run_name,
        learning_rate=args.learning_rate,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type="cosine",
        logging_steps=args.logging_steps,
        bf16=True,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,

        max_steps=args.max_steps,
        save_strategy="steps",
        save_steps=args.save_steps,

        save_total_limit=args.save_total_limit,
        save_only_model=True,

        max_grad_norm=0.1,
        report_to="wandb",
        log_on_each_node=False,
        loss_type='dr_grpo',
        shuffle_dataset=args.shuffle_dataset,
        seed=args.seed,
    )

    while True:
        try:
            model = AutoModelForCausalLM.from_pretrained(
                args.model,
                torch_dtype=torch.bfloat16,
                # attn_implementation="flash_attention_2",
                device_map=None
            ).to("cuda")
            break
        except Exception as e:
            print(f"Error loading model: {e}")
            time.sleep(60)
            
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token

    if args.float_reward_func:
        reward_funcs = [
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            float_reward_func,
            correctness_reward_func]
    else:
        reward_funcs = [
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func]

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()


if __name__ == "__main__":
    main()