import os
import json
import base64
import cv2
import numpy as np
from openai import OpenAI, BadRequestError

# Initialize the OpenAI client
client = OpenAI(
    api_key='',  # ← Your API Key
)

# Build prompt (no template used)
def build_prompt(action):
    return (
        "Given an image and an action description, reply with one of the following options ONLY:\n"
        "- 'yes' if the action is completed,\n"
        "- 'no' if the action is not completed,\n"
        "- 'Not exists' if the action in the video does not match the given action.\n\n"
        f"Action: {action}"
    )

# Sample 8 evenly spaced frames from a video and combine them into a single image
def extract_sampled_frames(video_path, output_jpg_path, num_frames=8):
    if not os.path.exists(video_path):
        print(f"❌ Video not found: {video_path}")
        return None

    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if total_frames < num_frames:
        print(f"⚠️ Not enough frames in video: {video_path}")
        return None

    indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    frames = []
    frame_id = 0

    for i in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break
        if i == indices[frame_id]:
            resized = cv2.resize(frame, (320, 240))
            frames.append(resized)
            frame_id += 1
            if frame_id >= num_frames:
                break

    cap.release()

    if len(frames) < num_frames:
        print(f"⚠️ Could not extract enough frames: {video_path}")
        return None

    row1 = np.hstack(frames[:4])
    row2 = np.hstack(frames[4:])
    grid = np.vstack([row1, row2])

    cv2.imwrite(output_jpg_path, grid)
    return output_jpg_path

# Send image and prompt to GPT-4o
def call_gpt4o(prompt, image_path):
    with open(image_path, "rb") as image_file:
        base64_image = base64.b64encode(image_file.read()).decode("utf-8")

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{base64_image}",
                            "detail": "auto"
                        }
                    }
                ]
            }
        ],
        max_tokens=10,
        temperature=0
    )

    return response.choices[0].message.content.strip().lower()

# Save results as a JSON list, append entries one by one
def save_result_json_list(output_path, result_entry):
    if os.path.exists(output_path):
        try:
            with open(output_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
        except json.JSONDecodeError:
            data = []
    else:
        data = []

    data.append(result_entry)

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)

# Get the list of videos that have already been processed
def get_processed_videos(output_json_path):
    processed_videos = set()
    if os.path.exists(output_json_path):
        try:
            with open(output_json_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                for item in data:
                    processed_videos.add(item["video"])
        except (json.JSONDecodeError, FileNotFoundError):
            pass
    return processed_videos

# Main evaluation function
def evaluate(json_path, output_json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    processed_videos = get_processed_videos(output_json_path)
    print(f"🔄 Found {len(processed_videos)} previously processed videos")

    total = 0
    correct = 0
    skipped = 0
    
    # Load previously recorded correct predictions
    if os.path.exists(output_json_path):
        try:
            with open(output_json_path, 'r', encoding='utf-8') as f:
                existing_results = json.load(f)
                for result in existing_results:
                    if result.get("correct", False):
                        correct += 1
                total = len(existing_results)
        except (json.JSONDecodeError, FileNotFoundError):
            pass

    for i, item in enumerate(data):
        video_path = item["video"]
        action = item["action"]
        ground_truth = item["if_finish"].strip().lower()

        # Skip videos that have already been processed
        if video_path in processed_videos:
            print(f"⏩ Skipping already processed video ({i+1}/{len(data)}): {video_path}")
            skipped += 1
            continue

        print(f"🎬 Processing video ({i+1}/{len(data)}): {video_path}")
        
        image_output_path = video_path.replace(".mp4", "_frames.jpg")
        image_path = extract_sampled_frames(video_path, image_output_path)

        if not image_path:
            print(f"⏭️ Skipping {video_path} due to extraction error")
            continue

        prompt = build_prompt(action)
        
        try:
            prediction = call_gpt4o(prompt, image_path)
            
            if "not exists" in prediction:
                predicted_label = "Not exists"
            elif "yes" in prediction:
                predicted_label = "yes"
            elif "no" in prediction:
                predicted_label = "no"
            else:
                predicted_label = "Unknown"  # Optional: handle unrecognized model output
                
        except BadRequestError as e:
            # Check if the error is due to content filtering
            if "content management policy" in str(e) or "content_filter" in str(e):
                print(f"⚠️ Content filter triggered: {video_path}")
                print(f"⚠️ Error details: {str(e)}")
                
                predicted_label = "Error-ContentFilter"
                is_correct = False
                
                result_entry = {
                    "video": video_path,
                    "action": action,
                    "ground_truth": ground_truth,
                    "prediction": predicted_label,
                    "correct": is_correct,
                    "error": "content_filter"
                }
                
                save_result_json_list(output_json_path, result_entry)
                
                total += 1
                print(f"❌ Result: Content filter error (GT: {ground_truth}) - Counted as incorrect")
                current_accuracy = correct / total if total > 0 else 0
                print(f"📊 Current accuracy: {current_accuracy * 100:.2f}% ({correct}/{total})")
                continue
            else:
                raise

        is_correct = predicted_label == ground_truth

        result_entry = {
            "video": video_path,
            "action": action,
            "ground_truth": ground_truth,
            "prediction": predicted_label,
            "correct": is_correct
        }

        save_result_json_list(output_json_path, result_entry)

        total += 1
        if is_correct:
            correct += 1

        # Print current progress and accuracy
        current_accuracy = correct / total if total > 0 else 0
        print(f"✓ Result: {predicted_label} (GT: {ground_truth}) - {'✅ Correct' if is_correct else '❌ Incorrect'}")
        print(f"📊 Current accuracy: {current_accuracy * 100:.2f}% ({correct}/{total})")

    accuracy = correct / total if total > 0 else 0
    print(f"\n✅ Evaluation completed")
    print(f"📊 Final accuracy: {accuracy * 100:.2f}% ({correct}/{total})")
    print(f"⏩ Skipped: {skipped} previously processed videos")

# Entry point
def main():
    json_path = ""
    output_json_path = ""

    if not os.path.exists(os.path.dirname(output_json_path)):
        os.makedirs(os.path.dirname(output_json_path))

    print("🚀 Starting evaluation...")
    evaluate(json_path, output_json_path)

if __name__ == "__main__":
    main()
