import csv
import requests
import base64
import mimetypes
import pathlib
import os

ENDPOINT = "xx"
API_KEY = "xx"
DATA_ROOT = "../data/a1.0.0_00"

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def build_llm_message(assembly_img_path, part_img_path):
    return [
        {"type": "text", "text": "Image 1: Assembly (assembled product)"},
        {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
        {"type": "text", "text": "Image 2: Individual Part"},
        {"type": "image_url", "image_url": {"url": to_data_url(part_img_path)}},
        {"type": "text", "text": 
"""You are an expert mechanical engineer. Given Image 1 (the assembly) and Image 2 (an individual part from the assembly), please generate a concise and descriptive noun phrase (not a full sentence). The phrase should briefly describe the part's main shape and any key features, in a way that clearly distinguishes it from the other parts in the assembly. Avoid generic names like "part" or "component". Be specific about the shape and any holes, slots, or functional features.
Your output should be a single noun phrase.
For example:
    - A conical mount with a forked top;
    - A cylindrical pin;
    - Two plates with each having holes;
    - A flat round disk with three small holes;
    - A rectangular bracket with two mounting slots.
"""}
    ]

def call_llm(messages):
    headers = {
        "Content-type": "application/json",
        "api-key": API_KEY,
    }
    body = {"messages": [{"role": "user", "content": messages}], "max_tokens": 100}
    try:
        res = requests.post(ENDPOINT, headers=headers, json=body, timeout=60)
        res.raise_for_status()
        result = res.json()

        if "choices" in result:
            content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
        elif "data" in result and "output" in result["data"]:
            content = result["data"]["output"]
        else:
            content = ""

        for line in content.strip().split("\n"):
            s = line.strip("-•。: \n\t")
            if s:
                return s
        return content.strip()
    except Exception as e:
        print(f"[ERROR] LLM call failed: {e}")
        return "[ERROR]"

def main():
    input_csv = "../result/s1_output.csv"
    output_csv = "../result/s2_output.csv"
    with open(input_csv, "r", encoding="utf-8", newline='') as fin:
        reader = csv.DictReader(fin)
        rows = list(reader)
        fieldnames = reader.fieldnames or []

    if "Descriptive" not in fieldnames:
        fieldnames.append("Descriptive")

    with open(output_csv, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in rows:
            assembly_name = row["Assembly"].strip()
            part_names = [p.strip() for p in row["Part_names"].split(";") if p.strip()]
            assembly_img_path = os.path.join(DATA_ROOT, assembly_name, "assembly.png")
            descriptive_phrases = []
            for part_name in part_names:
                part_img_path = os.path.join(DATA_ROOT, assembly_name, f"{part_name}.png")
                if not os.path.exists(assembly_img_path):
                    print(f"[WARN] Assembly image not found: {assembly_img_path}")
                    descriptive_phrases.append("[NO ASSEMBLY IMAGE]")
                    continue
                if not os.path.exists(part_img_path):
                    print(f"[WARN] Part image not found: {part_img_path}")
                    descriptive_phrases.append("[NO PART IMAGE]")
                    continue
                messages = build_llm_message(assembly_img_path, part_img_path)
                desc = call_llm(messages)
                descriptive_phrases.append(desc)
            row["Descriptive"] = ";".join(descriptive_phrases)
            writer.writerow(row)
            print(f"[DONE] {assembly_name} : {row['Descriptive']}")

if __name__ == "__main__":
    main()
