from openai import OpenAI
import os
import datetime
import csv
import json 
from tqdm import tqdm
import concurrent.futures
from typing import List, Tuple, Dict

client = OpenAI(
    base_url="https://ark.cn-beijing.volces.com/api/v3",
    api_key="your_api_key"
)

print("---begin inference---")

def ask_problem(system_prompt: str, caption: str, examples: List[Tuple[str, str]] = None, idx: int = 0) -> str:
    messages = [
        {"role": "system", "content": system_prompt}
    ]
    
    if examples:
        for example in examples:
            if isinstance(example, tuple) and len(example) == 2:
                messages.extend([
                    {"role": "user", "content": example[0]},
                    {"role": "assistant", "content": example[1]}
                ])
    
    question = f"Extract the visual objects that are contained in the caption '{caption}', return in list separated by comma, don't contain explanation"
    messages.append({"role": "user", "content": question})
    
    try:
        response = client.chat.completions.create(
            model="deepseek-v3-250324",
            messages=messages,
            stream=False
        )    
        return response.choices[0].message.content
    except Exception as e:
        print(f"Error occurs when calling API: {str(e)}")
        if not os.path.exists('./error_captions.csv'):
            with open('./error_captions.csv', 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(['Timestamp', 'Caption', 'Error', 'Index'])
        with open('./error_captions.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([datetime.datetime.now(), caption, str(e), idx])
        return None

def process_single_item(item: Tuple[int, Tuple[str, str, str]], examples: List[Tuple[str, str]], processed_ids: set) -> Dict:
    idx, (caption, image_path, image_id) = item
    
    if image_id in processed_ids:
        print("duplicate!")
        return None
        
    response = ask_problem("You are a helpful assistant", caption, examples=examples, idx=idx)
    if response:
        return {
            "image_path": image_path,
            "image_id": image_id,
            "caption": caption,
            "concepts": response
        }
    return None

example = [("Extract the visual objects that are contained in the caption  'A blond woman is on the street hailing a taxi', each entity is one or few words, return in list separated by comma, don't contain explanation", "hair,woman,street,arm,taxi")]

processed_ids = set()

if os.path.exists("./coco_train_captions_entities.jsonl"):
    with open("./coco_train_captions_entities.jsonl", "r") as f:
        for line in f:
            data = json.loads(line)
            processed_ids.add(data["image_id"])

with open("./coco_train_captions.jsonl", "r") as f:
    data = json.load(f)
    captions = data['captions']
    image_path = data['image_path']
    image_id = data['image_id']

    items = list(enumerate(zip(captions, image_path, image_id)))
    threads = 128
    with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
        futures = [
            executor.submit(process_single_item, item, example, processed_ids)
            for item in items
        ]
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
            result = future.result()
            if result:
                with open("./coco_train_captions_entities.jsonl", "a") as f:
                    f.write(json.dumps(result) + "\n")
