import os
import csv
import json
import requests
import base64
import mimetypes
import pathlib
import time
import random
import concurrent.futures

ENDPOINTS = ["xx", "xx"]
API_KEYS = ["xx", "xx"]

MAX_RETRY = 3
RETRY_BASE_SEC = 2
MAX_WORKERS = 4  # Number of concurrent threads

DATA_ROOT = "../data/a1.0.0_00"
JSON_OUT_ROOT = "../result/ss1_json_v2"
INPUT_CSV = "../result/ss1_output_v2.csv"
OUTPUT_CSV = "../result/ss3_output_cot_corrected.csv"

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=512):
    for i in range(retry):
        endpoint = ENDPOINTS[idx % len(ENDPOINTS)]
        api_key = API_KEYS[idx % len(API_KEYS)]
        headers = {"Content-type": "application/json", "api-key": api_key}
        body = {"messages": [{"role": "user", "content": messages}], "max_tokens": max_tokens}
        try:
            res = requests.post(endpoint, headers=headers, json=body, timeout=180)
            res.raise_for_status()
            result = res.json()
            if "choices" in result:
                content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            elif "data" in result and "output" in result["data"]:
                content = result["data"]["output"]
            else:
                content = ""
            return content.strip()
        except Exception as e:
            print(f"[API {endpoint}] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"

def infer_corrected_cot(
    assembly_img_path,
    part2desc,              # dict, filename -> description
    specification,          # string
    prev_cot,               # string, previous LLM CoT (with Final Answer)
    gt_filenames,           # string, e.g. 'part1;part2'
    idx
):
    desc_list = "\n".join([f"{k}: {v}" for k, v in part2desc.items()])
    prompt = (
        "You are an expert mechanical engineer with a sharp analytical mind. "
        "You are given the assembly image, the descriptions of all parts (each as 'filename: description'), the inspection specification, and a previous reasoning process (including its step-by-step thoughts and its Final Answer).\n"
        "Your job:\n"
        "1. Carefully read the previous reasoning step-by-step. Follow along and reproduce the steps until you encounter the first error or mistake.\n"
        "2. Once you spot the first mistake, stop following the previous reasoning and use a natural transition phrase (such as: “But, wait, let’s pause and examine this more carefully.” or “Wait, something seems off. Let’s pause and consider what we know so far.”) to point out the error and correct it.\n"
        "3. From that point on, continue the reasoning process in your own words, step-by-step, until you reach the correct answer (i.e., the filenames consistent with the correct ground-truth solution).\n"
        "4. Do not mention “previous attempt” or “ground-truth solution” explicitly. Make your reasoning sound like a student discovering and correcting their own mistake in real time.\n"
        "5. If the previous reasoning is already correct, simply reproduce the previous reasoning and the final answer as is.\n"
        "6. End your output with a “Final Answer:” line followed by the filenames (from the keys above), separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Part descriptions:\n"
        f"{desc_list}\n"
        f"Specification: {specification}\n"
        "Previous Reasoning:\n"
        f"{prev_cot}\n"
        "\n"
        "The correct filenames (for your reference):\n"
        f"{gt_filenames}\n"
    )
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image_url", "image_url": {"url": to_data_url(assembly_img_path)}},
        {"type": "text", "text": prompt}
    ]
    corrected_cot = call_llm_with_retry(messages, idx, max_tokens=512)
    return corrected_cot

def process_row(row, idx):
    assembly = row["Assembly"].strip()
    specification = row.get("Specification", "").strip()
    prev_cot = row.get("LLM_CoT", "").strip()
    gt_filenames = row.get("Matched_Part_Names", "").strip()
    dir_path = os.path.join(DATA_ROOT, assembly)
    assembly_img_path = os.path.join(dir_path, "assembly.png")

    # Load desc_list
    json_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
    if not os.path.exists(json_path):
        print(f"[WARN] part2desc json not found: {json_path}")
        row["LLM_CoT_Corrected"] = "[NO PART2DESC JSON]"
        return row
    with open(json_path, "r", encoding="utf-8") as jf:
        part2desc = json.load(jf)

    if not os.path.exists(assembly_img_path):
        print(f"[WARN] Assembly image not found: {assembly_img_path}")
        row["LLM_CoT_Corrected"] = "[NO ASSEMBLY IMAGE]"
        return row

    if not prev_cot or not gt_filenames:
        print(f"[WARN] Missing previous CoT or GT for {assembly}")
        row["LLM_CoT_Corrected"] = "[NO PREV_COT OR GT]"
        return row

    corrected_cot = infer_corrected_cot(
        assembly_img_path,
        part2desc,
        specification,
        prev_cot,
        gt_filenames,
        idx
    )
    row["LLM_CoT_Corrected"] = corrected_cot
    print(f"[OK] {assembly} done.")
    return row

def main():
    with open(INPUT_CSV, "r", encoding="utf-8", newline='') as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = list(reader[0].keys())
        if "LLM_CoT_Corrected" not in fieldnames:
            fieldnames.append("LLM_CoT_Corrected")

    results = [None] * len(reader)
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = {executor.submit(process_row, row, idx): idx for idx, row in enumerate(reader)}
        for future in concurrent.futures.as_completed(futures):
            idx = futures[future]
            try:
                row = future.result()
            except Exception as exc:
                print(f"[FATAL] Row {idx} failed with exception: {exc}")
                row = reader[idx]
                row["LLM_CoT_Corrected"] = "[EXCEPTION]"
            results[idx] = row

    with open(OUTPUT_CSV, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    print(f"All done. Results saved to {OUTPUT_CSV}")

if __name__ == "__main__":
    main()
