import os
import json
import argparse
from tqdm import tqdm
from datasets import load_dataset
from vllm import LLM, SamplingParams
from utils import extract_answer_math
from grader import math_equal
os.environ["NCCL_DEBUG"] = "WARN"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from huggingface_hub import login
from transformers import AutoTokenizer

login(" ... ")

def prepare_data(example, prompt_key, tokenizer):
    # Create messages in the standard format
    messages = [{
                    "role": "system", 
                    "content": (
                        "You are given a math problem. Solve it step by step. "
                        "Organize your thoughts using this format: Step 1: ..., Step 2: ..., Step 3: ..., and so on. "
                        "Put your final answer within \\boxed{{}}. "
                        "If you cannot solve the problem after 6 reasoning steps, stop reasoning and return: 'I need external assistance.' "
                    )
                },
                {
                    "role": "user",
                    "content": f"{example[prompt_key]}\n Let's think step by step."
                }]


    # Apply the model's specific chat template
    prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    example['prompt'] = prompt
    return example

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--datasets", type=str)
    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--max_tokens", type=int)
    parser.add_argument("--num_gpus", type=int, default=1)
    parser.add_argument("--output_dir", type=str)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_p", type=float, default=1.0)
    parser.add_argument("--top_k", type=int, default=-1)
    parser.add_argument("--num_generation", type=int, default=1)
    parser.add_argument("--dataset_num_proc", type=int, default=1)
    parser.add_argument("--resume_id", type=int, default=0)
    parser.add_argument("--comment", type=str, default="")
    args = parser.parse_args()
    
    model_path = args.model_path
    llm = LLM(model_path, tensor_parallel_size=args.num_gpus, dtype="bfloat16", gpu_memory_utilization=0.9, trust_remote_code=True)
    #llm = LLM(model="/home/jovyan/workspace/fang375/verl/examples/grpo_trainer/checkpoints/verl_grpo_example_math/llama_3_2_3b_function_rm_no_norm/global_step_700/actor/huggingface",
    #          tensor_parallel_size=args.num_gpus, 
    #          dtype="bfloat16", 
    #          trust_remote_code=True, 
    #          gpu_memory_utilization=0.9)
    sampling_params = SamplingParams(
        n=args.num_generation, 
        temperature=args.temperature, 
        top_p=args.top_p, 
        top_k=args.top_k,
        max_tokens=args.max_tokens,
    )

    # Load the dataset
    datasets = args.datasets.split(",")
    for dataset_name in datasets:
        dataset = load_dataset(dataset_name, split=args.split)
        #dataset = dataset.filter(lambda example: example['level'] in ['Level 1', 'Level 2', 'Level 3'])
        print(f"Starting from index {args.resume_id} out of {len(dataset)} examples.")
        dataset = dataset.select(range(args.resume_id, int(1*len(dataset))))
        
        # Debug: print available keys
        print(f"Available keys in dataset: {dataset.column_names}")
        
        if "lighteval" in dataset_name.lower():
            prompt_key = "problem"
            answer_key = "solution"
        elif "amc23" in dataset_name.lower() or "math-500" in dataset_name.lower():
            prompt_key = "problem"
            answer_key = "answer"
        elif "aime" in dataset_name.lower():
            prompt_key = "Problem"
            answer_key = "Answer"
        elif "minervamath" in dataset_name.lower():
            prompt_key = "question"
            answer_key = "answer"
        else:
            # Default fallback - try common keys
            if "problem" in dataset.column_names:
                prompt_key = "problem"
            elif "question" in dataset.column_names:
                prompt_key = "question"
            elif "Problem" in dataset.column_names:
                prompt_key = "Problem"
            else:
                raise ValueError(f"Could not determine prompt key for dataset {dataset_name}. Available keys: {dataset.column_names}")
            
            if "answer" in dataset.column_names:
                answer_key = "answer"
            elif "solution" in dataset.column_names:
                answer_key = "solution"
            elif "Answer" in dataset.column_names:
                answer_key = "Answer"
            else:
                raise ValueError(f"Could not determine answer key for dataset {dataset_name}. Available keys: {dataset.column_names}")
        
        print(f"Using prompt_key: {prompt_key}, answer_key: {answer_key}")
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        dataset = dataset.map(lambda x: prepare_data(x, prompt_key, tokenizer), num_proc=args.dataset_num_proc)

        # Extract model name from model path
        # Get the main model identifier from the path
        path_parts = model_path.rstrip('/').split('/')
        if len(path_parts) >= 2:
            # Try to get a meaningful name from the path
            # Look for patterns like "qwen2_5_1_5b_function_rm_no_norm_vanilla" or similar
            for part in reversed(path_parts):
                if any(keyword in part.lower() for keyword in ['qwen', 'llama', 'mistral', 'gemma', 'phi']):
                    model_name = part
                    break
            else:
                # Fallback to the second-to-last part if no model identifier found
                model_name = path_parts[-2] if len(path_parts) > 1 else path_parts[-1]
        else:
            model_name = path_parts[-1]
        
        output_file = f"{model_name}-{dataset_name.split('/')[-1]}-{args.split}-temp_{args.temperature}-top_p_{args.top_p}-top_k_{args.top_k}{args.comment}.jsonl"
        output_dir = args.output_dir
        local_rank = int(os.getenv("LOCAL_RANK", "0"))
        if local_rank == 0 and args.resume_id == 0 and os.path.exists(os.path.join(output_dir, output_file)):
            raise FileExistsError(f"Output file {output_file} already exists.")
        # Create a JSONL file to store the output
        correct = 0
        total = 0
        with open(os.path.join(output_dir, output_file), 'w' if args.resume_id == 0 else 'a') as f:
            for i in tqdm(range(0, len(dataset), args.batch_size)):
                batch = dataset[i:i + args.batch_size]
                inputs = batch["prompt"]
                answers = batch[answer_key]

                # Generate the answer
                outputs = llm.generate(inputs, sampling_params=sampling_params, use_tqdm=True)
                results = [[_.outputs[l].text for l in range(len(_.outputs))] for _ in outputs]
                assert len(results[0]) == args.num_generation, f"Number of generations is not equal to {args.num_generation}, got {len(results[0])}"

                # Prepare all outputs for batch tokenization
                flat_outputs = []
                output_mapping = []  # To map back to original indices
                
                for j in range(len(results)):
                    for k in range(args.num_generation):
                        flat_outputs.append(results[j][k])
                        output_mapping.append((j, k))

                # Process the results
                output_idx = 0
                for j, (inp, q, a, r) in enumerate(zip(inputs, batch[prompt_key], answers, results)):
                    solved = False
                    for k in range(args.num_generation):
                        qa_pair = {
                            "prompt": inp,
                            "vanilla_response": r[k],
                            "question": q,
                            "answer": a,
                            "question_id": args.resume_id + i + j,
                            "generation_id": k,
                        }
                        qa_pair["response"] = r[k]
                        output_idx += 1
                        if "lighteval" in dataset_name.lower():
                            gold_answer = extract_answer_math(a)
                            pred_answer = extract_answer_math(qa_pair["response"])
                        elif "amc23" in dataset_name.lower() or "aime" or "minervamath" or "math-500" in dataset_name.lower():
                            gold_answer = a
                            pred_answer = extract_answer_math(qa_pair["response"])
                        # qa_pair["label"] = pred_answer == gold_answer
                        qa_pair["label"] = math_equal(pred_answer, gold_answer, timeout=True)
                        qa_pair["gold_answer"] = gold_answer
                        qa_pair["pred_answer"] = pred_answer
                        if "I need external assistance." in qa_pair["response"]:
                            qa_pair["label"] = True
                        f.write(json.dumps(qa_pair) + '\n')
                        if qa_pair["label"]:
                            solved = True
                    total += 1
                    if solved:
                        correct += 1
                print(f"Accuracy for {dataset_name}: {correct}/{total} = {correct/total:.4f}")
                f.flush()

if __name__ == "__main__":
    main()
