import os
import json
import time
import base64
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, Any, Tuple
from tqdm import tqdm
from openai import OpenAI

client = OpenAI(
    base_url="XXX",
    api_key="XXX"
)

MODEL_NAME = "gpt-4o"
CONCURRENCY = 16
INPUT_JSON   = "ChartP_Bench_Compressed/chartp_annotations.json"
IMAGE_FOLDER = "ChartP_Bench_Compressed"
OUTPUT_JSON  = f"/Results/chartp_{MODEL_NAME}.json"
CHECKPOINT_EVERY = 100

PROMPT = """You are given a chart image. Convert its data into a strict JSON with this schema:

{
  "title": "<string or N/A>",
  "values": {
    "<series_or_None>": {
      "<x_label_or_category>": "<numeric_or_text_value_as_string>",
      ...
    },
    ...
  }
}

Rules:
- If the chart has no legend/series, use the single key "None" inside "values".
- Keep numbers as strings (e.g., "12.3").
- Include all visible categories/years/labels on X-axis as keys.
- Do NOT add explanations or code. Output ONLY the JSON object.
"""

def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def _strip_code_fences(text: str) -> str:
    text = text.strip()
    if text.startswith("```"):
        first_newline = text.find("\n")
        if first_newline != -1:
            text = text[first_newline+1:]
        if text.endswith("```"):
            text = text[:-3]
    return text.strip()

def _extract_json_object(text: str) -> str | None:
    start = text.find("{")
    if start == -1:
        return None
    depth = 0
    for i in range(start, len(text)):
        c = text[i]
        if c == "{":
            depth += 1
        elif c == "}":
            depth -= 1
            if depth == 0:
                return text[start:i+1]
    return None

def parse_to_unified(content: str) -> dict:
    if not isinstance(content, str):
        return {"title": "N/A", "values": {}}
    cleaned = _strip_code_fences(content)
    json_str = _extract_json_object(cleaned) or cleaned
    try:
        data = json.loads(json_str)
    except Exception:
        return {"title": "N/A", "values": {}}

    title = data.get("title", "N/A")
    values = data.get("values", {})
    if not isinstance(title, str):
        title = "N/A"
    if not isinstance(values, dict):
        values = {}

    if values and all(not isinstance(v, dict) for v in values.values()):
        values = {"None": values}
    if "None" not in values and len(values) == 0:
        values["None"] = {}
    return {"title": title, "values": values}

def call_gpt_on_item(idx: int, item: Dict[str, Any], max_retries: int = 2) -> Tuple[int, Dict[str, Any]]:
    image_rel = item["image"]
    image_path = os.path.join(IMAGE_FOLDER, image_rel)
    b64 = encode_image(image_path)

    img_part = {
        "type": "image_url",
        "image_url": {"url": f"data:image/jpeg;base64,{b64}"}
    }
    text_part = {"type": "text", "text": PROMPT}

    for attempt in range(1, max_retries + 1):
        try:
            resp = client.chat.completions.create(
                model=MODEL_NAME,
                messages=[{"role": "user", "content": [text_part, img_part]}],
                timeout=30
            )
            content = resp.choices[0].message.content.strip()
            pred = parse_to_unified(content)
            print(pred)
            return idx, {"image": image_rel, "pred": pred}
        except KeyboardInterrupt:
            raise
        except Exception as e:
            print(e)
            if attempt == max_retries:
                return idx, {"image": image_rel, "pred": {"title": "N/A", "values": {}}}
            time.sleep(min(2 ** attempt, 10))

def main():
    with open(INPUT_JSON, "r") as f:
        dataset = json.load(f)

    results = [None] * len(dataset)
    completed = 0

    try:
        with ThreadPoolExecutor(max_workers=CONCURRENCY) as ex:
            futures = {ex.submit(call_gpt_on_item, i, item): i for i, item in enumerate(dataset)}
            with tqdm(total=len(futures), desc=f"{MODEL_NAME} evaluating (concurrent x{CONCURRENCY})") as pbar:
                for fut in as_completed(futures):
                    idx, res = fut.result()
                    results[idx] = res
                    completed += 1
                    pbar.update(1)
                    if completed % CHECKPOINT_EVERY == 0:
                        os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
                        with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
                            json.dump(results, f, indent=2, ensure_ascii=False)

    except KeyboardInterrupt:
        print("\n^C detected, saving partial results...")
    finally:
        for i in range(len(results)):
            if results[i] is None:
                results[i] = {"image": dataset[i]["image"], "pred": {"title": "N/A", "values": {}}}

        os.makedirs(os.path.dirname(OUTPUT_JSON), exist_ok=True)
        with open(OUTPUT_JSON, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)

        print(f"✅ Saved {len(results)} predictions to {OUTPUT_JSON}")

if __name__ == "__main__":
    main()
