import os
import json
import torch
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
from datetime import datetime

# Initialize model and processor
model = LlavaNextVideoForConditionalGeneration.from_pretrained(
    "", torch_dtype=torch.float16, device_map="auto"
)
processor = LlavaNextVideoProcessor.from_pretrained("")

# Build evaluation prompt
def build_prompt(action):
    return (
        "Given a video and an action description, reply with one of the following options ONLY:\n"
        "- 'yes' if the action is completed,\n"
        "- 'no' if the action is not completed,\n"
        "- 'Not exists' if the action in the video does not match the given action.\n\n"
        f"Action: {action}"
    )

# Call LLaVA-NeXT-Video with direct video input
def call_llava_next_with_video(prompt, video_path, num_frames=8):
    if not os.path.exists(video_path):
        print(f"❌ Video not found: {video_path}")
        return "Error: Video not found"
    
    try:
        # Prepare message with video input
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "video", "path": video_path},
                ],
            }
        ]

        # Process input
        inputs = processor.apply_chat_template(
            conversation, 
            num_frames=num_frames, 
            add_generation_prompt=True, 
            tokenize=True, 
            return_dict=True, 
            return_tensors="pt"
        ).to(model.device)

        # Inference
        output = model.generate(**inputs, max_new_tokens=20)
        response = processor.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0].strip().lower()
        if "assistant:" in response:
            answer = response.split("assistant:")[1].strip()
        else:
            answer = response.strip()
        return answer
    
    except Exception as e:
        print(f"❌ Error while processing video {video_path}: {e}")
        return f"Error: Video processing failed - {str(e)}"

# Get list of already processed videos
def get_processed_videos(output_json_path):
    if not os.path.exists(output_json_path):
        return set()
    
    try:
        with open(output_json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return {item["video"] for item in data}
    except (json.JSONDecodeError, FileNotFoundError):
        return set()

# Save result to JSON file
def save_result_json_list(output_path, result_entry):
    # Add timestamp
    result_entry["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    
    if os.path.exists(output_path):
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except json.JSONDecodeError:
            data = []
    else:
        data = []

    data.append(result_entry)

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# Main evaluation function
def evaluate(json_path, output_json_path, num_frames=8):
    # Load list of already processed videos
    processed_videos = get_processed_videos(output_json_path)
    print(f"ℹ️ Found {len(processed_videos)} previously processed videos")
    
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    total = 0
    correct = 0
    skipped = 0

    for idx, item in enumerate(data):
        video_path = item["video"]
        action = item["action"]
        ground_truth = item["if_finish"].strip().lower()

        # Skip already processed videos
        if video_path in processed_videos:
            print(f"[{idx+1}/{len(data)}] Skipped: {os.path.basename(video_path)} (already processed)")
            skipped += 1
            continue

        print(f"[{idx+1}/{len(data)}] Processing: {os.path.basename(video_path)}")

        if not os.path.exists(video_path):
            print(f"⏭️ Skipping {video_path} - file not found")
            continue

        # Prepare prompt and call LLaVA-NeXT-Video
        prompt = build_prompt(action)
        prediction = call_llava_next_with_video(prompt, video_path, num_frames=num_frames)

        # Check for error
        if prediction.startswith("Error:"):
            print(f"⏭️ Skipping {video_path} - {prediction}")
            continue

        # Parse prediction result
        if "not exists" in prediction:
            predicted_label = "Not exists"
        elif "yes" in prediction:
            predicted_label = "yes"
        elif "no" in prediction:
            predicted_label = "no"
        else:
            print(f"⚠️ Unrecognized output: {prediction}")
            predicted_label = "Unknown"

        is_correct = predicted_label.lower() == ground_truth.lower()

        # Save result
        result_entry = {
            "video": video_path,
            "action": action,
            "ground_truth": ground_truth,
            "prediction": predicted_label,
            "raw_output": prediction,
            "correct": is_correct
        }

        save_result_json_list(output_json_path, result_entry)

        total += 1
        if is_correct:
            correct += 1

    # Calculate session and overall accuracy
    if total > 0:
        current_accuracy = correct / total
        print(f"\n✅ Evaluation completed -- Accuracy: {current_accuracy * 100:.2f}% ({correct}/{total})")
    else:
        print("\n⚠️ No new videos were processed in this session.")
    
    # Calculate overall accuracy from result file
    if os.path.exists(output_json_path):
        with open(output_json_path, 'r', encoding='utf-8') as f:
            results_data = json.load(f)
            
        all_total = len(results_data)
        all_correct = sum(1 for item in results_data if item.get("correct", False))
            
        if all_total > 0:
            overall_accuracy = all_correct / all_total
            print(f"✅ Overall accuracy: {overall_accuracy * 100:.2f}% ({all_correct}/{all_total})")
            print(f"ℹ️ Skipped {skipped} previously processed videos")

# Entry point
def main():
    json_path = ""
    output_json_path = ""
    num_frames = 8  # Adjustable if needed

    if not os.path.exists(os.path.dirname(output_json_path)):
        os.makedirs(os.path.dirname(output_json_path))

    print("🚀 Starting evaluation with LLaVA-NeXT-Video...")
    evaluate(json_path, output_json_path, num_frames=num_frames)

if __name__ == "__main__":
    main()
