#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
从 ../MRI_7_20_v3 目录中：
1. 找到每个苹果在所有 json 中面积最大的矩形框；
2. 以该最大框所在图片的前 1 张、当前、后 2 张共 4 张为一组，
   对这 4 张图都按同一框裁剪；
3. 将 4 张裁剪图按时间顺序拼成 2×2 大图（左上 = -1，右上 = 0，左下 = +1，右下 = +2）；
4. 保存到 LLM_data/MNR_figure/{苹果绝对编号}.png。
"""

"""
构建 2×2 拼图：缺帧不补黑，自动向前/向后扩展找满 4 张
"""

import re
import json
from pathlib import Path
from PIL import Image

SRC_ROOT = Path("../MRI_7_20_v3").resolve()
DST_DIR  = Path("LLM_data/MNR_figure").resolve()
DST_DIR.mkdir(parents=True, exist_ok=True)


def parse_folder_start_id(folder_name: str) -> int:
    m = re.search(r"\d+", folder_name)
    if not m:
        raise ValueError(f"无法解析文件夹名中的编号：{folder_name}")
    return int(m.group())


def get_bbox_area(bbox):
    return abs(bbox[2] - bbox[0]) * abs(bbox[3] - bbox[1])


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def update_max_bbox(record, apple_id, bbox, img_path):
    area = get_bbox_area(bbox)
    if apple_id not in record or area > record[apple_id]["area"]:
        record[apple_id] = {"area": area, "bbox": bbox, "img": img_path}


def collect_max_bboxes():
    """扫描所有 json，记录每颗苹果的最大框位置"""
    max_bbox_record = {}
    for folder in SRC_ROOT.iterdir():
        if not folder.is_dir():
            continue
        start_id = parse_folder_start_id(folder.name)
        for json_path in folder.glob("*.json"):
            data = load_json(json_path)
            for shape in data.get("shapes", []):
                try:
                    rel_label = int(shape["label"])
                except (KeyError, ValueError):
                    continue
                apple_id = start_id + rel_label - 1
                (x1, y1), (x2, y2) = shape["points"]
                bbox = (int(x1), int(y1), int(x2), int(y2))
                update_max_bbox(max_bbox_record, apple_id, bbox, json_path.with_suffix(".png"))
    return max_bbox_record


def collect_neighbour_imgs(center_img: Path, need: int = 4):
    """
    先尝试 [-1,0,+1,+2]；不足则向后补 (+3,+4…)；再不足向前补 (-2,-3…)
    返回存在的 Path 列表，长度可能 < need
    """
    m = re.match(r"(.*-)(\d+)(\.\w+)$", center_img.name)
    if not m:
        raise ValueError(f"无法解析文件名：{center_img}")
    prefix, idx_str, suffix = m.groups()
    idx = int(idx_str)
    pad = len(idx_str)

    def make(n):  # 构建同级文件路径
        return center_img.with_name(f"{prefix}{str(n).zfill(pad)}{suffix}")

    offsets_primary = [-1, 0, +1, +2]
    found = []
    used_indices = set()

    # 1. primary offsets
    for off in offsets_primary:
        n = idx + off
        p = make(n)
        if p.exists():
            found.append(p)
            used_indices.add(n)
        if len(found) == need:
            return found

    # 2. forward search
    f = offsets_primary[-1] + 1
    while len(found) < need:
        n = idx + f
        p = make(n)
        if not p.exists():
            break
        found.append(p)
        used_indices.add(n)
        f += 1

    # 3. backward search
    b = offsets_primary[0] - 1
    while len(found) < need:
        n = idx + b
        p = make(n)
        if not p.exists():
            break
        found.append(p)
        used_indices.add(n)
        b -= 1

    return found


def crop(img: Image.Image, bbox):
    return img.crop(bbox)


def compose_2x2(imgs):
    """
    imgs: 4 张裁剪图，顺序 = 时间顺序
    若不足 4 张则返回 None（跳过）
    """
    if len(imgs) < 4:
        return None
    w, h = imgs[0].size
    canvas = Image.new("RGB", (w * 2, h * 2))
    pos = [(0, 0), (w, 0), (0, h), (w, h)]
    for im, (x, y) in zip(imgs, pos):
        canvas.paste(im, (x, y))
    return canvas


def main():
    max_record = collect_max_bboxes()
    print(f"共找到 {len(max_record)} 颗苹果的最大框")

    skipped = 0
    for apple_id, info in max_record.items():
        bbox = info["bbox"]
        center_img = info["img"]

        neighbours = collect_neighbour_imgs(center_img, need=4)
        if len(neighbours) < 4:
            skipped += 1
            print(f"[!] Apple {apple_id} 缺帧过多，已跳过")
            continue

        crops = [crop(Image.open(p), bbox) for p in neighbours]
        montage = compose_2x2(crops)
        if montage is None:
            skipped += 1
            continue

        out_path = DST_DIR / f"{apple_id}.png"
        montage.save(out_path)
        print(f"[✓] Apple {apple_id} => {out_path.relative_to(DST_DIR.parent)}")

    if skipped:
        print(f"\n共有 {skipped} 颗苹果因缺帧过多被跳过")


if __name__ == "__main__":
    main()

