import csv
import requests
import base64
import mimetypes
import pathlib
import os
import concurrent.futures
import time
import random

ENDPOINTS = ["xx", "xx"]
API_KEYS = ["xx", "xx"]
DATA_ROOT = "../data/a1.0.0_00"

MAX_WORKERS = 4
MAX_RETRY = 3
RETRY_BASE_SEC = 2

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def build_llm_message(assembly_img_path, descriptions):
    desc_list_str = '\n'.join([f"{i+1}. {desc}" for i, desc in enumerate(descriptions)])
    prompt = f"""You are an expert mechanical engineer.
Given an image of an assembled product (assembly) and a list of its part descriptions below:

Part descriptions:
{desc_list_str}

Your task:
1. Review the assembly image and the list of part descriptions.
2. Choose any two part descriptions that are most likely to have a direct physical, spatial, or functional relationship in the assembly (such as fit, mounting, alignment, or coupling).
3. Generate one English specification sentence (inspection/check item) that describes the required relationship, fit, or assembly condition between these two parts, as would appear in a manufacturing or assembly checklist.
4. Your specification should be clear, specific, and professional, mentioning both selected part descriptions explicitly.
5. Output only one specification sentence. Do not explain your reasoning.
6. Output format: The selected two part descriptions (exactly as shown above, separated by a semicolon), then a line break, then the specification sentence.

For example, given descriptions like:
  1. A cylindrical pin
  2. A flat plate with holes
Output:
A cylindrical pin;A flat plate with holes
The cylindrical pin must be fully inserted into one of the holes on the flat plate.
"""
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
        {"type": "text", "text": prompt}
    ]
    return messages

def call_llm_with_retry(messages, idx, retry=MAX_RETRY):
    """
    Concurrently allocate API resources.
    """
    for i in range(retry):
        endpoint = ENDPOINTS[idx % len(ENDPOINTS)]
        api_key = API_KEYS[idx % len(API_KEYS)]
        headers = {
            "Content-type": "application/json",
            "api-key": api_key,
        }
        body = {"messages": [{"role": "user", "content": messages}], "max_tokens": 256}
        try:
            res = requests.post(endpoint, headers=headers, json=body, timeout=120)
            res.raise_for_status()
            result = res.json()
            if "choices" in result:
                content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            elif "data" in result and "output" in result["data"]:
                content = result["data"]["output"]
            else:
                content = ""
            lines = [line.strip() for line in content.split('\n') if line.strip()]
            if len(lines) >= 2:
                descriptive_pair = lines[0]
                specification = lines[1]
            else:
                descriptive_pair = ""
                specification = content.strip()
            if "[ERROR]" in (descriptive_pair, specification):
                raise RuntimeError("API returned error flag.")
            return descriptive_pair, specification
        except Exception as e:
            print(f"[API {endpoint}] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]", "[ERROR]"

def process_row(row, idx):
    assembly_name = row["Assembly"].strip()
    descriptive_text = row.get("Descriptive", "")
    descriptions = [d.strip() for d in descriptive_text.split(";") if d.strip()]
    assembly_img_path = os.path.join(DATA_ROOT, assembly_name, "assembly.png")
    if not os.path.exists(assembly_img_path):
        print(f"[WARN] Assembly image not found: {assembly_img_path}")
        return row, "[NO ASSEMBLY IMAGE]", "[NO ASSEMBLY IMAGE]"
    elif not descriptions or len(descriptions) < 2:
        print(f"[WARN] Not enough part descriptions for {assembly_name}")
        return row, "[NO PART DESCRIPTIONS]", "[NO PART DESCRIPTIONS]"
    else:
        messages = build_llm_message(assembly_img_path, descriptions)
        descriptive_pair, specification = call_llm_with_retry(messages, idx)
        return row, descriptive_pair, specification

def main():
    input_csv = "../result/s2_output.csv"
    output_csv = "../result/s3_output.csv"
    with open(input_csv, "r", encoding="utf-8", newline='') as fin:
        reader = csv.DictReader(fin)
        rows = list(reader)
        fieldnames = reader.fieldnames or []
    for col in ["Descriptive_Pair", "Specification"]:
        if col not in fieldnames:
            fieldnames.append(col)
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [
            executor.submit(process_row, row, idx)
            for idx, row in enumerate(rows)
        ]
        for future in concurrent.futures.as_completed(futures):
            try:
                row, descriptive_pair, specification = future.result()
                row["Descriptive_Pair"] = descriptive_pair
                row["Specification"] = specification
                results.append(row)
                print(f"[DONE] {row['Assembly']} : {descriptive_pair} | {specification}")
            except Exception as exc:
                print(f"[FATAL ERROR] {exc}")

    results.sort(key=lambda r: rows.index(r))
    with open(output_csv, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

if __name__ == "__main__":
    main()
