import os
import csv
import time
import random
import json
import concurrent.futures
from typing import List, Dict, Any

import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

MODEL_NAME = "./Qwen2-VL-2B-Instruct"

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype="auto",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_NAME)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_WORKERS = 1
MAX_RETRY = 3
RETRY_BASE_SEC = 2

DATA_ROOT = "../data/a1.0.0_00"

JSON_OUT_ROOT = "../result/ss1_json_qwen2vl2b"
os.makedirs(JSON_OUT_ROOT, exist_ok=True)

SS1_CSV_OUT = "../result/ss1_output_qwen2vl2b.csv"


def to_file_uri(path: str) -> str:
    return "file://" + os.path.abspath(path)


def qwen_generate_from_messages(
    messages_content: List[Dict[str, Any]],
    max_tokens: int = 256,
) -> str:

    chat_messages = [
        {
            "role": "user",
            "content": messages_content,
        }
    ]

    text = processor.apply_chat_template(
        chat_messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(chat_messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )

    inputs = inputs.to(DEVICE)

    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
        )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids
        in zip(inputs["input_ids"], generated_ids)
    ]

    output_texts = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return output_texts[0].strip() if output_texts else ""

def call_llm_with_retry(messages, idx, retry=MAX_RETRY, max_tokens=256):
    for i in range(retry):
        try:
            content = qwen_generate_from_messages(messages, max_tokens=max_tokens)
            return content.strip()
        except Exception as e:
            print(f"[Qwen2-VL] call failed (attempt {i+1}/{retry}), error: {e}")
            wait_sec = RETRY_BASE_SEC * (2 ** i) + random.uniform(0, 1)
            print(f"Retrying after {wait_sec:.1f} seconds...")
            time.sleep(wait_sec)
    return "[ERROR]"

def generate_part_descriptions(assembly_img_path, part_img_paths):
    part2desc = {}
    def process(idx, part_img_path):
        part_filename = os.path.basename(part_img_path)
        messages = [
            {"type": "text", "text": "Image 1: Assembly (assembled product)"},
            {"type": "image", "image": to_file_uri(assembly_img_path)},
            {"type": "text", "text": "Image 2: Individual part"},
            {"type": "image", "image": to_file_uri(part_img_path)},
            {
                "type": "text",
                "text": (
                    "You are an expert mechanical engineer. Given Image 1 (the assembly) "
                    "and Image 2 (an individual part from the assembly), please generate "
                    "a concise and descriptive noun phrase (not a full sentence). The phrase "
                    "should briefly describe the part's main shape and any key features, in a way "
                    "that clearly distinguishes it from the other parts in the assembly. Avoid "
                    'generic names like "part" or "component". Be specific about the shape and any '
                    "holes, slots, or functional features.\n"
                    "Your output should be a single noun phrase.\n"
                    "For example:\n"
                    "    - A conical mount with a forked top;\n"
                    "    - A cylindrical pin;\n"
                    "    - Two plates with each having holes;\n"
                    "    - A flat round disk with three small holes;\n"
                    "    - A rectangular bracket with two mounting slots.\n"
                ),
            },
        ]
        desc = call_llm_with_retry(messages, idx)
        return part_filename, desc

    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        futures = [
            executor.submit(process, idx, part_img_path)
            for idx, part_img_path in enumerate(part_img_paths)
        ]
        for future in concurrent.futures.as_completed(futures):
            part_filename, desc = future.result()
            part2desc[part_filename] = desc

    return part2desc


