import torch
import json
from tqdm import tqdm
import os
import random
import numpy as np
import re
import glob
import argparse
from datetime import datetime


# Set up argument parser for command line options
parser = argparse.ArgumentParser(description='Evaluate VLM outputs against reference data')
parser.add_argument('--model_dir', type=str, default='/root/temp_json/', 
                   help='Directory containing model prediction JSON files')
parser.add_argument('--reference_file', type=str, default='/root/OmniBench.json',
                   help='Reference data JSON file')
parser.add_argument('--output_dir', type=str, default='/root/',
                   help='Directory to save evaluation results')
parser.add_argument('--specific_model', type=str, default=None,
                   help='Evaluate only this specific model file (omit to process all files)')
parser.add_argument('--print_interval', type=int, default=5,
                   help='Print progress update every N items (no intermediate files saved)')
args = parser.parse_args()

# Ensure output directory exists
os.makedirs(args.output_dir, exist_ok=True)

# 使用ModelScope加载模型
print("Loading tokenizer and model from ModelScope...")
from modelscope import AutoModelForCausalLM, AutoTokenizer, snapshot_download

# 不再指定缓存目录
model_dir_path = snapshot_download('qwen/Qwen2.5-14B-Instruct')
tokenizer = AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True)

# 检查模型参数设置
device_map = "auto"
torch_dtype = torch.float16
try:
    model = AutoModelForCausalLM.from_pretrained(
        model_dir_path,
        device_map=device_map,
        trust_remote_code=True,
        torch_dtype=torch_dtype
    ).eval()
    print("Model loaded successfully from ModelScope.")
except Exception as e:
    print(f"Error loading model: {e}. Please check the model parameters (device_map: {device_map}, torch_dtype: {torch_dtype}).")
    exit(1)

print("Model loaded successfully from ModelScope.")

# Load JSON data
def load_json_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return json.load(file)

# Save JSON data
def save_json_file(data, file_path):
    with open(file_path, 'w', encoding='utf-8') as file:
        json.dump(data, file, indent=2, ensure_ascii=False)

def generate_prompt(question, reference_reasoning, reference_answer, candidate_reasoning, candidate_answer):
    """
    Generate a highly structured prompt with explicit instructions
    """
    task_description = """You are a text similarity evaluator. Your task is to compare a candidate's reasoning and answer with reference reasoning and answer. 
Rate each on a scale of 1-5 (1=completely incorrect, 5=completely correct).
IMPORTANT: Your response must ONLY contain these two lines:
REASONING_SCORE: [1/2/3/4/5]
ANSWER_SCORE: [1/2/3/4/5]
Do not add any explanation or other text."""


    comparison = f"""
Question: {question}

Reference reasoning: {reference_reasoning}
Candidate reasoning: {candidate_reasoning}

Reference answer: {reference_answer}
Candidate answer: {candidate_answer}

Remember, respond ONLY with:
REASONING_SCORE: [1/2/3/4/5]
ANSWER_SCORE: [1/2/3/4/5]"""

    return f"{task_description}\n{comparison}"

