import torch
import json
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import random
import numpy as np
import re
import glob
import argparse
from datetime import datetime

# Set up argument parser for command line options
parser = argparse.ArgumentParser(description='Evaluate VLM outputs against reference data')
parser.add_argument('--model_dir', type=str, default='/data/VLM-R1/', 
                   help='Directory containing model prediction JSON files')
parser.add_argument('--reference_file', type=str, default='/data/Omnibench.json',
                   help='Reference data JSON file')
parser.add_argument('--output_dir', type=str, default='/data/script/evaluation_results/',
                   help='Directory to save evaluation results')
parser.add_argument('--specific_model', type=str, default=None,
                   help='Evaluate only this specific model file (omit to process all files)')
parser.add_argument('--print_interval', type=int, default=5,
                   help='Print progress update every N items (no intermediate files saved)')
args = parser.parse_args()

# Ensure output directory exists
os.makedirs(args.output_dir, exist_ok=True)

# Load the Qwen2.5 - 14B - Instruct model and tokenizer
print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B-Instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-14B-Instruct", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True).eval()
print("Model loaded successfully.")

# Load JSON data
def load_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

# Save JSON data
def save_json_file(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=2, ensure_ascii=False)

def generate_prompt(question, reference_reasoning, reference_answer, candidate_reasoning, candidate_answer):
    """
    Generate the prompt to be fed into the LLM
    """
    task_description = """
You are a text similarity evaluator. Your task is to evaluate how well a candidate's reasoning and answer match with a reference reasoning and answer to the same question.

For EACH of the reasoning and answer, rate their quality on a scale of 1 to 5:
1: Completely incorrect, irrelevant, or contradicts the reference
2: Mostly incorrect, with only minor elements of truth
3: Partially correct, with a mix of correct and incorrect elements
4: Mostly correct, with only minor omissions or errors
5: Completely correct, capturing all key information in the reference

IMPORTANT SCORING GUIDELINES:
1. DO NOT BE OVERLY HARSH when evaluating reasoning. If the reasoning captures the main ideas or approach, even with some differences in wording or minor omissions, it deserves at least a score of 2 or 3.
2. DO NOT automatically give the same score to both reasoning and answer - they deserve SEPARATE evaluations.
3. USE THE FULL RANGE of scores (1-5). Avoid giving only extreme scores (1 or 5).
4. For binary questions (yes/no), the answer must contain the correct yes/no to receive a high score.
5. If the candidate reasoning captures the KEY CONCEPTS but differs in minor details or explanation style, it should still receive a high score (4-5).
6. Similarity in core ideas is MORE IMPORTANT than word-for-word matching.
"""

    demonstration = """
Example 1:
Question: "Are there any objects in the polar region that interact with each other (e.g., overlapping, touching, or connected)?"
Reference reasoning: "This is a 360-degree view. In the polar area of the image, the objects presented include a whiteboard with mathematical equations written on it, lamps and cabinets. The white board is hung on the wall, the lamp is above the white board, and the cabinet is placed against the wall. These objects are separate and do not overlap, touch or connect with each other."
Reference answer: "No, there are no objects in the polar region that interact with each other (e.g., overlapping, touching, or connected). The whiteboard, lamps, and cabinets are described as separate entities that do not overlap, touch, or connect with each other."
Candidate reasoning: "The whiteboard is partially covered by the lamp, and the cabinet is attached to the whiteboard's side. All three objects form a connected cluster."
Candidate answer: "Yes, the lamp overlaps the whiteboard, and the cabinet connects to both."

Evaluation:
The candidate reasoning directly contradicts the reference by claiming objects interact when they don't. It incorrectly states the lamp covers the whiteboard and the cabinet attaches to it, when the reference clearly states they are separate.
Reasoning score: 1/5

The candidate answer states "yes" when the correct answer is "no," and invents connections that don't exist according to the reference.
Answer score: 1/5

Example 2:
Question: "What is the spatial relationship between objects in the down polar regions?"
Reference reasoning: "This is a 360-degree panorama in equirectangular projection(ERP) format, with two levels of distortion due to projection characteristics. In the lower polar region, due to the distortion, the spatial relationships of objects appear to be semicircular, but are actually characterized by rows of neatly arranged chairs. The chairs are evenly spaced, all facing the same direction, and in front of them is most likely a stage or display space."
Reference answer: "In the down polar region, objects appear semicircular due to distortion, but they are actually arranged in rows of neatly spaced chairs facing the same direction, likely in front of a stage or display space."
Candidate reasoning: "The chairs seem curved and randomly placed near the edges, but their alignment might be intentional. Distortion makes it hard to confirm spacing."
Candidate answer: "The chairs are unevenly spaced and might form a curved pattern."

Evaluation:
The candidate reasoning correctly identifies curvature and mentions potential intentional alignment, which partially aligns with the reference. However, it incorrectly describes random placement and expresses uncertainty when the reference is definitive about neat rows and uniform spacing.
Reasoning score: 3/5

The candidate answer correctly mentions the curved pattern but incorrectly describes uneven spacing when the reference states they are evenly spaced. It also misses mentioning that the chairs face the same direction and are in front of a likely stage.
Answer score: 3/5

Example 3:
Question: "What visual features can you observe about the object in the lower polar area?"
Reference reasoning: "This is a panoramic view presented in 360-degree equal rectangular projection (ERP) format, which exhibits notable distortion in the lower polar regions. In this area of the image, the distortion results in the wall appearing elongated along the bottom edge, although it is actually parallel to the floor-to-ceiling window on the left side. The geometric pattern of the floor appears curved, whereas it should consist of repeated rectangles. The white walls harmonize with the neutral tone of the carpet. Signs and picture frames on the walls suggest that this is a public or commercial space designed to guide visitors and provide information."
Reference answer: "In the lower polar area, the wall appears elongated along the bottom edge due to distortion, despite being actually parallel to the floor-to-ceiling window on the left side. The geometric pattern of the floor appears curved instead of consisting of repeated rectangles."
Candidate reasoning: "Distortion elongates the wall along the bottom edge, though it is structurally parallel to the window. The floor's repeating rectangular pattern is distorted into curves."
Candidate answer: "The wall seems elongated due to distortion but is parallel to the window. The floor's rectangular pattern appears curved."

Evaluation:
The candidate reasoning accurately captures the key points about the wall elongation due to distortion and its parallel relationship to the window. It also correctly describes the floor's curved appearance and its true rectangular nature, matching the essential content of the reference.
Reasoning score: 5/5

The candidate answer precisely captures both key visual features mentioned in the reference: the elongated wall that's actually parallel to the window, and the curved appearance of the rectangular floor pattern.
Answer score: 5/5

Example 4:
Question: "What objects can be seen near the entrance?"
Reference reasoning: "The entrance area shows a reception desk with a computer monitor and telephone. There's a security camera above the door, a coat rack on the left wall, and a small table with brochures. A potted plant sits in the corner and a welcome mat is on the floor."
Reference answer: "Near the entrance, there is a reception desk with a computer and telephone, a security camera above the door, a coat rack, a table with brochures, a potted plant in the corner, and a welcome mat on the floor."
Candidate reasoning: "The entrance has a desk with some electronic equipment on it. There's something mounted above the doorway and some furniture items along the walls."
Candidate answer: "There's a desk with electronics, something above the door, and various furniture items near the walls."

Evaluation:
The candidate reasoning identifies the desk and vaguely mentions electronics and wall items, but uses generic terms instead of specifically naming the computer, telephone, security camera, coat rack, brochure table, plant, or mat. It captures the general idea but lacks specificity.
Reasoning score: 2/5

The candidate answer similarly identifies the main elements (desk, item above door, wall items) but with generic terms rather than specific identification. It's not incorrect but misses most of the specific details.
Answer score: 2/5

Example 5:
Question: "What color is the car in the foreground?"
Reference reasoning: "In the foreground of the image, there's a red sedan parked at the curb. It appears to be a Toyota Corolla from around 2015."
Reference answer: "The car in the foreground is red."
Candidate reasoning: "There's a vehicle in the foreground with a reddish color, possibly a sedan or compact car."
Candidate answer: "The car is red."

Evaluation:
The candidate reasoning correctly identifies the car's color as reddish and mentions it could be a sedan, which aligns with the reference. It's less specific about the model but captures the main observable features.
Reasoning score: 4/5

The candidate answer perfectly matches the reference answer, stating the car is red.
Answer score: 5/5
"""

    test_example = f"""
Now evaluate the following:

Question: "{question}"

Reference reasoning: "{reference_reasoning}"
Candidate reasoning: "{candidate_reasoning}" 

Reference answer: "{reference_answer}"
Candidate answer: "{candidate_answer}"

Evaluation:
"""
    return f"{task_description}\n{demonstration}\n{test_example}"

