import re
import torch
import argparse
import json
import os
import shutil
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments
from trl import GRPOConfig, GRPOTrainer
from huggingface_hub import HfApi, login

# Import existing utilities from the MATH evaluation codebase
from data_loader import load_data
from parser import parse_question, parse_ground_truth
from reward_efficient import compare_answers_simple_batch
# Load and prep dataset

SYSTEM_PROMPT = """
Always respond in the following format, with only the final answer between the <answer> tags and always put your answer in boxed:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer> 
{answer}
</answer>
"""

def load_dataset_by_path(dataset_path: str) -> Dataset:
    with open(dataset_path, "r") as f:
        data = [json.loads(line) for line in f]
    data = [{
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt'].strip()}
        ],
        'answer': x['final_answer']
    } for x in data]
    return Dataset.from_list(data)

def get_math_questions(split="train", data_dir="./data") -> Dataset:
    """Load MATH dataset using the same approach as the evaluation code"""
    # Load MATH dataset using the evaluation code's data loader
    examples = load_data("math", split, data_dir)
    
    # Convert to the format expected by GRPO trainer
    processed_data = []
    for example in examples:
        question = parse_question(example, "math")
        _, gt_ans = parse_ground_truth(example, "math")
        
        processed_example = {
            'prompt': [
                {'role': 'system', 'content': SYSTEM_PROMPT},
                {'role': 'user', 'content': question}
            ],
            'answer': gt_ans
        }
        processed_data.append(processed_example)
    
    # Convert to Dataset
    return Dataset.from_list(processed_data)

# Reward functions
step_count = 0
correctness_reward_weight = None
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """Reward function that checks if the answer is mathematically correct."""
    
    global step_count
    # Extract all responses
    responses = [completion[0]['content'] for completion in completions]
    
    # Process all comparisons in batch with a single comparator
    comparison_results = compare_answers_simple_batch(responses, answer)

    # Convert boolean results to rewards
    rewards = [correctness_reward_weight if is_correct else 0.0 for is_correct in comparison_results]
    print('Step Count', step_count)
    step_count+=1
    if responses:
        q = prompts[0][-1]['content']
        print('-'*20, f"Question:\n{q}", f"\nGround Truth:\n{answer[0]}", 
              f"\nResponse:\n{responses[0]}", f"\nCorrect: {comparison_results[0]}")
    
    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] 
    return [0.5 if match else 0.0 for match in matches]

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """Simplified and stronger XML format rewards."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for text in contents:
        score = 0.0
        if "<reasoning>" in text and "</reasoning>" in text:
            score += 0.5
        if "<answer>" in text and "</answer>" in text:
            score += 1.0
        if "\\boxed{" in text:
            score += 0.25
            
        rewards.append(score)
    
    return rewards


if __name__ == "__main__":

    args = argparse.ArgumentParser()
    args.add_argument("--model", type=str)
    args.add_argument("--dataset", type=str)
    args.add_argument("--output_dir", type=str)
    args.add_argument("--run_name", type=str)
    args.add_argument("--max_steps", type=int)
    args.add_argument("--save_steps", type=int)
    args.add_argument("--learning_rate", type=float, default=5e-6)
    args.add_argument("--max_completion_length", type=int, default=8192)
    args.add_argument("--save_total_limit", type=int)
    args.add_argument("--keep_local_checkpoints", type=int, default=1, 
                      help="Number of local checkpoints to keep (default: 1)")
    args.add_argument("--per_device_train_batch_size", type=int, default=2)
    args.add_argument("--gradient_accumulation_steps", type=int, default=4)
    args.add_argument("--num_generations", type=int, default=8)
    args.add_argument("--max_prompt_length", type=int, default=1024)
    args.add_argument("--correctness_reward_weight", type=float, default=1.0)
    args.add_argument("--shuffle_dataset", type=bool, default=False)
    args.add_argument("--seed", type=int, default=42)

    args = args.parse_args()

    assert args.model is not None and args.dataset is not None, "Model and dataset names are required"

    api = HfApi()
    api.whoami()
    #model_name = "meta-llama/Llama-3.2-1B-Instruct"
    model_name = args.model

    output_dir = args.output_dir
    run_name = args.run_name

    if run_name is None:
        run_name = output_dir.split("/")[-1]

    if args.dataset == "math":
        dataset = get_math_questions()
    else:
       dataset = load_dataset_by_path(args.dataset)

    if args.max_steps is None:
        max_steps = len(dataset)
    else:
        max_steps = args.max_steps

    if args.save_steps is None:
        save_steps = max_steps
    else:
        save_steps = args.save_steps

    if args.save_total_limit is None:
        save_total_limit = max_steps
    else:
        save_total_limit = args.save_total_limit

    correctness_reward_weight = args.correctness_reward_weight

    hub_model_id = run_name

    if args.shuffle_dataset:
        dataset = dataset.shuffle(seed=args.seed)

    training_args = GRPOConfig(
        output_dir=output_dir,
        run_name=run_name,
        learning_rate=args.learning_rate,
        adam_beta1=0.9,
        adam_beta2=0.99,
        weight_decay=0.1,
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        logging_steps=1,
        bf16=True,
        per_device_train_batch_size=args.per_device_train_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        num_generations=args.num_generations,
        max_prompt_length=args.max_prompt_length,
        max_completion_length=args.max_completion_length,

        # ──────── new / changed ────────
        max_steps=max_steps,          # train for 200 update steps
        save_strategy="steps",  # save by step count
        save_steps=save_steps,         # …only when step == 200 (the last one)
        save_only_model=True,
        save_total_limit=args.save_total_limit,
        # ───────────────────────────────

        max_grad_norm=0.1,
        report_to=None,
        log_on_each_node=False,
        loss_type='dr_grpo',

        push_to_hub=True,                           # Enable pushing to hub
        hub_model_id=hub_model_id,            # Your HF model repo
        hub_strategy="all_checkpoints",                  # Push each checkpoint
        hub_token=None,                            # Uses logged-in token
        hub_private_repo=True,
        # hub_always_push=True

        shuffle_dataset=False,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        device_map=None
    ).to("cuda")
            
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token


    trainer = GRPOTrainer(
        model=model,
        processing_class=tokenizer,
        reward_funcs=[
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            correctness_reward_func],
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()

    # print("Training completed. Pushing final model to hub...")
    # try:
    #     model.push_to_hub(
    #         hub_model_id,
    #         commit_message="Final model after training",
    #         safe_serialization=True,
    #     )
    #     tokenizer.push_to_hub(
    #         hub_model_id,
    #         commit_message="Final tokenizer",
    #     )
    #     print("Final model successfully pushed to hub")
    # except Exception as e:
    #     print(f"Failed to push final model: {e}")
    #     print("Saving final model locally...")
    #     trainer.save_model(os.path.join(output_dir, "final_model"))



