#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Main Pipeline for IC Datasheet Analysis

- Stage-1: Fuses predictions from the Diagram Agent and a YOLO object
           detector to accurately locate the 'Suggest Pad' region.
- Stage-2: Classifies the IC footprint and plans parameter extraction based on
           the located region.
- Stage-3: Extracts the specific numerical parameters according to the plan.

"""

import os
import re
import gc
import json
import time
import warnings
import argparse
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional, Any

from PIL import Image

# ============================== Core Dependencies ==============================
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, GenerationConfig
from qwen_vl_utils import process_vision_info

# ============================== Environment Setup ==============================
os.environ.setdefault("HF_HUB_OFFLINE", "1")
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
try:
    from ultralytics.utils import SETTINGS
    SETTINGS.update({'checks': False})
except Exception:
    pass

# ============================== Optional Imports ==============================
try:
    from tqdm import tqdm
except ImportError:
    class tqdm:
        def __init__(self, iterable=None, total=None, desc=""): self.iterable, self.desc = iterable, desc
        def __iter__(self): return iter(self.iterable)
        def update(self, n=1): pass
        def close(self): pass
        def __enter__(self): return self
        def __exit__(self, exc_type, exc, tb): pass

try:
    from ultralytics import YOLO
    _YOLO_OK = True
except ImportError:
    YOLO, _YOLO_OK = None, False

# ============================== Utility Functions ==============================

def log(msg: str) -> None:
    print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)

def write_jsonl(path: str, rows: List[dict]) -> None:
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        for r in rows:
            r.pop("label", None)
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

class SuppressOutput:
    def __enter__(self):
        self._old_warn, warnings.simplefilter = warnings.showwarning, "ignore"
        return self
    def __exit__(self, exc_type, exc_val, exc_tb):
        warnings.showwarning = self._old_warn
        return False

def parse_assistant_response(full_text: str) -> str:
    full_text = full_text.strip()
    if "<|im_start|>assistant" in full_text:
        return full_text.split("<|im_start|>assistant")[-1].strip()
    if "</answer>" in full_text and "assistant" in full_text:
        match = re.search(r"</answer>.*?assistant\s*(.*)", full_text, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()
    if "assistant" in full_text:
        return full_text.split("assistant")[-1].strip()
    return full_text

# ============================== Prompt Templates ==============================

SYSTEM_PROMPT = "You are a helpful assistant."

S1_PROMPT = (
    "Suggest Pad is the recommended land pattern for an IC. There is only one Suggest Pad image in the picture. Please locate the Suggest Pad image and give the area of the Suggest Pad image in pure number pair (x,y,width,height) in proportion of the datasheet image size, where x and y are the coordination of the left-upper corner of the area, width and height are the dimension of the area. Note that (x,y)=(0,0) denote the left upper corner of the image."
)

S2_CLS_PROMPT_TMPL = (
    "Suggest Pad is the recommended land pattern for an IC. There is only one Suggest Pad image in the picture. The area of the Suggest Pad image is located by pure number pair (x,y,width,height)={bbox} in proportion of the datasheet image size, where x and y are the coordination of the left-upper corner of the area, width and height are the dimension of the area. Note that (x,y)=(0,0) denote the left upper corner of the image. Based on the located image, please classify this IC footprint as \"2-sides\", \"4-sides\", \"grid\" or \"other\"."
)

S2_PLAN_PROMPT = (
    "Based on the located image and the above classification, please choose the appropriate parameters to describe the positions and dimensions of the IC pins. Note that for \"2-sides\" ICs, parameters are chosen from: row, column, row spacing, column spacing, inner row spacing, outer row spacing, inner column spacing, outer column spacing, dx, dy, diameter; for \"4-sides\" ICs, parameters are chosen from: pin per side, side-to-side distance, inner side distance, outer side distance, pin spacing within side, dx1, dx2, dy1, dy2, dx, dy; for \"grid\" ICs, parameters are choosen from: row, column, row spacing, column spacing, dx, dy, diameter; for \"other\" ICs, parameters are: pin count, pin coordinates, pin dimensions. The reasoning process and answers are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>."
)

eq1="row spacing = (inner row spacing + outer row spacing)/2"
eq2="row spacing = outer row spacing - dy"
eq3="row spacing = inner row spacing + dy"
eq4="column spacing = (inner column spacing + outer column spacing)/2"
eq5="column spacing = outer column spacing - dx"
eq6="column spacing = inner column spacing - dx"
eq7="side-to-side distance = (inner side distance + outer side distance)/2"
eq8="side-to-side distance = outer side distance - dx1"
eq9="side-to-side distance = inner side distance + dx1"

def gen_reasoning(IC_type: str, paras: List[str]) -> str:
    eq, reasoning_line = '', ''
    if IC_type == "2-sides":
        rc = ''
        if ("inner row spacing" in paras) and ("outer row spacing" in paras): eq, rc = eq1, "row"
        elif "outer row spacing" in paras: eq, rc = eq2, "row"
        elif "inner row spacing" in paras: eq, rc = eq3, "row"
        elif ("inner column spacing" in paras) and ("outer column spacing" in paras): eq, rc = eq4, "column"
        elif "outer column spacing" in paras: eq, rc = eq5, "column"
        elif "inner column spacing" in paras: eq, rc = eq6, "column"
        if eq: reasoning_line = f"Then, calculate {rc} spacing base on equation \"{eq}\"."
    elif IC_type == "4-sides":
        if ("inner side distance" in paras) and ("outer side distance" in paras): eq = eq7
        elif "outer side distance" in paras: eq = eq8
        elif "inner side distance" in paras: eq = eq9
        if eq: reasoning_line = f"Then, calculate side-to-side distance base on equation \"{eq}\"."
    elif IC_type == "grid" and "missing pins" in paras:
        reasoning_line = "There are missing pins, please give their names in one line with comma separators."
    return reasoning_line

def build_s2_plan_prompt(bbox_str: str, cls_label: str) -> str:
    """
    Builds the planning prompt for Stage2 by chaining the classification task with the parameter planning task.
    """
    classification_prompt = S2_CLS_PROMPT_TMPL.format(bbox=bbox_str)
    combined_prompt = (f"{classification_prompt}\n"
                       f"The IC footprint is classified as \"{cls_label}\".\n"
                       f"{S2_PLAN_PROMPT}")
    return combined_prompt

def build_s3_prompt(ic_type: str, paras: List[str], bbox: str) -> str:
    """
    Constructs the detailed prompt for Stage3 parameter extraction based on the IC type and planned parameters from Stage2.
    """
    intro = (f"The area of the Suggest Pad image is located by pure number pair (x,y,width,height)={bbox} "
             f"in proportion of the datasheet image size. Note that (x,y)=(0,0) denote the left upper corner of the image. "
             f"Based on the located suggest pad image, this IC can be classified as \"{ic_type}\".")
    outro = (" The reasoning process and answers are enclosed within <think> </think> and <answer> </answer> tags, "
             "respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>.")
    if ic_type == "other":
        task_desc = ("To describe the pins' positions and dimensions of this IC, please give the following pin parameters in pure numbers: "
                     "pin count, pin coordinates: \"name, x, y;\", and pin dimensions: \"name, dx, dy;\".")
        return intro + task_desc + outro

    para_line = ", ".join(paras)
    if any(p in paras for p in ["dx1", "dx2", "dy1", "dy2"]):
        para_line += ", where dx1, dy1 are the dimension of side pins, and dx2, dy2 are the dimension of the center pin"
    task_desc = f" To describe the pins' positions and dimensions of this IC, please give the following pin parameters in pure numbers: {para_line}."
    reasoning_line = gen_reasoning(ic_type, paras)
    final_para = ""
    if ic_type == "2-sides":
        final_para = " Finally, give the pin parameters in pure numbers: row, column, row spacing, column spacing, dx, dy."
    elif ic_type == "4-sides":
        final_para = " Finally, give the pin parameters in pure numbers: pin per side, side-to-side distance, pin spacing within side, "
        final_para += "dx1, dy1, dx2, dy2." if any(p in paras for p in ["dx1", "dx2", "dy1", "dy2"]) else "dx, dy."
    
    full_prompt = intro + task_desc
    if reasoning_line: full_prompt += " " + reasoning_line
    if final_para: full_prompt += final_para
    full_prompt += outro
    return full_prompt

# ============================== Parsing Functions ==============================

def parse_cls_label(text: str) -> str:
    t = (text or "").strip().lower()
    for k in ["2-sides", "4-sides", "grid", "other"]:
        if k in t: return k
    return "other"

def parse_plan_answer(text: str) -> Tuple[str, List[str]]:
    if not text: return "", []
    ans = text
    if m := re.search(r"<\s*answer\s*>(.*?)<\s*/\s*answer\s*>", text, re.I | re.S):
        ans = m.group(1)
    ic_type = ""
    if m := re.search(r"ic\s*type\s*:\s*([^\n,]+)", ans, re.I):
        ic_type = m.group(1).lower().strip()
    params = []
    if m := re.search(r"extract\s*parameters\s*:\s*(.+)$", ans, re.I):
        params = [p.strip(" .;") for p in re.split(r"[,;]", m.group(1)) if p.strip()]
    return ic_type, params

# ============================== VLM Wrapper Class ==============================

class QwenVL_LLM:
    def __init__(self, model_path: str, name: str = "Qwen", pbar: Optional[tqdm] = None):
        try:
            mp = os.path.abspath(model_path)
            if not os.path.isdir(mp): raise FileNotFoundError(f"[{name}] Model directory not found: {mp}")
            device_map = "auto" if torch.cuda.is_available() else None
            dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float32
            self.processor = AutoProcessor.from_pretrained(mp, use_fast=True, local_files_only=True, trust_remote_code=True)
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                mp, torch_dtype=dtype, device_map=device_map, local_files_only=True, trust_remote_code=True
            )
            self.gcfg = GenerationConfig(max_new_tokens=512, do_sample=False)
            self.available = True
        except Exception as e:
            log(f"[ERROR] Failed to load {name} from {model_path}: {e}")
            self.available = False
        if pbar: pbar.update(1)

    def _call(self, prompts: List[str], images: List[Image.Image], gen_kwargs: Optional[dict] = None) -> List[str]:
        if not self.available: return ["[MODEL NOT LOADED]" for _ in prompts]
        gcfg = self.gcfg
        if gen_kwargs: gcfg = GenerationConfig.from_dict({**self.gcfg.to_dict(), **gen_kwargs})
        prompt, image = prompts[0], images[0]
        messages = [
            {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
            {"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}
        ]
        try:
            text_in = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            image_inputs, _ = process_vision_info(messages)
            inputs = self.processor(text=[text_in], images=image_inputs, return_tensors="pt").to(self.model.device)
            out = self.model.generate(**inputs, generation_config=gcfg)
            full_decoded_text = self.processor.batch_decode(out, skip_special_tokens=True)[0]
            response = parse_assistant_response(full_decoded_text)
            return [response]
        except Exception as e:
            return [f"[MODEL_ERROR] {e}"]

# ============================== Stage1 Fusion Logic ==============================

def calculate_iou(box1: Tuple[float, ...], box2: Tuple[float, ...]) -> float:
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    inter_x1, inter_y1 = max(x1, x2), max(y1, y2)
    inter_x2, inter_y2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
    inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
    union_area = (w1 * h1) + (w2 * h2) - inter_area
    return inter_area / union_area if union_area > 0 else 0.0

def parse_vlm_bbox_from_string(text: str) -> Optional[Tuple[float, ...]]:
    if m := re.search(r"\((\s*[\d\.]+\s*,\s*[\d\.]+\s*,\s*[\d\.]+\s*,\s*[\d\.]+\s*)\)", text):
        try: return tuple(map(float, m.group(1).replace(" ", "").split(',')))
        except ValueError: pass
    return None

def detect_with_yolo(yolo_model: Any, image_path: str) -> List[Dict[str, Any]]:
    if not (_YOLO_OK and yolo_model): return []
    try:
        results = yolo_model(image_path, verbose=False)
        detections = []
        for box in results[0].boxes:
            conf = box.conf.item()
            xc, yc, w, h = box.xywhn.cpu().numpy()[0]
            x, y = xc - w / 2, yc - h / 2
            detections.append({"box": (x, y, w, h), "confidence": conf})
        return detections
    except Exception as e:
        log(f"[WARN] YOLO detection failed for {os.path.basename(image_path)}: {e}")
        return []

# ============================== Data Loading & Configuration ==============================

@dataclass
class Cfg:
    image_dir: str; json_file: str; model_path1: str; model_path2: str; model_path3: str
    yolo_model_path: Optional[str] = None
    out_stage1: str = "generated_predictions_stage1.jsonl"
    out_stage2: str = "generated_predictions_stage2.jsonl"
    out_stage3: str = "generated_predictions_stage3.jsonl"

def load_items_from_json(cfg: Cfg) -> List[Dict[str, Any]]:
    items = []
    try:
        with open(cfg.json_file, "r", encoding="utf-8") as f: data = json.load(f)
    except (json.JSONDecodeError, FileNotFoundError):
        log(f"[ERROR] Could not read or parse JSON file: {cfg.json_file}"); return []
    for entry in data:
        if isinstance(entry, dict) and "images" in entry and entry["images"]:
            img_path = os.path.join(cfg.image_dir, entry["images"][0])
            if os.path.isfile(img_path): items.append({"image_path": img_path, "raw": entry})
            else: log(f"[WARN] Image file not found, skipping: {img_path}")
    return items

# ============================== Main Pipeline ==============================

def run_pipeline(cfg: Cfg, skip_yolo: bool = False, max_images: int = 0):
    items = load_items_from_json(cfg)
    log(f"Loaded {len(items)} valid records from JSON.")
    if max_images > 0: items = items[:max_images]; log(f"Processing first {max_images} records.")
    if not items: return

    num_models = sum(1 for path in [cfg.model_path1, cfg.model_path2, cfg.model_path3] if path)
    if not skip_yolo and cfg.yolo_model_path: num_models += 1
    model1, model2, model3, yolo_model = None, None, None, None
    with tqdm(total=num_models, desc="Loading models") as pbar:
        if cfg.model_path1: model1 = QwenVL_LLM(cfg.model_path1, "Qwen(S1)", pbar)
        if cfg.model_path2: model2 = QwenVL_LLM(cfg.model_path2, "Qwen(S2)", pbar)
        if cfg.model_path3: model3 = QwenVL_LLM(cfg.model_path3, "Qwen(S3)", pbar)
        if not skip_yolo and cfg.yolo_model_path:
            try: yolo_model = YOLO(cfg.yolo_model_path); pbar.update(1)
            except Exception as e: log(f"[ERROR] Failed to load YOLO model: {e}")

    # ---------- Stage1: Diagram detection ----------
    s1_rows, s1_cache_bbox = [], {}; IOU_THRESHOLD = 0.4
    with tqdm(items, desc="Stage 1: Fusion") as pbar:
        for it in pbar:
            img_path = it["image_path"]; final_bbox = (0.3, 0.3, 0.4, 0.4)
            with Image.open(img_path).convert("RGB") as img:
                vlm_pred = model1._call([S1_PROMPT], [img])[0]
                vlm_box = parse_vlm_bbox_from_string(vlm_pred)
                yolo_results = detect_with_yolo(yolo_model, img_path)
                if not yolo_results or skip_yolo:
                    if vlm_box: final_bbox = vlm_box
                elif not vlm_box:
                    if yolo_results: final_bbox = max(yolo_results, key=lambda r: r['confidence'])['box']
                else:
                    best_yolo = max(yolo_results, key=lambda r: calculate_iou(vlm_box, r['box']))
                    if calculate_iou(vlm_box, best_yolo['box']) > IOU_THRESHOLD:
                        conf = best_yolo["confidence"]
                        final_bbox = tuple(conf * y + (1 - conf) * v for y, v in zip(best_yolo['box'], vlm_box))
                    else: final_bbox = best_yolo["box"]
                s1_cache_bbox[img_path] = final_bbox
                bbox_str = f"({','.join(f'{x:.4f}' for x in final_bbox)})"
                s1_rows.append({"image": os.path.basename(img_path), "prompt": S1_PROMPT, "predict": bbox_str})
    write_jsonl(cfg.out_stage1, s1_rows)

    # ---------- Stage2: Classification & Planning ----------
    s2_rows, plan_params, cls_cache = [], {}, {}
    with tqdm(items, desc="Stage 2: Class & Plan") as pbar:
        for it in pbar:
            img_path = it["image_path"]
            bbox_str = f"({','.join(f'{x:.4f}' for x in s1_cache_bbox[img_path])})"
            with Image.open(img_path).convert("RGB") as img:
                cls_prompt = S2_CLS_PROMPT_TMPL.format(bbox=bbox_str)
                cls_pred = model2._call([cls_prompt], [img])[0]
                cls_label = parse_cls_label(cls_pred)
                cls_cache[img_path] = cls_label
                plan_prompt = build_s2_plan_prompt(bbox_str, cls_label)
                plan_pred = model2._call([plan_prompt], [img], gen_kwargs={"max_new_tokens": 1024})[0]
                ic, params = parse_plan_answer(plan_pred)
                plan_params[img_path] = (ic or cls_label, params)
                s2_rows.extend([
                    {"image": os.path.basename(img_path), "prompt": cls_prompt, "predict": cls_pred},
                    {"image": os.path.basename(img_path), "prompt": plan_prompt, "predict": plan_pred}
                ])
    write_jsonl(cfg.out_stage2, s2_rows)

    # ---------- Stage3: Parameter Extraction ----------
    s3_rows = []
    with tqdm(items, desc="Stage 3: Extraction") as pbar:
        for it in pbar:
            img_path = it["image_path"]
            bbox_str = f"({','.join(f'{x:.2f}' for x in s1_cache_bbox[img_path])})"
            ic_type, paras = plan_params.get(img_path, (cls_cache.get(img_path, "other"), []))
            s3_prompt = build_s3_prompt(ic_type, paras, bbox_str)
            with Image.open(img_path).convert("RGB") as img:
                s3_pred = model3._call([s3_prompt], [img], gen_kwargs={"max_new_tokens": 1024})[0]
                s3_rows.append({"image": os.path.basename(img_path), "prompt": s3_prompt, "predict": s3_pred})
    write_jsonl(cfg.out_stage3, s3_rows)
    
    del model1, model2, model3, yolo_model; gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()
    log("Pipeline finished successfully.")

# ============================== Command-Line Interface ==============================

def main():
    parser = argparse.ArgumentParser(description="Run the IC Datasheet Analysis Pipeline.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--image_dir", type=str, help="Directory containing the datasheet images.")
    parser.add_argument("--json_file", type=str, help="JSON file listing the images to process.")
    parser.add_argument("--model_path1", type=str, help="Path to the Stage1 VLM.")
    parser.add_argument("--model_path2", type=str, help="Path to the Stage2 VLM.")
    parser.add_argument("--model_path3", type=str, help="Path to the Stage3 VLM.")
    parser.add_argument("--yolo_model_path", type=str, help="Path to the fine-tuned YOLO model weights.")
    parser.add_argument("--out_stage1", type=str, default="stage1_output.jsonl", help="Output file for Stage 1.")
    parser.add_argument("--out_stage2", type=str, default="stage2_output.jsonl", help="Output file for Stage 2.")
    parser.add_argument("--out_stage3", type=str, default="stage3_output.jsonl", help="Output file for Stage 3.")
    parser.add_argument("--skip_yolo", action="store_true", help="Disable YOLO and use only VLM for Stage 1.")
    parser.add_argument("--max_images", type=int, default=0, help="Limit processing to the first N images for debugging (0=all).")
    args = parser.parse_args()
    required = ["image_dir", "json_file", "model_path1", "model_path2", "model_path3"]
    if not args.skip_yolo: required.append("yolo_model_path")
    missing = [f"--{arg}" for arg in required if getattr(args, arg) is None]
    if missing:
        log(f"[ERROR] Missing required arguments: {', '.join(missing)}"); parser.print_help(); return
    cfg = Cfg(
        image_dir=args.image_dir, json_file=args.json_file,
        model_path1=args.model_path1, model_path2=args.model_path2, model_path3=args.model_path3,
        yolo_model_path=args.yolo_model_path,
        out_stage1=args.out_stage1, out_stage2=args.out_stage2, out_stage3=args.out_stage3
    )
    run_pipeline(cfg, skip_yolo=args.skip_yolo, max_images=args.max_images)

if __name__ == "__main__":
    main()