def infer_matching_parts(assembly_img_path, part2desc, specification, idx):
    desc_list = "\n".join([f"{k}: {v}" for k, v in part2desc.items()])
    prompt = (
        "You are an expert mechanical engineer.\n"
        "The image above shows the complete assembly. Each part description below corresponds "
        "to a unique part image from this assembly, mapped as 'filename: description'.\n"
        "You should answer the filenames of the parts that correspond to the specification.\n"
        "Here are all the part descriptions:\n"
        f"{desc_list}\n"
        f"Specification: {specification}\n"
        "Your task:\n"
        "1. Think step by step (Chain-of-Thought) and explain how you identify the required part(s).\n"
        "2. In the last line, write 'Final Answer:' followed by only the selected part filenames "
        "(from the keys above), separated by semicolons (;), with no extra words or punctuation.\n"
        "\n"
        "Example output:\n"
        "Chain-of-Thought:\n"
        "First, I check the descriptions of all parts. Only part1.png and part2.png are described "
        "as cylindrical pins. Therefore, the required parts are part1.png and part2.png.\n"
        "Final Answer:\n"
        "part1.png;part2.png\n"
    )
    messages = [
        {"type": "text", "text": "Assembly image:"},
        {"type": "image", "image": to_file_uri(assembly_img_path)},
        {"type": "text", "text": prompt},
    ]
    full_output = call_llm_with_retry(messages, idx, max_tokens=256)

    lines = [l.strip() for l in full_output.split("\n") if l.strip()]
    cot_lines = []
    filenames_str = ""

    for i, line in enumerate(lines):
        if line.lower().startswith("final answer:"):
            cot_lines = lines[:i]
            filenames_str = line[len("Final Answer:"):].strip()
            extra = [
                l for l in lines[i + 1:]
                if l and not l.lower().startswith("chain-of-thought")
            ]
            if extra:
                filenames_str += ";" + ";".join(extra)
            break
    else:
        if lines:
            cot_lines = lines[:-1]
            filenames_str = lines[-1]
        else:
            cot_lines = []
            filenames_str = ""

    cot = "\n".join(cot_lines)
    return filenames_str, cot


def main():
    s4_csv = "../result/s4_output.csv"
    ss1_csv = SS1_CSV_OUT

    with open(s4_csv, "r", encoding="utf-8", newline="") as fin:
        reader = list(csv.DictReader(fin))
        fieldnames = list(reader[0].keys()) + ["Matched_Filenames", "LLM_CoT"]

    results = []

    for idx, row in enumerate(reader):
        assembly = row["Assembly"].strip()
        specification = row.get("Specification", "").strip()
        dir_path = os.path.join(DATA_ROOT, assembly)
        assembly_img_path = os.path.join(dir_path, "assembly.png")

        if not os.path.exists(assembly_img_path):
            print(f"[WARN] Assembly image not found: {assembly_img_path}")
            row["Matched_Filenames"] = "[NO ASSEMBLY IMAGE]"
            results.append(row)
            continue

        part_img_paths = [
            os.path.join(dir_path, fn)
            for fn in os.listdir(dir_path)
            if fn.lower().endswith(".png") and fn != "assembly.png"
        ]

        if not part_img_paths:
            print(f"[WARN] No part images found in {dir_path}")
            row["Matched_Filenames"] = "[NO PART IMAGES]"
            results.append(row)
            continue

        part2desc = generate_part_descriptions(assembly_img_path, part_img_paths)

        json_out_path = os.path.join(JSON_OUT_ROOT, f"{assembly}.json")
        with open(json_out_path, "w", encoding="utf-8") as jf:
            json.dump(part2desc, jf, indent=2, ensure_ascii=False)

        filenames_str, cot = infer_matching_parts(
            assembly_img_path, part2desc, specification, idx
        )
        row["LLM_CoT"] = cot

        filenames_str = filenames_str.strip()
        valid_names = [os.path.basename(p) for p in part_img_paths]
        matched = [
            fn.strip()
            for fn in filenames_str.split(";")
            if fn.strip() in valid_names
        ]

        if not matched:
            row["Matched_Filenames"] = "[NOT FOUND]"
            print(f"[WARN] LLM returned no valid matches for {assembly}")
        else:
            row["Matched_Filenames"] = ";".join(matched)

        results.append(row)
        print(f"[OK] {assembly} done. Matched: {row['Matched_Filenames']}")

    with open(ss1_csv, "w", encoding="utf-8", newline="") as fout:
        writer = csv.DictWriter(fout, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)

    print("All done (Qwen2-VL-2B-Instruct).")

if __name__ == "__main__":
    main()