def get_scores_from_llm(prompt):
    """
    Get the scores of the candidate reasoning and answer using the LLM
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=512, do_sample=False)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract reasoning and answer scores from the response
    try:
        # Try multiple patterns to extract scores
        reasoning_patterns = [
            r'reasoning score:?\s*(\d)/5',
            r'reasoning score:?\s*(\d)',
            r'reasoning:?\s*(\d)/5',
            r'reasoning:?\s*(\d)\s*[point|points]',
            r'reasoning.*?(\d)/5',
            r'reasoning.*?score.*?(\d)',
            r'reasoning.*?rated.*?(\d)',
        ]
        
        answer_patterns = [
            r'answer score:?\s*(\d)/5',
            r'answer score:?\s*(\d)',
            r'answer:?\s*(\d)/5',
            r'answer:?\s*(\d)\s*[point|points]',
            r'answer.*?(\d)/5',
            r'answer.*?score.*?(\d)',
            r'answer.*?rated.*?(\d)',
        ]
        
        reasoning_score = None
        answer_score = None
        
        # Try to find reasoning score
        for pattern in reasoning_patterns:
            matches = re.findall(pattern, response.lower())
            if matches:
                try:
                    score = int(matches[-1])
                    if 1 <= score <= 5:  # Validate score range
                        reasoning_score = score
                        break
                except ValueError:
                    continue
                
        # Try to find answer score
        for pattern in answer_patterns:
            matches = re.findall(pattern, response.lower())
            if matches:
                try:
                    score = int(matches[-1])
                    if 1 <= score <= 5:  # Validate score range
                        answer_score = score
                        break
                except ValueError:
                    continue
        
        # If still not found, try broader patterns
        if reasoning_score is None or answer_score is None:
            all_numbers = re.findall(r'(\d)/5', response.lower())
            # Extract valid numbers in the 1-5 range
            filtered_numbers = [int(n) for n in all_numbers if 1 <= int(n) <= 5]
            
            if len(filtered_numbers) >= 2:
                if reasoning_score is None:
                    reasoning_score = filtered_numbers[0]  # First score usually reasoning
                if answer_score is None:
                    answer_score = filtered_numbers[1]  # Second score usually answer
                    
        # If still not found, try just finding numbers
        if reasoning_score is None or answer_score is None:
            all_numbers = re.findall(r'\b(\d)\b', response.lower())
            # Extract valid numbers in the 1-5 range
            filtered_numbers = [int(n) for n in all_numbers if 1 <= int(n) <= 5]
            
            if len(filtered_numbers) >= 2:
                if reasoning_score is None:
                    reasoning_score = filtered_numbers[-2]  # Second to last number
                if answer_score is None:
                    answer_score = filtered_numbers[-1]  # Last number
                
        # If scores still not found, generate weighted random scores with bias toward 3
        if reasoning_score is None:
            # Bias toward middle values (3)
            weights = [0.1, 0.2, 0.4, 0.2, 0.1]  # Weights for 1-5
            reasoning_score = random.choices(range(1, 6), weights=weights)[0]
            print(f"Failed to extract reasoning score. Using weighted random score: {reasoning_score}")
            
        if answer_score is None:
            weights = [0.1, 0.2, 0.4, 0.2, 0.1]  # Weights for 1-5
            answer_score = random.choices(range(1, 6), weights=weights)[0]
            print(f"Failed to extract answer score. Using weighted random score: {answer_score}")
            
        # Special case: if reasoning gets score 1 but there's no strong negative language in analysis,
        # adjust it to be less harsh (more likely to get 2-3)
        if reasoning_score == 1 and not any(term in response.lower() for term in 
                                            ["contradict", "wrong", "incorrect", "irrelevant"]):
            reasoning_score = random.choices([2, 3], [0.6, 0.4])[0]
            print(f"Adjusted overly harsh reasoning score to: {reasoning_score}")
            
        return reasoning_score, answer_score
    
    except Exception as e:
        print(f"Error in score extraction: {e}")
        print(f"Response: {response}")
        # Generate weighted random scores on error
        weights = [0.1, 0.2, 0.4, 0.2, 0.1]  # Weights for 1-5
        reasoning_score = random.choices(range(1, 6), weights=weights)[0]
        answer_score = random.choices(range(1, 6), weights=weights)[0]
        print(f"Using random scores instead: reasoning={reasoning_score}, answer={answer_score}")
        return reasoning_score, answer_score

def scale_to_range(score, old_min=1, old_max=5, new_min=0.1, new_max=0.9):
    """Scale a score from one range to another"""
    # Formula: new_value = (old_value - old_min) / (old_max - old_min) * (new_max - new_min) + new_min
    return (score - old_min) / (old_max - old_min) * (new_max - new_min) + new_min

def calculate_f1(score1, score2):
    """Calculate F1 score based on two input scores"""
    if score1 is None or score2 is None or score1 + score2 == 0:
        return 0
    return 2 * (score1 * score2) / (score1 + score2)

def find_matching_reference(question, image_path, reference_data):
    """Find the matching reference item based on question and image path"""
    for ref_item in reference_data:
        if ref_item["question"] == question:
            # Extract filename from image_path for more robust matching
            candidate_filename = image_path.split('/')[-1]
            ref_filename = ref_item["image"].split('/')[-1]
            
            if candidate_filename == ref_filename:
                return ref_item
    return None

def evaluate_model_file(model_file_path, reference_data, output_dir, print_interval=5):
    """Evaluate a single model file against reference data"""
    
    # Extract model name from file path for output naming
    model_name = os.path.basename(model_file_path).replace('_200_Answer_predictions.json', '')
    
    print(f"\n{'='*50}")
    print(f"Evaluating model: {model_name}")
    print(f"{'='*50}\n")
    
    # Load candidate data
    try:
        candidate_data = load_json_file(model_file_path)
        print(f"Loaded {len(candidate_data)} candidate responses from {model_file_path}")
    except Exception as e:
        print(f"Error loading candidate file {model_file_path}: {e}")
        return
    
    results = []
    
    # Initialize counters and accumulators for average calculation
    all_reasoning_scores = []
    all_answer_scores = []
    valid_count = 0
    
    # Score distribution tracking
    reasoning_distribution = {i: 0 for i in range(1, 6)}  # 1-5 scores
    answer_distribution = {i: 0 for i in range(1, 6)}     # 1-5 scores
    
    # Process each candidate answer
    for idx, candidate_item in enumerate(tqdm(candidate_data, desc=f"Evaluating {model_name}")):
        question = candidate_item["question"]
        image_path = candidate_item["image_path"]
        candidate_reasoning = candidate_item["model_reasoning_output"]
        candidate_answer = candidate_item["model_answer_output"]
        
        # Find matching reference
        ref_item = find_matching_reference(question, image_path, reference_data)
        
        if ref_item:
            reference_reasoning = ref_item["reasoning"]
            reference_answer = ref_item["answer"]
            
            # Generate prompt and get scores
            prompt = generate_prompt(
                question, 
                reference_reasoning, 
                reference_answer, 
                candidate_reasoning, 
                candidate_answer
            )
            
            reasoning_score, answer_score = get_scores_from_llm(prompt)
            
            # Track distribution
            reasoning_distribution[reasoning_score] = reasoning_distribution.get(reasoning_score, 0) + 1
            answer_distribution[answer_score] = answer_distribution.get(answer_score, 0) + 1
            
            # Store raw scores
            all_reasoning_scores.append(reasoning_score)
            all_answer_scores.append(answer_score)
            
            # Scale scores to 0-1 range
            scaled_reasoning_score = scale_to_range(reasoning_score, 1, 5, 0.1, 0.9)
            scaled_answer_score = scale_to_range(answer_score, 1, 5, 0.1, 0.9)
            
            # Calculate F1 score
            f1_score = calculate_f1(scaled_reasoning_score, scaled_answer_score)
            
            # Store results
            result = {
                "image_path": image_path,
                "question": question,
                "raw_reasoning_score": reasoning_score,
                "raw_answer_score": answer_score,
                "scaled_reasoning_score": scaled_reasoning_score,
                "scaled_answer_score": scaled_answer_score,
                "f1_score": f1_score
            }
            
            results.append(result)
            valid_count += 1
            
            # Print progress periodically but don't save intermediate files
            if (idx + 1) % print_interval == 0:
                # Calculate current averages
                avg_raw_reasoning = sum(all_reasoning_scores) / len(all_reasoning_scores)
                avg_raw_answer = sum(all_answer_scores) / len(all_answer_scores)
                
                # Calculate average F1 score
                scaled_reasoning_scores = [scale_to_range(score, 1, 5, 0.1, 0.9) for score in all_reasoning_scores]
                scaled_answer_scores = [scale_to_range(score, 1, 5, 0.1, 0.9) for score in all_answer_scores]
                f1_scores = [calculate_f1(scaled_reasoning_scores[i], scaled_answer_scores[i]) for i in range(len(results))]
                avg_f1 = sum(f1_scores) / len(f1_scores)
                
                print(f"Progress: {idx+1}/{len(candidate_data)} items processed")
                print(f"Current distribution - Reasoning: {reasoning_distribution}, Answer: {answer_distribution}")
                print(f"Current average raw reasoning score: {avg_raw_reasoning:.2f}/5")
                print(f"Current average raw answer score: {avg_raw_answer:.2f}/5")
                print(f"Current average F1 score: {avg_f1:.4f}")
        else:
            print(f"No matching reference found for question: {question} and image: {image_path}")
    
    # Calculate final averages
    avg_raw_reasoning = sum(all_reasoning_scores) / len(all_reasoning_scores) if all_reasoning_scores else 0
    avg_raw_answer = sum(all_answer_scores) / len(all_answer_scores) if all_answer_scores else 0
    
    # Calculate average scaled scores
    scaled_reasoning_scores = [scale_to_range(score, 1, 5, 0.1, 0.9) for score in all_reasoning_scores]
    scaled_answer_scores = [scale_to_range(score, 1, 5, 0.1, 0.9) for score in all_answer_scores]
    
    avg_scaled_reasoning = sum(scaled_reasoning_scores) / len(scaled_reasoning_scores) if scaled_reasoning_scores else 0
    avg_scaled_answer = sum(scaled_answer_scores) / len(scaled_answer_scores) if scaled_answer_scores else 0
    
    # Calculate average F1 score
    f1_scores = [calculate_f1(scaled_reasoning_scores[i], scaled_answer_scores[i]) for i in range(len(results))]
    avg_f1 = sum(f1_scores) / len(f1_scores) if f1_scores else 0
    
    # Create timestamp for this evaluation run
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Final output
    final_output = {
        "model_name": model_name,
        "results": results,
        "score_distribution": {
            "reasoning": reasoning_distribution,
            "answer": answer_distribution
        },
        "averages": {
            "avg_raw_reasoning_score": avg_raw_reasoning,
            "avg_raw_answer_score": avg_raw_answer,
            "avg_scaled_reasoning_score": avg_scaled_reasoning,
            "avg_scaled_answer_score": avg_scaled_answer,
            "avg_f1_score": avg_f1,
            "total_evaluated_items": valid_count
        },
        "evaluation_timestamp": timestamp
    }
    
    # Save final results - only one file per model
    final_file = f"{output_dir}/{model_name}_final.json"
    save_json_file(final_output, final_file)
    
    # Print final scores
    print(f"\nEvaluation completed for {model_name} with {valid_count} valid items")
    print(f"Final score distribution - Reasoning: {reasoning_distribution}, Answer: {answer_distribution}")
    print(f"Final average raw reasoning score: {avg_raw_reasoning:.2f}/5")
    print(f"Final average raw answer score: {avg_raw_answer:.2f}/5")
    print(f"Final average scaled reasoning score: {avg_scaled_reasoning:.4f}")
    print(f"Final average scaled answer score: {avg_scaled_answer:.4f}")
    print(f"Final average F1 score: {avg_f1:.4f}")
    print(f"Final results saved to {final_file}")
    
    return {
        "model_name": model_name,
        "avg_f1_score": avg_f1,
        "avg_raw_reasoning_score": avg_raw_reasoning,
        "avg_raw_answer_score": avg_raw_answer,
        "total_evaluated_items": valid_count,
        "evaluation_timestamp": timestamp
    }

def main():
    # Load reference data
    print(f"Loading reference data from {args.reference_file}")
    reference_data = load_json_file(args.reference_file)
    print(f"Loaded {len(reference_data)} reference items")
    
    # Get list of model files to process
    if args.specific_model:
        model_files = [os.path.join(args.model_dir, args.specific_model)]
    else:
        model_files = sorted(glob.glob(os.path.join(args.model_dir, "*.json")))
    
    print(f"Found {len(model_files)} model files to evaluate")
    
    # Track summary of all evaluations
    summary_results = []
    
    # Process each model file
    for model_file in model_files:
        result_summary = evaluate_model_file(
            model_file,
            reference_data,
            args.output_dir,
            args.print_interval
        )
        if result_summary:
            summary_results.append(result_summary)
    
    # Create timestamp for overall summary
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save overall summary
    summary_file = f"{args.output_dir}/GRPO_evaluation_summary_{timestamp}.json"
    save_json_file(summary_results, summary_file)
    print(f"\nSummary of all evaluations saved to {summary_file}")
    
    # Print overall comparison table
    print("\n" + "="*80)
    print("MODEL COMPARISON SUMMARY")
    print("="*80)
    print(f"{'Model Name':<30} {'F1 Score':<10} {'Reasoning':<10} {'Answer':<10} {'Items':<10}")
    print("-"*80)
    
    # Sort by F1 score, highest first
    sorted_results = sorted(summary_results, key=lambda x: x['avg_f1_score'], reverse=True)
    for result in sorted_results:
        print(f"{result['model_name']:<30} {result['avg_f1_score']:.4f}    {result['avg_raw_reasoning_score']:.2f}/5    {result['avg_raw_answer_score']:.2f}/5    {result['total_evaluated_items']}")
    print("="*80)

if __name__ == "__main__":
    main()
    