import os
import csv
import json
import time
import random
import concurrent.futures
from difflib import SequenceMatcher
from typing import List, Dict, Any

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


MODEL_NAME = "./Qwen2-VL-2B-Instruct"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def to_file_uri(path: str) -> str:
    return "file://" + os.path.abspath(path)


def qwen_generate_from_messages(
    messages_content: List[Dict[str, Any]],
    max_tokens: int = 512,
) -> str:
    chat_messages = [
        {
            "role": "user",
            "content": messages_content,
        }
    ]

    text = processor.apply_chat_template(
        chat_messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(chat_messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(DEVICE)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids
        in zip(inputs["input_ids"], generated_ids)
    ]

    output_texts = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return output_texts[0].strip() if output_texts else ""



MAX_RETRY = 3
RETRY_BASE_SEC = 2
MAX_WORKERS = 1

DATA_ROOT = "../data/a1.0.0_00"

JSON_OUT_ROOT = "../result/ss1_json_qwen2vl2b"
INPUT_CSV = "../result/ss3_output_cot_corrected_qwen2vl2b.csv"
OUTPUT_CSV = "../result/ss4_output_qwen2vl2b.csv"


def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=512):
    for i in range(retry):
        try:
            content = qwen_generate_from_messages(messages, max_tokens=max_tokens)
            return content.strip()
        except Exception as e:
            print(f"[Qwen2-VL] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"


def calc_similarity(a, b):
    return SequenceMatcher(None, a, b).ratio()


def rag_topn_rows(current_row, all_rows, topn=2):
    cur_spec = current_row["Specification"]
    candidates = [
        r for r in all_rows
        if r["Assembly"] != current_row["Assembly"]
        and r.get("Part_count", "").isdigit()
        and int(r["Part_count"]) < 10
        and r.get("LLM_CoT", "").strip()
        and r.get("Matched_Part_Names", "").strip()
    ]
    scored = [
        (calc_similarity(cur_spec, r["Specification"]), r)
        for r in candidates
    ]
    scored = sorted(scored, key=lambda x: -x[0])
    return [r for sim, r in scored[:topn]]


def get_fewshot_prompt(row, json_root, data_root, idx):
    assembly = row["Assembly"]
    assembly_img_path = os.path.join(data_root, assembly, "assembly.png")
    json_path = os.path.join(json_root, f"{assembly}.json")

    desc = {}
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as jf:
            desc = json.load(jf)
    desc_lines = "\n".join([f"{k}: {v}" for k, v in desc.items()])

    cot = row.get("LLM_CoT", "").strip()
    final_ans = row.get("Matched_Part_Names", "").strip()
    spec = row.get("Specification", "").strip()

    prompt = (
        f"Example {idx+1}:\n"
        f"Assembly image:\n"
        f"[image]\n"
        f"Part descriptions:\n{desc_lines}\n"
        f"Specification:\n{spec}\n"
        f"Chain-of-Thought:\n{cot}\n"
        f"Final Answer:\n{final_ans}\n"
    )
    return prompt


def extract_final_answer(output):
    lines = [l.strip() for l in output.split('\n') if l.strip()]
    filenames = ""
    for i in range(len(lines) - 1, -1, -1):
        if lines[i].lower().startswith("final answer:"):
            filenames = lines[i][len("final answer:"):].strip()
            break
    if not filenames and lines:
        filenames = lines[-1]
    filenames = filenames.replace(" ", "")
    return filenames


def process_row_with_rag(row, idx, all_rows):
    rag_rows = rag_topn_rows(row, all_rows, topn=2)
    rag_prompts = []
    for i, rag_row in enumerate(rag_rows):
        p = get_fewshot_prompt(rag_row, JSON_OUT_ROOT, DATA_ROOT, i)
        rag_prompts.append(p)

    assembly = row["Assembly"]
    assembly_img_path = os.path.join(DATA_ROOT, assembly, "assembly.png")

    if not os.path.exists(assembly_img_path):
        print(f"[WARN] Assembly image not found: {assembly_img_path}")
        row["LLM_CoT_Corrected_with_RAG"] = "[NO ASSEMBLY IMAGE]"
        row["Final_Answer_with_RAG"] = "[NO ASSEMBLY IMAGE]"
        return row

    asm_img_uri = to_file_uri(assembly_img_path)
    json_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
    desc = {}
    if os.path.exists(json_path):
        with open(json_path, "r", encoding="utf-8") as jf:
            desc = json.load(jf)
    desc_lines = "\n".join([f"{k}: {v}" for k, v in desc.items()])
    spec = row.get("Specification", "").strip()

    fs_block = "\n".join(rag_prompts)

    main_prompt = (
        f"{fs_block}\n\n"
        "Now, for the following question, use the above reasoning as reference and answer step-by-step:\n"
        "Assembly image:\n"
        "[image below]\n"
        f"Part descriptions:\n{desc_lines}\n"
        f"Specification:\n{spec}\n"
        "Your task:\n"
        "1. Think step by step (Chain-of-Thought) and explain how you identify the required part(s).\n"
        "2. In the last line, write 'Final Answer:' followed by only the selected part filenames (from the keys above), "
        "separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Example output:\n"
        "Chain-of-Thought:\n"
        "First, I check the descriptions of all parts. Only part1.png and part2.png are described as cylindrical pins. "
        "Therefore, the required parts are part1.png and part2.png.\n"
        "Final Answer:\n"
        "part1.png;part2.png\n"
    )

    messages = []
    messages.append({"type": "text", "text": main_prompt})
    messages.append({"type": "image", "image": asm_img_uri})

    full_output = call_llm_with_retry(messages, idx, max_tokens=1024)
    filenames_str = extract_final_answer(full_output)

    row["LLM_CoT_Corrected_with_RAG"] = full_output
    row["Final_Answer_with_RAG"] = filenames_str
    print(f"[OK] {assembly} with RAG done.")
    return row


def main():
    with open(INPUT_CSV, "r", encoding="utf-8", newline='') as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = list(reader[0].keys())
        if "LLM_CoT_Corrected_with_RAG" not in fieldnames:
            fieldnames.append("LLM_CoT_Corrected_with_RAG")
        if "Final_Answer_with_RAG" not in fieldnames:
            fieldnames.append("Final_Answer_with_RAG")

    results = [None] * len(reader)
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = {
            executor.submit(process_row_with_rag, row, idx, reader): idx
            for idx, row in enumerate(reader)
        }
        for future in concurrent.futures.as_completed(futures):
            idx = futures[future]
            try:
                row = future.result()
            except Exception as exc:
                print(f"[FATAL] Row {idx} failed with exception: {exc}")
                row = reader[idx]
                row["LLM_CoT_Corrected_with_RAG"] = "[EXCEPTION]"
                row["Final_Answer_with_RAG"] = "[EXCEPTION]"
            results[idx] = row

    with open(OUTPUT_CSV, "w", encoding="utf-8", newline='') as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"All done. Results saved to {OUTPUT_CSV}")


if __name__ == "__main__":
    main()
