
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import GRPOConfig, GRPOTrainer
import json
import argparse
import random
import wandb
from datetime import datetime
from math500_eval import MATH500Evaluator

# Load and prep dataset

SYSTEM_PROMPT = """
Respond in the following format, with only the numerical answer between the <answer> tags:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer> 
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def legacy_extract_answer(text: str) -> str | None:
    text = text.replace(",", "")
    answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
    if m := re.search(answer_pattern, text):
        try:
            m_float = float(m.group(1))
            return m.group(1)
        except ValueError:
            pass

    if "####" in text:
        tail = text.split("####")[-1].strip()
        if m := re.search(r"([+-]?\d+(?:\.\d+)?)", tail):
            try:
                m_float = float(m.group(1))
                return m.group(1)
            except ValueError:
                pass

    # last-chunk heuristics
    parts = re.split(r"answer", text, flags=re.IGNORECASE)
    if len(parts) > 1:
        numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", parts[-1])
        if numbers:
            try:
                numbers_float = float(numbers[0])
                return numbers[0]
            except ValueError:
                pass

    lines = text.strip().splitlines()
    if lines:
        numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", lines[-1])
        if numbers:
            try:
                numbers_float = float(numbers[-1])
                return numbers[-1]
            except ValueError:
                pass
    return None

def extract_answer_from_text(text: str) -> str:
    if '<answer>' in text:
        return extract_xml_answer(text)
    else:
        legacy_answer = legacy_extract_answer(text)
        if legacy_answer is not None:
            return legacy_answer
        else:
            return text

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(dataset_path: str, start_idx: int | None, end_idx: int | None) -> Dataset:
    with open(dataset_path, "r") as f:
        data = [json.loads(line) for line in f]
    if end_idx is not None:
        data = data[:end_idx]
    if start_idx is not None:
        data = data[start_idx:]
    data = [{ # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt'].strip()}
        ],
        'answer': x['final_answer']
    } for x in data] # type: ignore
    return data # type: ignore

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_answer_from_text(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    rewards = []
    for r, a in zip(extracted_responses, answer):
        try:
            if abs(float(r) - float(a)) < 1e-9:
                rewards.append(2.0)
            else:
                rewards.append(0.0)
        except:
            rewards.append(0.0)
    return rewards

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

class ValidationCallback(TrainerCallback):
    def __init__(self, val_datasets, tokenizer, model, validation_interval=10):
        self.val_datasets = val_datasets
        self.tokenizer = tokenizer
        self.model = model
        self.validation_interval = validation_interval
    
    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # Run validation every N steps
        if state.global_step % self.validation_interval == 0 and state.global_step > 0:
            print(f"\n{'='*50}")
            print(f"Running custom validation at step {state.global_step}")
            print(f"{'='*50}")
            
            # Your custom validation logic here
            self.run_validation(state.global_step)
        
        return control
    
    def run_validation(self, step):
        self.model.eval()
        
        with torch.no_grad():
            for val_dataset in self.val_datasets:
                evaluator = MATH500Evaluator(
                    model_name="",
                    model = self.model,
                    tensor_parallel_size=tp,
                    is_instruct=instruct,
                    temperatures=temperatures,
                    num_generations=num_generations,
                    seed=seed,
                    boxed_system_prompt=boxed_system_prompt,
                    llama_system_prompt=llama_system_prompt,
                )
                results = {}
                for dataset_name, dataset in datasets.items():
                    print(f"\nEvaluating {model_name} on {dataset_name} with {num_samples_by_dataset[dataset_name]} samples …")
                    add_intruction_prefix = (not (dataset_name.endswith(".json") or dataset_name.endswith(".jsonl")))
                    print(f"Adding instruction prefix: {add_intruction_prefix}")
                    results[dataset_name] = evaluator.evaluate_dataset(dataset, num_samples=num_samples_by_dataset[dataset_name], seed=seed, start_idx=start_idx, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, add_intruction_prefix=add_intruction_prefix)
                    print(f"Results for {model_name} on {dataset_name}:")
                    print(f"Total samples:   {results[dataset_name]['total']}")
                    print(f"Exact match correct: {results[dataset_name]['exact_match_correct']}")
                    print(f"Math verify correct: {results[dataset_name]['math_verify_correct']}")
                    print(f"No answer:       {results[dataset_name]['no_answer']}")
                    print(f"Exact match accuracy: {results[dataset_name]['exact_match_accuracy']*100:.2f}%")
                    print(f"Answer rate:     {results[dataset_name]['answer_rate']*100:.2f}%")
                    print(f"Exact match pass @ {num_generations * len(temperatures)}: {results[dataset_name]['exact_match_pass_at_n']*100:.2f}%")

        
        # Log to wandb
        wandb.log({
            "validation/accuracy": accuracy,
            "validation/correct": total_correct,
            "validation/total": total_samples,
            "validation/step": step
        })
        
        print(f"Validation Accuracy: {accuracy:.2%} ({total_correct}/{total_samples})")
        
        self.model.train()

def main() -> None:
    parser = argparse.ArgumentParser(description="GSM8K evaluation with vLLM")
    parser.add_argument("--model", help="HF repo or local path")
    parser.add_argument("--dataset", help="Path to dataset")
    parser.add_argument("--start_idx", type=int, default=None)
    parser.add_argument("--end_idx", type=int, default=None)
    parser.add_argument("--output_dir", help="Path to output directory")
    parser.add_argument("--run_name", help="Name of the run", default=None)
    parser.add_argument("--learning_rate", type=float, default=5e-6)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
    parser.add_argument("--max_steps", type=int, default=None)
    parser.add_argument("--shuffle_dataset", action="store_true", default=False)
    parser.add_argument("--per_device_train_batch_size", type=int, default=1)
    parser.add_argument("--num_generations", type=int, default=16)
    parser.add_argument("--warmup_ratio", type=float, default=0.1)
    parser.add_argument("--max_prompt_length", type=int, default=512)
    parser.add_argument("--max_completion_length", type=int, default=1536)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--save_steps", type=int, default=None)
    parser.add_argument("--save_total_limit", type=int, default=None)
    parser.add_argument("--logging_steps", type=int, default=10)
    parser.add_argument("--target_accuracies", type=list[float], default=None)
    parser.add_argument("--val_datasets_size", type=int, default=500)
    args = parser.parse_args()

    if args.run_name is None:
        args.run_name = args.output_dir.split("/")[-1] + "_" + datetime.now().strftime("%Y%m%d_%H%M%S")

    if args.target_accuracies is None:
        print("No target accuracies provided. Please provide a list of target accuracies.")
        return
    else:
        target_accuracies = args.target_accuracies

    wandb.init(
        project="long-horizon-reasoning",
        name=args.run_name,
    )

    dataset = get_gsm8k_questions(args.dataset, args.start_idx, args.end_idx)

    if args.max_steps is None:
        args.max_steps = len(dataset)

    assert args.max_steps <= len(dataset), f"max steps must be less than or equal to the number of samples in the dataset. max steps: {args.max_steps}, dataset size: {len(dataset)}"

    if args.save_steps is None:
        args.save_steps = args.max_steps

    if args.save_total_limit is None:
        args.save_total_limit = args.max_steps

    training_args = GRPOConfig(
        output_dir=args.output_dir,
        run_name=args.run_name,
        learning_rate=args.learning_rate,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=args.warmup_ratio,
        lr_scheduler_type="cosine",
        logging_steps=args.logging_steps,
        bf16=True,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,

        # ──────── new / changed ────────
        max_steps=args.max_steps,          # train for 200 update steps
        save_strategy="steps",  # save by step count
        save_steps=args.save_steps,  # …only when step == 200 (the last one)

        save_total_limit=args.save_total_limit,     # keep only the most recent checkpoint
        save_only_model=True,   # skip optimizer/scheduler states
        # save_safetensors=True,  # write .safetensors (safer, often faster)
        # ───────────────────────────────

        max_grad_norm=0.1,
        report_to="wandb",
        log_on_each_node=False,
        loss_type='dr_grpo',
        shuffle_dataset=args.shuffle_dataset,
        seed=args.seed,
    )

    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        device_map=None
    ).to("cuda")
            
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token

    val_datasets = []
    for dataset_name in ["gsm8k_test_processed", "gsm8k_test_1000_2_sub", "gsm8k_test_1000_3_sub", "gsm8k_test_1000_4_sub", "gsm8k_test_1000_5_sub"]:
        val_datasets.append(get_gsm8k_questions(f"datasets/{dataset_name}.jsonl", 0, args.val_datasets_size))

    validation_callback = ValidationCallback(
        val_datasets=val_datasets,
        tokenizer=tokenizer,
        model=model,
        validation_interval=10
    )

    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func],
        args=training_args,
        train_dataset=dataset,
        callbacks=[validation_callback],
    )
    trainer.train()


if __name__ == "__main__":
    main()