def get_scores_from_llm(prompt):
    """
    Get scores using a simplified approach with more robust extraction
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        response_ids = model.generate(
            **inputs,
            max_new_tokens=100,  # Very short response needed
            do_sample=True,      # Enable sampling for more varied outputs
            temperature=0.1,     # Very low temperature for deterministic-like output
            top_p=0.9,           # Restrict to highly probable tokens
            repetition_penalty=1.2  # Mild repetition penalty
        )
    
    # Get generated text
    generated_ids = response_ids[0][inputs.input_ids.shape[1]:]
    response_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Print first 100 chars of response for debugging
    print(f"Model response (first 100 chars): {response_text[:100]}")
    
    # Extract scores using pattern matching
    reasoning_score = None
    answer_score = None
    
    # Primary pattern: look for the exact format requested
    reasoning_match = re.search(r'REASONING_SCORE:\s*(\d)', response_text)
    answer_match = re.search(r'ANSWER_SCORE:\s*(\d)', response_text)
    
    if reasoning_match:
        try:
            score = int(reasoning_match.group(1))
            if 1 <= score <= 5:
                reasoning_score = score
        except ValueError:
            pass
    
    if answer_match:
        try:
            score = int(answer_match.group(1))
            if 1 <= score <= 5:
                answer_score = score
        except ValueError:
            pass
    
    # Backup extraction: find all numbers
    if reasoning_score is None or answer_score is None:
        numbers = re.findall(r'\b[1-5]\b', response_text)
        numbers = [int(n) for n in numbers if 1 <= int(n) <= 5]
        
        if reasoning_score is None and len(numbers) > 0:
            reasoning_score = numbers[0]
        
        if answer_score is None and len(numbers) > 1:
            answer_score = numbers[1]
        elif answer_score is None and len(numbers) == 1 and reasoning_score != numbers[0]:
            answer_score = numbers[0]
    
    # Last resort: use weighted random scores
    if reasoning_score is None:
        weights = [0.1, 0.2, 0.4, 0.2, 0.1]
        reasoning_score = random.choices(range(1, 6), weights=weights)[0]
        print(f"Failed to extract reasoning score. Using weighted random score: {reasoning_score}")
    
    if answer_score is None:
        weights = [0.1, 0.2, 0.4, 0.2, 0.1]
        answer_score = random.choices(range(1, 6), weights=weights)[0]
        print(f"Failed to extract answer score. Using weighted random score: {answer_score}")
    
    return reasoning_score, answer_score

def scale_to_range(score, old_min=1, old_max=5, new_min=0.1, new_max=0.9):
    """Scale a score from one range to another"""
    if score is None: return new_min # Handle None case, maybe return lowest score
    return (score - old_min) / (old_max - old_min) * (new_max - new_min) + new_min

def calculate_f1(score1, score2):
    """Calculate F1 score based on two input scores"""
    if score1 is None or score2 is None or (score1 + score2) == 0:
        return 0.0
    return 2 * (score1 * score2) / (score1 + score2)

def find_matching_reference(question, image_path, reference_data):
    """Find the matching reference item based on question and image path"""
    for ref_item in reference_data:
        if ref_item["question"] == question:
            candidate_filename = os.path.basename(image_path)
            ref_filename = os.path.basename(ref_item["image"])
            if candidate_filename == ref_filename:
                return ref_item
    return None

def evaluate_model_file(model_file_path, reference_data, output_dir, print_interval=5):
    model_name = os.path.basename(model_file_path).replace('_200predictions.json', '')
    
    print(f"\n{'='*50}")
    print(f"Evaluating model: {model_name}")
    print(f"{'='*50}\n")
    
    try:
        candidate_data = load_json_file(model_file_path)
        print(f"Loaded {len(candidate_data)} candidate responses from {model_file_path}")
    except Exception as e:
        print(f"Error loading candidate file {model_file_path}: {e}")
        return None # Return None if candidate file cannot be loaded
    
    results = []
    all_reasoning_scores = []
    all_answer_scores = []
    valid_count = 0
    
    reasoning_distribution = {i: 0 for i in range(1, 6)}
    answer_distribution = {i: 0 for i in range(1, 6)}
    
    for idx, candidate_item in enumerate(tqdm(candidate_data, desc=f"Evaluating {model_name}")):
        question = candidate_item.get("question")
        image_path = candidate_item.get("image_path")
        candidate_reasoning = candidate_item.get("model_reasoning_output")
        candidate_answer = candidate_item.get("answer") 
        
        if not all([question, image_path, candidate_reasoning is not None, candidate_answer is not None]):
            print(f"Skipping item {idx+1} due to missing fields: Q={question is not None}, I={image_path is not None}, CR={candidate_reasoning is not None}, CA={candidate_answer is not None}")
            continue

        ref_item = find_matching_reference(question, image_path, reference_data)
        
        if ref_item:
            reference_reasoning = ref_item.get("reasoning")
            reference_answer = ref_item.get("answer")

            if reference_reasoning is None or reference_answer is None:
                print(f"Skipping item {idx+1} due to missing reference reasoning/answer for Q: {question}, I: {image_path}")
                continue
            
            # Try to trim very long inputs
            max_len = 800  # Limiting input length
            if len(reference_reasoning) > max_len:
                reference_reasoning = reference_reasoning[:max_len] + "..."
            if len(reference_answer) > max_len:
                reference_answer = reference_answer[:max_len] + "..."
            if len(candidate_reasoning) > max_len:
                candidate_reasoning = candidate_reasoning[:max_len] + "..."
            if len(candidate_answer) > max_len:
                candidate_answer = candidate_answer[:max_len] + "..."
            
            prompt = generate_prompt(
                question, 
                reference_reasoning, 
                reference_answer, 
                candidate_reasoning, 
                candidate_answer
            )
            
            reasoning_score, answer_score = get_scores_from_llm(prompt)
            
            if reasoning_score is not None and answer_score is not None:
                reasoning_distribution[reasoning_score] = reasoning_distribution.get(reasoning_score, 0) + 1
                answer_distribution[answer_score] = answer_distribution.get(answer_score, 0) + 1
                
                all_reasoning_scores.append(reasoning_score)
                all_answer_scores.append(answer_score)
                
                scaled_reasoning_score = scale_to_range(reasoning_score)
                scaled_answer_score = scale_to_range(answer_score)
                f1_score = calculate_f1(scaled_reasoning_score, scaled_answer_score)
                
                results.append({
                    "image_path": image_path,
                    "question": question,
                    "raw_reasoning_score": reasoning_score,
                    "raw_answer_score": answer_score,
                    "scaled_reasoning_score": scaled_reasoning_score,
                    "scaled_answer_score": scaled_answer_score,
                    "f1_score": f1_score
                })
                valid_count += 1
            else:
                print(f"Skipping item {idx+1} due to None scores returned by LLM for Q: {question}")

            if (idx + 1) % print_interval == 0 and valid_count > 0:
                current_avg_raw_reasoning = sum(all_reasoning_scores) / valid_count
                current_avg_raw_answer = sum(all_answer_scores) / valid_count
                current_f1_scores = [r['f1_score'] for r in results]
                current_avg_f1 = sum(current_f1_scores) / valid_count
                
                print(f"\nProgress: {idx+1}/{len(candidate_data)} items processed for {model_name}")
                print(f"Current distribution - Reasoning: {reasoning_distribution}, Answer: {answer_distribution}")
                print(f"Current average raw reasoning score: {current_avg_raw_reasoning:.2f}/5")
                print(f"Current average raw answer score: {current_avg_raw_answer:.2f}/5")
                print(f"Current average F1 score: {current_avg_f1:.4f}")
        else:
            print(f"No matching reference found for Q: {question}, I: {image_path}. Skipping.")
    
    if valid_count == 0:
        print(f"No valid items were processed for {model_name}. Skipping final report for this model.")
        return None

    avg_raw_reasoning = sum(all_reasoning_scores) / valid_count
    avg_raw_answer = sum(all_answer_scores) / valid_count
    
    scaled_reasoning_scores = [scale_to_range(score) for score in all_reasoning_scores]
    scaled_answer_scores = [scale_to_range(score) for score in all_answer_scores]
    
    avg_scaled_reasoning = sum(scaled_reasoning_scores) / valid_count
    avg_scaled_answer = sum(scaled_answer_scores) / valid_count
    
    f1_scores = [r['f1_score'] for r in results]
    avg_f1 = sum(f1_scores) / valid_count
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    final_output = {
        "model_name": model_name,
        "results": results,
        "score_distribution": {
            "reasoning": reasoning_distribution,
            "answer": answer_distribution
        },
        "averages": {
            "avg_raw_reasoning_score": avg_raw_reasoning,
            "avg_raw_answer_score": avg_raw_answer,
            "avg_scaled_reasoning_score": avg_scaled_reasoning,
            "avg_scaled_answer_score": avg_scaled_answer,
            "avg_f1_score": avg_f1,
            "total_evaluated_items": valid_count
        },
        "evaluation_timestamp": timestamp
    }
    
    final_file = os.path.join(output_dir, f"{model_name}_final_evaluation_{timestamp}.json")
    save_json_file(final_output, final_file)
    
    print(f"\nEvaluation completed for {model_name} with {valid_count} valid items")
    print(f"Final score distribution - Reasoning: {reasoning_distribution}, Answer: {answer_distribution}")
    print(f"Final average raw reasoning score: {avg_raw_reasoning:.2f}/5")
    print(f"Final average raw answer score: {avg_raw_answer:.2f}/5")
    print(f"Final average scaled reasoning score: {avg_scaled_reasoning:.4f}")
    print(f"Final average scaled answer score: {avg_scaled_answer:.4f}")
    print(f"Final average F1 score: {avg_f1:.4f}")
    print(f"Final results saved to {final_file}")
    
    return {
        "model_name": model_name,
        "avg_f1_score": avg_f1,
        "avg_raw_reasoning_score": avg_raw_reasoning,
        "avg_raw_answer_score": avg_raw_answer,
        "total_evaluated_items": valid_count,
        "evaluation_timestamp": timestamp
    }

def main():
    print(f"Loading reference data from {args.reference_file}")
    try:
        reference_data = load_json_file(args.reference_file)
        print(f"Loaded {len(reference_data)} reference items")
    except Exception as e:
        print(f"Fatal error: Could not load reference file {args.reference_file}: {e}")
        return

    if args.specific_model:
        model_files = [os.path.join(args.model_dir, args.specific_model)]
        if not os.path.exists(model_files[0]):
            print(f"Fatal error: Specified model file does not exist: {model_files[0]}")
            return
    else:
        model_files = sorted(glob.glob(os.path.join(args.model_dir, "*_200predictions.json")))
    
    if not model_files:
        print(f"No model prediction files found in {args.model_dir} matching pattern '*_200predictions.json'.")
        if args.specific_model:
             print(f"Or the specific model file {args.specific_model} was not found.")
        return
        
    print(f"Found {len(model_files)} model files to evaluate: {model_files}")
    
    summary_results = []
    
    for model_file_path in model_files:
        print(f"Processing model file: {model_file_path}")
        result_summary = evaluate_model_file(
            model_file_path,
            reference_data,
            args.output_dir,
            args.print_interval
        )
        if result_summary:
            summary_results.append(result_summary)
        else:
            print(f"Skipped adding summary for {model_file_path} as evaluation returned None (likely due to errors or no valid items).")

    if not summary_results:
        print("\nNo models were successfully evaluated. Skipping summary generation.")
        return

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    summary_file = os.path.join(args.output_dir, f"GRPO_evaluation_summary_{timestamp}.json")
    save_json_file(summary_results, summary_file)
    print(f"\nSummary of all evaluations saved to {summary_file}")
    
    print("\n" + "="*80)
    print("MODEL COMPARISON SUMMARY")
    print("="*80)
    header = f"{'Model Name':<40} {'F1 Score':<10} {'Reasoning':<12} {'Answer':<12} {'Items':<10}"
    print(header)
    print("-" * len(header))
    
    sorted_results = sorted(summary_results, key=lambda x: x.get('avg_f1_score', 0), reverse=True)
    for result in sorted_results:
        model_name_display = result.get('model_name', 'N/A')[:38] # Truncate long names
        f1_score_display = f"{result.get('avg_f1_score', 0):.4f}"
        reasoning_score_display = f"{result.get('avg_raw_reasoning_score', 0):.2f}/5"
        answer_score_display = f"{result.get('avg_raw_answer_score', 0):.2f}/5"
        items_display = str(result.get('total_evaluated_items', 0))
        
        print(f"{model_name_display:<40} {f1_score_display:<10} {reasoning_score_display:<12} {answer_score_display:<12} {items_display:<10}")
    print("="*len(header))

if __name__ == "__main__":
    main()