import os
import json
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
from datetime import datetime

# Initialize model and processor
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "", torch_dtype="auto", device_map="auto"
)
processor = AutoProcessor.from_pretrained("")

# Build evaluation prompt
def build_prompt(action):
    return (
        "Given a video and an action description, reply with one of the following options ONLY:\n"
        "- 'yes' if the action is completed,\n"
        "- 'no' if the action is not completed,\n"
        "- 'Not exists' if the action in the video does not match the given action.\n\n"
        f"Action: {action}"
    )

# Call the Qwen2.5-VL model with direct video input
def call_qwen_vl_with_video(prompt, video_path, max_pixels=360*420, fps=0.5):
    if not os.path.exists(video_path):
        print(f"❌ Video not found: {video_path}")
        return "Error: Video not found"
    
    try:
        # Prepare message with direct video input
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": f"file://{video_path}",
                        "max_pixels": max_pixels,
                        "fps": fps,
                    },
                    {"type": "text", "text": prompt}
                ],
            }
        ]

        # Process input
        text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        image_inputs, video_inputs = process_vision_info(messages)
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            fps=fps,
            padding=True,
            return_tensors="pt"
        ).to(model.device)

        # Inference
        generated_ids = model.generate(**inputs, max_new_tokens=20)
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        return output_text[0].strip().lower()
    
    except Exception as e:
        print(f"❌ Error while processing video {video_path}: {e}")
        return "Error: Video processing failed"

# Get the list of already processed videos
def get_processed_videos(output_json_path):
    if not os.path.exists(output_json_path):
        return set()
    
    try:
        with open(output_json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return {item["video"] for item in data}
    except (json.JSONDecodeError, FileNotFoundError):
        return set()

# Save results to a JSON file
def save_result_json_list(output_path, result_entry):
    # Add timestamp
    result_entry["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    if os.path.exists(output_path):
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except json.JSONDecodeError:
            data = []
    else:
        data = []

    data.append(result_entry)

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# Main evaluation function
def evaluate(json_path, output_json_path):
    # Get the list of previously processed videos
    processed_videos = get_processed_videos(output_json_path)
    print(f"ℹ️ Found {len(processed_videos)} previously processed videos")
    
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    total = 0
    correct = 0
    skipped = 0

    for idx, item in enumerate(data):
        video_path = item["video"]
        action = item["action"]
        ground_truth = item["if_finish"].strip().lower()

        # Skip if already processed
        if video_path in processed_videos:
            print(f"[{idx+1}/{len(data)}] Skipped: {os.path.basename(video_path)} (already processed)")
            skipped += 1
            continue

        print(f"[{idx+1}/{len(data)}] Processing: {os.path.basename(video_path)}")

        if not os.path.exists(video_path):
            print(f"⏭️ Skipping {video_path} - file not found")
            continue

        # Prepare prompt and call Qwen2.5-VL
        prompt = build_prompt(action)
        prediction = call_qwen_vl_with_video(prompt, video_path)

        # Check for errors
        if prediction.startswith("Error:"):
            print(f"⏭️ Skipping {video_path} - {prediction}")
            continue

        # Parse prediction result
        if "not exists" in prediction:
            predicted_label = "Not exists"
        elif "yes" in prediction:
            predicted_label = "yes"
        elif "no" in prediction:
            predicted_label = "no"
        else:
            print(f"⚠️ Unrecognized output: {prediction}")
            predicted_label = "Unknown"

        is_correct = predicted_label.lower() == ground_truth.lower()

        # Save result
        result_entry = {
            "video": video_path,
            "action": action,
            "ground_truth": ground_truth,
            "prediction": predicted_label,
            "raw_output": prediction,
            "correct": is_correct
        }

        save_result_json_list(output_json_path, result_entry)

        total += 1
        if is_correct:
            correct += 1

    # Calculate current session accuracy
    if total > 0:
        current_accuracy = correct / total
        print(f"\n✅ Evaluation complete -- Accuracy: {current_accuracy * 100:.2f}% ({correct}/{total})")
    else:
        print("\n⚠️ No new videos were processed in this session.")
    
    # Calculate overall accuracy from the output file
    all_results = get_processed_videos(output_json_path)
    all_total = len(all_results)
    if all_total > 0:
        all_correct = 0
        with open(output_json_path, 'r', encoding='utf-8') as f:
            results_data = json.load(f)
            all_correct = sum(1 for item in results_data if item.get("correct", False))
        
        overall_accuracy = all_correct / all_total
        print(f"✅ Overall evaluation -- Accuracy: {overall_accuracy * 100:.2f}% ({all_correct}/{all_total})")
        print(f"ℹ️ Skipped {skipped} already processed videos")

# Entry point
def main():
    json_path = ""
    output_json_path = ""

    if not os.path.exists(os.path.dirname(output_json_path)):
        os.makedirs(os.path.dirname(output_json_path))

    print("🚀 Starting evaluation using Qwen2.5-VL...")
    evaluate(json_path, output_json_path)

if __name__ == "__main__":
    main()
