import os
import json
import argparse
import sys
from tqdm import tqdm
from datasets import load_dataset
from together import Together

# Add eval directory to Python path for imports
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from utils import extract_answer_math
from grader import math_equal

def prepare_data(example):
    instruction_following = "Let's think step by step and output the final answer."
    # Create messages in the standard format
    question = example["question"]
    question = question + " " + instruction_following
    messages = [{
        "role": "user",
        "content": question
    }]
    return messages

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="./eval/deepseek_r1")
    parser.add_argument("--max_tokens", type=int, default=4096)
    args = parser.parse_args()
    
    # Initialize Together client
    client = Together()  # auth defaults to os.environ.get("TOGETHER_API_KEY")
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load the MinervaMath dataset
    dataset = load_dataset("math-ai/minervamath", split="test")
    print(f"Loaded MinervaMath dataset with {len(dataset)} examples.")
    
    # Test with first 5 examples
    #dataset = dataset.select(range(5))
    #print(f"Testing with first 5 examples.")
    
    # Setup output file
    output_file = "deepseek-r1-minervamath-test-all.jsonl"
    output_path = os.path.join(args.output_dir, output_file)
    
    # Check if file already exists
    if os.path.exists(output_path):
        raise FileExistsError(f"Output file {output_file} already exists.")
    
    # Evaluation loop
    correct = 0
    total = 0
    
    with open(output_path, 'w') as f:
        for i, example in enumerate(tqdm(dataset)):
            print(f"\nProcessing example {i+1}/{len(dataset)}...")
            
            # Prepare the prompt
            messages = prepare_data(example)
            question = example["question"]
            answer = example["answer"]
            
            print(f"Question: {question[:100]}...")
            
            try:
                print("Making API call to DeepSeek-R1...")
                # Generate response using DeepSeek-R1
                response = client.chat.completions.create(
                    model="deepseek-ai/DeepSeek-R1",
                    messages=messages,
                    max_tokens=args.max_tokens
                )
                
                print("API call successful!")
                
                # Extract the response text
                response_text = response.choices[0].message.content
                print(f"Response: {response_text[:100]}...")
                
                # Create result entry
                result = {
                    "prompt": messages[0]["content"],
                    "response": response_text,
                    "question": question,
                    "answer": answer,
                    "question_id": i,
                    "generation_id": 0,
                }
                
                # Evaluate the answer
                gold_answer = answer
                pred_answer = extract_answer_math(response_text)
                result["label"] = math_equal(pred_answer, gold_answer, timeout=True)
                result["gold_answer"] = gold_answer
                result["pred_answer"] = pred_answer
                
                print(f"Predicted: {pred_answer}, Gold: {gold_answer}, Correct: {result['label']}")
                
                # Write to file
                f.write(json.dumps(result) + '\n')
                f.flush()
                
                # Update accuracy
                total += 1
                if result["label"]:
                    correct += 1
                
                # Print progress
                print(f"Current accuracy: {correct}/{total} = {correct/total:.4f}")
                    
            except Exception as e:
                print(f"Error processing example {i}: {e}")
                # Write error entry
                error_result = {
                    "prompt": messages[0]["content"],
                    "response": f"ERROR: {str(e)}",
                    "question": question,
                    "answer": answer,
                    "question_id": i,
                    "generation_id": 0,
                    "label": False,
                    "gold_answer": answer,
                    "pred_answer": "ERROR",
                    "error": str(e)
                }
                f.write(json.dumps(error_result) + '\n')
                f.flush()
                total += 1
    
    # Final results
    print(f"Final Results:")
    print(f"Total examples processed: {total}")
    print(f"Correct answers: {correct}")
    print(f"Accuracy: {correct/total:.4f}")
    print(f"Results saved to: {output_path}")

if __name__ == "__main__":
    main()