from openai import OpenAI
import base64
import os
from typing import Dict
from PIL import Image
import json
from pathlib import Path
from tqdm import tqdm


openai_api_key = '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' # todo
openai_base_url = 'https://api.openai.com/v1'

client = OpenAI(api_key=openai_api_key, base_url=openai_base_url)

# 用于记录 token 用量
TOKENS_BLOG = {}

# -------------------- Token & Cost -------------------- #
def calculate_api_cost(tokens_blog={}):
    cost_dict = {
        'gpt-4o': {'prompt': 5.0, 'completion': 15.0},
        'gpt-4-vision-preview': {'prompt': 5.0, 'completion': 15.0},
    }

    total_cost = 0
    detailed_cost = {}
    for model, usage in tokens_blog.items():
        detailed_cost[model] = {}
        for k in ['prompt', 'completion']:
            if k in usage and model in cost_dict and k in cost_dict[model]:
                cost = usage[k] * cost_dict[model][k] / 1_000_000
                detailed_cost[model][k] = cost
                total_cost += cost
    return total_cost, detailed_cost

# -------------------- Image Utils -------------------- #
def gif_to_base64(gif_path: str) -> str:
    with open(gif_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def get_image_resolution(image_path: str):
    with Image.open(image_path) as img:
        return img.size  # (width, height)

# -------------------- Prompt -------------------- #
def build_prompt(full_desc: str, person_a_desc: str, person_b_desc: str,
                 start_frame: int, end_frame: int) -> str:
    return f"""
You are given a short video segment from a two-person interaction, in the form of a GIF image.  

Your job is to describe what happens **based on the visual motion in the GIF only**, even if it contradicts or only partially reflects the textual scene description.

Context (for reference only):
- Full scene description: "{full_desc}"
- Person A's general behavior: "{person_a_desc}"
- Person B's general behavior: "{person_b_desc}"

This segment includes frames {start_frame} to {end_frame}.

If the motion appears as idle, transition, or contains no meaningful visual action, label all descriptions as "transition".

Otherwise, describe:
- What happens overall in this segment (visually observed)
- What Person A is doing (visually)
- What Person B is doing (visually)

Please respond with a **valid JSON object only**, with the following structure:

{{
  "start_frame": {start_frame},
  "end_frame": {end_frame},
  "overall_description": "...",
  "personA_description": "...",
  "personB_description": "..."
}}
Please output only a valid JSON object, without any code formatting (e.g., no triple backticks like ```json).
"""

# -------------------- GPT-4V 调用 -------------------- #
def call_gpt4v_on_segment(
    gif_path: str,
    full_desc: str,
    person_a_desc: str,
    person_b_desc: str,
    start_frame: int,
    end_frame: int,
    model: str = "gpt-4o"
) -> Dict:
    base64_gif = gif_to_base64(gif_path)
    prompt = build_prompt(full_desc, person_a_desc, person_b_desc, start_frame, end_frame)

    response = client.chat.completions.create(
        model=model,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/gif;base64,{base64_gif}",
                            "detail": "high"
                        }
                    }
                ]
            }
        ],
        max_tokens=512,
        temperature=0.3,
    )

    message = response.choices[0].message.content
    usage = response.usage
    prompt_tokens = usage.prompt_tokens
    completion_tokens = usage.completion_tokens

    if model not in TOKENS_BLOG:
        TOKENS_BLOG[model] = {'prompt': 0, 'completion': 0}
    TOKENS_BLOG[model]['prompt'] += prompt_tokens
    TOKENS_BLOG[model]['completion'] += completion_tokens

    return {
        "result": message,
        "prompt_tokens": prompt_tokens,
        "completion_tokens": completion_tokens,
    }

# -------------------- 主流程：批量处理 -------------------- #
def run_batch_annotation(gif_dir, segments, full_desc, person_a_desc, person_b_desc, save_path):
    results = []
    for idx, (start, end) in enumerate(tqdm(segments)):
        gif_path = os.path.join(gif_dir, f"test_segment{idx}.gif")
        output = call_gpt4v_on_segment(
            gif_path,
            full_desc,
            person_a_desc,
            person_b_desc,
            start,
            end
        )
        try:
            output_json = json.loads(output['result'])
        except:
            output_json = {
                "start_frame": start,
                "end_frame": end,
                "overall_description": "ERROR",
                "personA_description": "ERROR",
                "personB_description": "ERROR"
            }
        results.append(output_json)

    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=4)

    cost_sum, cost_detail = calculate_api_cost(TOKENS_BLOG)
    print("\n💰 Total Cost: ${:.4f}".format(cost_sum))
    print("📊 Breakdown:")
    print(json.dumps(cost_detail, indent=2))


if __name__ == "__main__":
    gif_path = "./vis_segment_gif"
    segments = [(0, 5), (6, 69), (70, 165), (166, 285)]
    full_desc = "both people are walking in the same direction. suddenly, an item falls to the ground from one. the other approaches it and picks it up. then, the other catches up with one and holds onto one's right arm with their left hand."
    person_a_desc = "Another person catches up with the first person and takes hold of their right arm with their left hand."
    person_b_desc = "A person walks towards an item that has fallen on the ground and picks it up, while continuing to walk in the same direction as others." 
    save_path = "annotated_segments.json"

    run_batch_annotation(gif_path, segments, full_desc, person_a_desc, person_b_desc, save_path)