import os
import csv
import requests
import base64
import mimetypes
import pathlib
import concurrent.futures
import time
import random
import json

ENDPOINTS = ["xx", "xx"]
API_KEYS = ["xx", "xx"]

MAX_WORKERS = 4
MAX_RETRY = 3
RETRY_BASE_SEC = 2


DATA_ROOT = "../data/a1.0.0_00"
JSON_OUT_ROOT = "../result/ss1_json_v2"
os.makedirs(JSON_OUT_ROOT, exist_ok=True)

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=256):
    for i in range(retry):
        endpoint = ENDPOINTS[idx % len(ENDPOINTS)]
        api_key = API_KEYS[idx % len(API_KEYS)]
        headers = {"Content-type": "application/json", "api-key": api_key}
        body = {"messages": [{"role": "user", "content": messages}], "max_tokens": max_tokens}
        try:
            res = requests.post(endpoint, headers=headers, json=body, timeout=120)
            res.raise_for_status()
            result = res.json()
            if "choices" in result:
                content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            elif "data" in result and "output" in result["data"]:
                content = result["data"]["output"]
            else:
                content = ""
            return content.strip()
        except Exception as e:
            print(f"[API {endpoint}] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"


def generate_part_descriptions(assembly_img_path, part_img_paths):
    part2desc = {}
    def process(idx, part_img_path):
        part_filename = os.path.basename(part_img_path)
        messages = [
            {"type": "text", "text": "Image 1: Assembly (assembled product)"},
            {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
            {"type": "text", "text": "Image 2: Individual part"},
            {"type": "image_url", "image_url": {"url": to_data_url(part_img_path)}},
            {"type": "text", "text": 
"""You are an expert mechanical engineer. Given Image 1 (the assembly) and Image 2 (an individual part from the assembly), please generate a concise and descriptive noun phrase (not a full sentence). The phrase should briefly describe the part's main shape and any key features, in a way that clearly distinguishes it from the other parts in the assembly. Avoid generic names like "part" or "component". Be specific about the shape and any holes, slots, or functional features.
Your output should be a single noun phrase.
For example:
    - A conical mount with a forked top;
    - A cylindrical pin;
    - Two plates with each having holes;
    - A flat round disk with three small holes;
    - A rectangular bracket with two mounting slots.
"""}
    ]
        desc = call_llm_with_retry(messages, idx)
        return part_filename, desc

    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [
            executor.submit(process, idx, part_img_path)
            for idx, part_img_path in enumerate(part_img_paths)
        ]
        for future in concurrent.futures.as_completed(futures):
            part_filename, desc = future.result()
            part2desc[part_filename] = desc

    # print("Step 1 - part descriptives JSON:\n", json.dumps(part2desc, indent=2, ensure_ascii=False))
    return part2desc


def infer_matching_parts(assembly_img_path, part2desc, specification, idx):
    desc_list = "\n".join([f"{k}: {v}" for k, v in part2desc.items()])
    prompt = (
        "You are an expert mechanical engineer.\n"
        "The image above shows the complete assembly. Each part description below corresponds to a unique part image from this assembly, "
        "mapped as 'filename: description'.\n"
        "You should answer the filenames of the parts that correspond to the specification.\n"
        "Here are all the part descriptions:\n"
        f"{desc_list}\n"
        f"Specification: {specification}\n"
        "Your task:\n"
        "1. Think step by step (Chain-of-Thought) and explain how you identify the required part(s).\n"
        "2. In the last line, write 'Final Answer:' followed by only the selected part filenames (from the keys above), separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Example output:\n"
        "Chain-of-Thought:\n"
        "First, I check the descriptions of all parts. Only part1.png and part2.png are described as cylindrical pins. Therefore, the required parts are part1.png and part2.png.\n"
        "Final Answer:\n"
        "part1.png;part2.png\n"
    )
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
        {"type": "text", "text": prompt}
    ]
    full_output = call_llm_with_retry(messages, idx, max_tokens=256)
    # Extract CoT and filenames from output
    lines = [l.strip() for l in full_output.split('\n') if l.strip()]
    cot_lines = []
    filenames_str = ""
    for i, line in enumerate(lines):
        if line.lower().startswith("final answer:"):
            # Everything before this is CoT
            cot_lines = lines[:i]
            filenames_str = line[len("Final Answer:"):].strip()
            # If any files after this line (just in case LLM spits out multiple), join with semicolons
            extra = [l for l in lines[i+1:] if l and not l.lower().startswith("chain-of-thought")]
            if extra:
                filenames_str += ";" + ";".join(extra)
            break
    else:
        # fallback: treat last nonempty line as filenames
        cot_lines = lines[:-1]
        filenames_str = lines[-1] if lines else ""
    cot = "\n".join(cot_lines)
    return filenames_str, cot



def main():
    s4_csv = "../result/s4_output.csv"
    ss1_csv = "../result/ss1_output_v2.csv"
    with open(s4_csv, "r", encoding="utf-8", newline='') as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = reader[0].keys()
        fieldnames = list(fieldnames) + ["Matched_Filenames", "LLM_CoT"]
    results = []

    for idx, row in enumerate(reader):
        assembly = row["Assembly"].strip()
        specification = row.get("Specification", "").strip()
        dir_path = os.path.join(DATA_ROOT, assembly)
        assembly_img_path = os.path.join(dir_path, "assembly.png")
        if not os.path.exists(assembly_img_path):
            print(f"[WARN] Assembly image not found: {assembly_img_path}")
            row["Matched_Filenames"] = "[NO ASSEMBLY IMAGE]"
            results.append(row)
            continue
        # Find all part images (skip assembly.png itself)
        part_img_paths = [
            os.path.join(dir_path, fn)
            for fn in os.listdir(dir_path)
            if fn.lower().endswith('.png') and fn != "assembly.png"
        ]
        if not part_img_paths:
            print(f"[WARN] No part images found in {dir_path}")
            row["Matched_Filenames"] = "[NO PART IMAGES]"
            results.append(row)
            continue

        # Step 1: part descriptions (json)
        part2desc = generate_part_descriptions(assembly_img_path, part_img_paths)
        # Save json
        json_out_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
        with open(json_out_path, "w", encoding="utf-8") as jf:
            json.dump(part2desc, jf, indent=2, ensure_ascii=False)
        # Step 2: LLM-based matching
        filenames_str, cot = infer_matching_parts(assembly_img_path, part2desc, specification, idx)
        row["LLM_CoT"] = cot  # Add new column
        filenames_str = filenames_str.strip()
        # Parse result, filter only valid part filenames
        valid_names = [os.path.basename(p) for p in part_img_paths]
        matched = [fn.strip() for fn in filenames_str.split(";") if fn.strip() in valid_names]
        if not matched:
            row["Matched_Filenames"] = "[NOT FOUND]"
            print(f"[WARN] LLM returned no valid matches for {assembly}")
        else:
            row["Matched_Filenames"] = ";".join(matched)
        results.append(row)
        print(f"[OK] {assembly} done. Matched: {row['Matched_Filenames']}")

    # Write final ss1.csv
    with open(ss1_csv, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    print("All done.")


if __name__ == "__main__":
    main()
