import os
import csv
import json
import requests
import base64
import mimetypes
import pathlib
import time
import random
import concurrent.futures
from difflib import SequenceMatcher

ENDPOINTS = ["xx", "xx"]
API_KEYS = ["xx", "xx"]

MAX_RETRY = 3
RETRY_BASE_SEC = 2
MAX_WORKERS = 4

DATA_ROOT = "../data/a1.0.0_00"
JSON_OUT_ROOT = "../result/ss1_json_v2"
INPUT_CSV = "../result/ss3_output_cot_corrected.csv"
OUTPUT_CSV = "../result/ss4_output.csv"

def to_data_url(path: str) -> str:
    path = pathlib.Path(path)
    mime = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("utf-8")
    return f"data:{mime};base64,{encoded}"

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=512):
    for i in range(retry):
        endpoint = ENDPOINTS[idx % len(ENDPOINTS)]
        api_key = API_KEYS[idx % len(API_KEYS)]
        headers = {"Content-type": "application/json", "api-key": api_key}
        body = {"messages": [{"role": "user", "content": messages}], "max_tokens": max_tokens}
        try:
            res = requests.post(endpoint, headers=headers, json=body, timeout=180)
            res.raise_for_status()
            result = res.json()
            if "choices" in result:
                content = result.get("choices", [{}])[0].get("message", {}).get("content", "")
            elif "data" in result and "output" in result["data"]:
                content = result["data"]["output"]
            else:
                content = ""
            return content.strip()
        except Exception as e:
            print(f"[API {endpoint}] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"

def calc_similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()

def rag_topn_rows(current_row, all_rows, topn=2):
    cur_spec = current_row["Specification"]
    # Only select candidates with Part_count < 10, and not current row
    candidates = [
        r for r in all_rows
        if r["Assembly"] != current_row["Assembly"]
        and r.get("Part_count", "").isdigit()
        and int(r["Part_count"]) < 10
        and r.get("LLM_CoT", "").strip()
        and r.get("Matched_Part_Names", "").strip()
    ]
    # Compute similarity
    scored = [
        (calc_similarity(cur_spec, r["Specification"]), r)
        for r in candidates
    ]
    scored = sorted(scored, key=lambda x: -x[0])
    return [r for sim, r in scored[:topn]]

def get_fewshot_prompt(row, json_root, data_root, idx):
    assembly = row["Assembly"]
    assembly_img_path = os.path.join(data_root, assembly, "assembly.png")
    json_path = os.path.join(json_root, f"{assembly}.json")
    asm_img_url = to_data_url(assembly_img_path) if os.path.exists(assembly_img_path) else "[Image Not Found]"
    desc = {}
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as jf:
            desc = json.load(jf)
    desc_lines = "\n".join([f"{k}: {v}" for k, v in desc.items()])
    cot = row.get("LLM_CoT", "").strip()
    final_ans = row.get("Matched_Part_Names", "").strip()
    spec = row.get("Specification", "").strip()
    prompt = (
        f"Example {idx+1}:\n"
        f"Assembly image:\n"
        f"[image]\n"
        f"Part descriptions:\n{desc_lines}\n"
        f"Specification:\n{spec}\n"
        f"Chain-of-Thought:\n{cot}\n"
        f"Final Answer:\n{final_ans}\n"
    )
    return prompt, asm_img_url

def extract_final_answer(output):
    lines = [l.strip() for l in output.split('\n') if l.strip()]
    filenames = ""
    for i in range(len(lines)-1, -1, -1):
        if lines[i].lower().startswith("final answer:"):
            filenames = lines[i][len("final answer:"):].strip()
            break
    if not filenames and lines:
        filenames = lines[-1]
    filenames = filenames.replace(" ", "")
    return filenames

def process_row_with_rag(row, idx, all_rows):
    rag_rows = rag_topn_rows(row, all_rows, topn=2)
    rag_prompts = []
    rag_imgs = []
    for i, rag_row in enumerate(rag_rows):
        p, img_url = get_fewshot_prompt(rag_row, JSON_OUT_ROOT, DATA_ROOT, i)
        rag_prompts.append(p)
        rag_imgs.append(img_url)
    # main question
    assembly = row["Assembly"]
    assembly_img_path = os.path.join(DATA_ROOT, assembly, "assembly.png")
    asm_img_url = to_data_url(assembly_img_path) if os.path.exists(assembly_img_path) else "[Image Not Found]"
    json_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
    desc = {}
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as jf:
            desc = json.load(jf)
    desc_lines = "\n".join([f"{k}: {v}" for k, v in desc.items()])
    spec = row.get("Specification", "").strip()
    fs_block = "\n".join(rag_prompts)
    main_prompt = (
        "Now, for the following question, use the above reasoning as reference and answer step-by-step:\n"
        "Assembly image:\n"
        "[image below]\n"
        f"Part descriptions:\n{desc_lines}\n"
        f"Specification:\n{spec}\n"
        "Your task:\n"
        "1. Think step by step (Chain-of-Thought) and explain how you identify the required part(s).\n"
        "2. In the last line, write 'Final Answer:' followed by only the selected part filenames (from the keys above), separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Example output:\n"
        "Chain-of-Thought:\n"
        "First, I check the descriptions of all parts. Only part1.png and part2.png are described as cylindrical pins. Therefore, the required parts are part1.png and part2.png.\n"
        "Final Answer:\n"
        "part1.png;part2.png\n"
    )
    messages = []
    for rag_p in rag_prompts:
        messages.append({"type": "text", "text": rag_p})
    messages.append({"type": "text", "text": main_prompt})
    messages.append({"type": "image_url", "image_url": {"url": asm_img_url}})
    full_output = call_llm_with_retry(messages, idx, max_tokens=1024)
    filenames_str = extract_final_answer(full_output)
    row["LLM_CoT_Corrected_with_RAG"] = full_output
    row["Final_Answer_with_RAG"] = filenames_str
    print(f"[OK] {assembly} with RAG done.")
    return row

def main():
    with open(INPUT_CSV, "r", encoding="utf-8", newline='') as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = list(reader[0].keys())
        if "LLM_CoT_Corrected_with_RAG" not in fieldnames:
            fieldnames.append("LLM_CoT_Corrected_with_RAG")
        if "Final_Answer_with_RAG" not in fieldnames:
            fieldnames.append("Final_Answer_with_RAG")

    results = [None] * len(reader)
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = {executor.submit(process_row_with_rag, row, idx, reader): idx for idx, row in enumerate(reader)}
        for future in concurrent.futures.as_completed(futures):
            idx = futures[future]
            try:
                row = future.result()
            except Exception as exc:
                print(f"[FATAL] Row {idx} failed with exception: {exc}")
                row = reader[idx]
                row["LLM_CoT_Corrected_with_RAG"] = "[EXCEPTION]"
                row["Final_Answer_with_RAG"] = "[EXCEPTION]"
            results[idx] = row

    with open(OUTPUT_CSV, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    print(f"All done. Results saved to {OUTPUT_CSV}")

if __name__ == "__main__":
    main()
