import os
import csv
import json
import time
import random
import concurrent.futures
from typing import List, Dict, Any

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

MODEL_NAME = "./Qwen2-VL-2B-Instruct"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def to_file_uri(path: str) -> str:
    return "file://" + os.path.abspath(path)


def qwen_generate_from_messages(
    messages_content: List[Dict[str, Any]],
    max_tokens: int = 512,
) -> str:
    chat_messages = [
        {
            "role": "user",
            "content": messages_content,
        }
    ]

    text = processor.apply_chat_template(
        chat_messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(chat_messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(DEVICE)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids
        in zip(inputs["input_ids"], generated_ids)
    ]

    output_texts = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return output_texts[0].strip() if output_texts else ""


MAX_RETRY = 3
RETRY_BASE_SEC = 2
MAX_WORKERS = 4

DATA_ROOT = "../data/a1.0.0_00"

JSON_OUT_ROOT = "../result/ss1_json_qwen2vl2b"
INPUT_CSV = "../result/ss1_output_qwen2vl2b.csv"
OUTPUT_CSV = "../result/ss3_output_cot_corrected_qwen2vl2b.csv"

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=512):
    for i in range(retry):
        try:
            content = qwen_generate_from_messages(messages, max_tokens=max_tokens)
            return content.strip()
        except Exception as e:
            print(f"[Qwen2-VL] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"

def infer_corrected_cot(
    assembly_img_path,
    part2desc,
    specification, 
    prev_cot,
    gt_filenames,
    idx
):
    desc_list = "\n".join([f"{k}: {v}" for k, v in part2desc.items()])
    prompt = (
        "You are an expert mechanical engineer with a sharp analytical mind. "
        "You are given the assembly image, the descriptions of all parts (each as 'filename: description'), "
        "the inspection specification, and a previous reasoning process (including its step-by-step thoughts "
        "and its Final Answer).\n"
        "Your job:\n"
        "1. Carefully read the previous reasoning step-by-step. Follow along and reproduce the steps until you "
        "encounter the first error or mistake.\n"
        "2. Once you spot the first mistake, stop following the previous reasoning and use a natural transition "
        "phrase (such as: “But, wait, let’s pause and examine this more carefully.” or “Wait, something seems off. "
        "Let’s pause and consider what we know so far.”) to point out the error and correct it.\n"
        "3. From that point on, continue the reasoning process in your own words, step-by-step, until you reach "
        "the correct answer (i.e., the filenames consistent with the correct ground-truth solution).\n"
        "4. Do not mention “previous attempt” or “ground-truth solution” explicitly. Make your reasoning sound "
        "like a student discovering and correcting their own mistake in real time.\n"
        "5. If the previous reasoning is already correct, simply reproduce the previous reasoning and the final "
        "answer as is.\n"
        "6. End your output with a “Final Answer:” line followed by the filenames (from the keys above), separated "
        "by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Part descriptions:\n"
        f"{desc_list}\n"
        f"Specification: {specification}\n"
        "Previous Reasoning:\n"
        f"{prev_cot}\n"
        "\n"
        "The correct filenames (for your reference):\n"
        f"{gt_filenames}\n"
    )
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image", "image": to_file_uri(assembly_img_path)},
        {"type": "text", "text": prompt},
    ]
    corrected_cot = call_llm_with_retry(messages, idx, max_tokens=512)
    return corrected_cot


def process_row(row, idx):
    assembly = row["Assembly"].strip()
    specification = row.get("Specification", "").strip()
    prev_cot = row.get("LLM_CoT", "").strip()
    gt_filenames = row.get("Matched_Part_Names", "").strip()
    dir_path = os.path.join(DATA_ROOT, assembly)
    assembly_img_path = os.path.join(dir_path, "assembly.png")

    json_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
    if not os.path.exists(json_path):
        print(f"[WARN] part2desc json not found: {json_path}")
        row["LLM_CoT_Corrected"] = "[NO PART2DESC JSON]"
        return row
    with open(json_path, "r", encoding="utf-8") as jf:
        part2desc = json.load(jf)

    if not os.path.exists(assembly_img_path):
        print(f"[WARN] Assembly image not found: {assembly_img_path}")
        row["LLM_CoT_Corrected"] = "[NO ASSEMBLY IMAGE]"
        return row

    if not prev_cot or not gt_filenames:
        print(f"[WARN] Missing previous CoT or GT for {assembly}")
        row["LLM_CoT_Corrected"] = "[NO PREV_COT OR GT]"
        return row

    corrected_cot = infer_corrected_cot(
        assembly_img_path,
        part2desc,
        specification,
        prev_cot,
        gt_filenames,
        idx,
    )
    row["LLM_CoT_Corrected"] = corrected_cot
    print(f"[OK] {assembly} done.")
    return row


def main():
    with open(INPUT_CSV, "r", encoding="utf-8", newline="") as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = list(reader[0].keys())
        if "LLM_CoT_Corrected" not in fieldnames:
            fieldnames.append("LLM_CoT_Corrected")

    results = [None] * len(reader)
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = {
            executor.submit(process_row, row, idx): idx
            for idx, row in enumerate(reader)
        }
        for future in concurrent.futures.as_completed(futures):
            idx = futures[future]
            try:
                row = future.result()
            except Exception as exc:
                print(f"[FATAL] Row {idx} failed with exception: {exc}")
                row = reader[idx]
                row["LLM_CoT_Corrected"] = "[EXCEPTION]"
            results[idx] = row

    with open(OUTPUT_CSV, "w", encoding="utf-8", newline="") as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print(f"All done. Results saved to {OUTPUT_CSV}")


if __name__ == "__main__":
    main()
