import os
import csv
import requests
import base64
import mimetypes
import pathlib
import concurrent.futures
import time
import random
import json

ENDPOINTS = ["xx", "xx"]
API_KEYS = ["xx", "xx"]

MAX_WORKERS = 4
MAX_RETRY = 3
RETRY_BASE_SEC = 2


DATA_ROOT = "../data/a1.0.0_00"
JSON_OUT_ROOT = "../result/ss1_json_v2"
os.makedirs(JSON_OUT_ROOT, exist_ok=True)

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=256):
    for i in range(retry):
        endpoint = ENDPOINTS[idx % len(ENDPOINTS)]
        api_key = API_KEYS[idx % len(API_KEYS)]
        headers = {"Content-type": "application/json", "api-key": api_key}
        body = {"messages": [{"role": "user", "content": messages}], "max_tokens": max_tokens}
        try:
            res = requests.post(endpoint, headers=headers, json=body, timeout=120)
            res.raise_for_status()
            result = res.json()
            if "choices" in result:
                content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            elif "data" in result and "output" in result["data"]:
                content = result["data"]["output"]
            else:
                content = ""
            return content.strip()
        except Exception as e:
            print(f"[API {endpoint}] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"


def generate_part_descriptions(assembly_img_path, part_img_paths):
    part2desc = {}
    def process(idx, part_img_path):
        part_filename = os.path.basename(part_img_path)
        messages = [
            {"type": "text", "text": "Image 1: Assembly (assembled product)"},
            {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
            {"type": "text", "text": "Image 2: Individual part"},
            {"type": "image_url", "image_url": {"url": to_data_url(part_img_path)}},
            {"type": "text", "text": 
"""You are an expert mechanical engineer. Given Image 1 (the assembly) and Image 2 (an individual part from the assembly), please generate a concise and descriptive noun phrase (not a full sentence). The phrase should briefly describe the part's main shape and any key features, in a way that clearly distinguishes it from the other parts in the assembly. Avoid generic names like "part" or "component". Be specific about the shape and any holes, slots, or functional features.
Your output should be a single noun phrase.
For example:
    - A conical mount with a forked top;
    - A cylindrical pin;
    - Two plates with each having holes;
    - A flat round disk with three small holes;
    - A rectangular bracket with two mounting slots.
"""}
    ]
        desc = call_llm_with_retry(messages, idx)
        return part_filename, desc

    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [
            executor.submit(process, idx, part_img_path)
            for idx, part_img_path in enumerate(part_img_paths)
        ]
        for future in concurrent.futures.as_completed(futures):
            part_filename, desc = future.result()
            part2desc[part_filename] = desc

    return part2desc


def infer_matching_parts(assembly_img_path, part2desc, specification, idx):
    desc_list = "\n".join([f"{k}: {v}" for k, v in part2desc.items()])
    prompt = (
        "You are an expert mechanical engineer.\n"
        "The image above shows the complete assembly. Each part description below corresponds to a unique part image from this assembly, "
        "mapped as 'filename: description'.\n"
        "You should answer the filenames of the parts that correspond to the specification.\n"
        "Here are all the part descriptions:\n"
        f"{desc_list}\n"
        f"Specification: {specification}\n"
        "Your task:\n"
        "1. Think step by step (Chain-of-Thought) and explain how you identify the required part(s).\n"
        "2. In the last line, write 'Final Answer:' followed by only the selected part filenames (from the keys above), separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Example output:\n"
        "Chain-of-Thought:\n"
        "First, I check the descriptions of all parts. Only part1.png and part2.png are described as cylindrical pins. Therefore, the required parts are part1.png and part2.png.\n"
        "Final Answer:\n"
        "part1.png;part2.png\n"
    )
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
        {"type": "text", "text": prompt}
    ]
    full_output = call_llm_with_retry(messages, idx, max_tokens=256)
    lines = [l.strip() for l in full_output.split('\n') if l.strip()]
    cot_lines = []
    filenames_str = ""
    for i, line in enumerate(lines):
        if line.lower().startswith("final answer:"):
            cot_lines = lines[:i]
            filenames_str = line[len("Final Answer:"):].strip()
            extra = [l for l in lines[i+1:] if l and not l.lower().startswith("chain-of-thought")]
            if extra:
                filenames_str += ";" + ";".join(extra)
            break
    else:
        cot_lines = lines[:-1]
        filenames_str = lines[-1] if lines else ""
    cot = "\n".join(cot_lines)
    return filenames_str, cot



def main():
    s4_csv = "../result/s4_output.csv"
    ss1_csv = "../result/ss1_output_v2.csv"
    with open(s4_csv, "r", encoding="utf-8", newline='') as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = reader[0].keys()
        fieldnames = list(fieldnames) + ["Matched_Filenames", "LLM_CoT"]
    results = []

    for idx, row in enumerate(reader):
        assembly = row["Assembly"].strip()
        specification = row.get("Specification", "").strip()
        dir_path = os.path.join(DATA_ROOT, assembly)
        assembly_img_path = os.path.join(dir_path, "assembly.png")
        if not os.path.exists(assembly_img_path):
            print(f"[WARN] Assembly image not found: {assembly_img_path}")
            row["Matched_Filenames"] = "[NO ASSEMBLY IMAGE]"
            results.append(row)
            continue
        part_img_paths = [
            os.path.join(dir_path, fn)
            for fn in os.listdir(dir_path)
            if fn.lower().endswith('.png') and fn != "assembly.png"
        ]
        if not part_img_paths:
            print(f"[WARN] No part images found in {dir_path}")
            row["Matched_Filenames"] = "[NO PART IMAGES]"
            results.append(row)
            continue

        # Step 1: part descriptions (json)
        part2desc = generate_part_descriptions(assembly_img_path, part_img_paths)
        json_out_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
        with open(json_out_path, "w", encoding="utf-8") as jf:
            json.dump(part2desc, jf, indent=2, ensure_ascii=False)
        # Step 2: LLM-based matching
        filenames_str, cot = infer_matching_parts(assembly_img_path, part2desc, specification, idx)
        row["LLM_CoT"] = cot
        filenames_str = filenames_str.strip()
        valid_names = [os.path.basename(p) for p in part_img_paths]
        matched = [fn.strip() for fn in filenames_str.split(";") if fn.strip() in valid_names]
        if not matched:
            row["Matched_Filenames"] = "[NOT FOUND]"
            print(f"[WARN] LLM returned no valid matches for {assembly}")
        else:
            row["Matched_Filenames"] = ";".join(matched)
        results.append(row)
        print(f"[OK] {assembly} done. Matched: {row['Matched_Filenames']}")

    with open(ss1_csv, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    print("All done.")


if __name__ == "__main__":
    main()
