import json
import torch
import argparse
import random
import os
import numpy as np
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

# ---------------- Argument Parsing ----------------
def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate solution quality using Qwen2.5-Math-RM.")
    
    parser.add_argument(
        "--model_path", 
        type=str, 
        default="Qwen/Qwen2.5-Math-RM-72B",
        help="Path or HuggingFace ID of the Reward Model."
    )
    parser.add_argument(
        "--input_file", 
        type=str, 
        required=True, 
        help="Path to the input JSON file containing 'problem' and 'solution' fields."
    )
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default="results", 
        help="Directory to save the scored results."
    )
    parser.add_argument(
        "--sample_size", 
        type=int, 
        default=1000, 
        help="Number of samples to evaluate (set to -1 for all data)."
    )
    parser.add_argument(
        "--seed", 
        type=int, 
        default=42, 
        help="Random seed for reproducibility."
    )
    
    return parser.parse_args()

# ---------------- Main Processing Function ----------------
def main():
    args = parse_args()
    
    # Set random seed for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Device configuration
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # ---------------- Load Model ----------------
    print(f"Loading Reward Model from: {args.model_path}...")
    try:
        model = AutoModel.from_pretrained(
            args.model_path,
            device_map="auto", 
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
        ).eval()

        tokenizer = AutoTokenizer.from_pretrained(
            args.model_path, 
            trust_remote_code=True
        )
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Failed to load model: {e}")
        return

    # ---------------- Data Preparation ----------------
    if not os.path.exists(args.input_file):
        print(f"Error: Input file not found at {args.input_file}")
        return

    print(f"Loading data from {args.input_file}...")
    with open(args.input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # Sample data if required
    if args.sample_size > 0 and len(data) > args.sample_size:
        print(f"Randomly sampling {args.sample_size} items (Seed: {args.seed})...")
        target_data = random.sample(data, args.sample_size)
    else:
        print(f"Using all {len(data)} items...")
        target_data = data

    results = []
    solution_lengths = [] 

    # ---------------- Inference Loop ----------------
    print("Starting inference...")
    for entry in tqdm(target_data, desc="Scoring"):
        # Ensure keys match your dataset format
        problem = entry.get('problem', '')
        solution = entry.get('solution', '')

        if not problem or not solution:
            continue
        
        # 1. Calculate Solution Token Length
        # add_special_tokens=False to count only text tokens
        sol_tokens = tokenizer.encode(solution, add_special_tokens=False)
        sol_len = len(sol_tokens)
        
        entry['solution_len'] = sol_len
        solution_lengths.append(sol_len)

        # 2. Construct Chat Template for Reward Model
        # Note: Qwen2.5-Math-RM expects specific system prompts
        chat = [
            {"role": "system", "content": "Please reason step by step, and put your final answer within \\boxed{}."},
            {"role": "user", "content": problem},
            {"role": "assistant", "content": solution}
        ]

        # 3. Tokenize and Format Input
        conversation_str = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=False
        )

        input_ids = tokenizer.encode(
            conversation_str,
            return_tensors="pt",
            add_special_tokens=False
        ).to(model.device)

        # 4. Inference (Compute Reward Score)
        with torch.no_grad():
            outputs = model(input_ids=input_ids)
            # The model returns logits; usually index 0 is the reward score
            score = outputs[0].item()

        entry['reward_score'] = score
        results.append(entry)

    # ---------------- Statistics & Saving ----------------
    if results:
        scores = [item['reward_score'] for item in results if 'reward_score' in item]
        
        print("\n" + "="*50)
        print("📊 Evaluation Statistics")
        print("="*50)
        
        if scores:
            print(f"--- Reward Score Stats ---")
            print(f"Mean Score      : {np.mean(scores):.4f}")
            print(f"Std Dev         : {np.std(scores):.4f}")
            print(f"Median          : {np.median(scores):.4f}")
            print(f"Max Score       : {np.max(scores):.4f}")
            print(f"Min Score       : {np.min(scores):.4f}")
        
        if solution_lengths:
            print(f"\n--- Solution Token Length Stats ---")
            print(f"Mean Length     : {np.mean(solution_lengths):.2f}")
            print(f"Std Dev         : {np.std(solution_lengths):.2f}")
            print(f"Median          : {np.median(solution_lengths):.2f}")
            print(f"Max Length      : {np.max(solution_lengths)}")
            print(f"Min Length      : {np.min(solution_lengths)}")
            
        print("="*50 + "\n")

    # Generate output path
    os.makedirs(args.output_dir, exist_ok=True)
    base_name = os.path.basename(args.input_file).replace('.json', '')
    output_file = os.path.join(args.output_dir, f"{base_name}_scored.json")

    print(f"Saving results to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, ensure_ascii=False, indent=4)
    
    print("Evaluation completed successfully.")

if __name__ == "__main__":
    main()