#%%
import os, json, cv2, numpy as np
from skimage.segmentation import slic
from skimage.color import rgb2lab
from tqdm import tqdm

def superpixel_boxes_by_variance(
        image_rgb,
        n_segments=200,
        color_thresh=20,
        min_brightness=0,
        area_ratio_thresh=0.05, 
        topk=10
):
    H, W = image_rgb.shape[:2]
    img_area = H * W

    img_float = np.clip(image_rgb.astype(np.float32) / 255.0, 0, 1)
    segments  = slic(img_float, n_segments=n_segments, compactness=10, start_label=1)
    lab_image = rgb2lab(img_float)

    label_ids  = np.unique(segments)
    mean_lab   = {}
    brightness = {}

    for label in label_ids:
        mask = segments == label
        mean_lab[label] = lab_image[mask].mean(axis=0)
        avg_rgb = image_rgb[mask].mean(axis=0)
        brightness[label] = 0.299*avg_rgb[0] + 0.587*avg_rgb[1] + 0.114*avg_rgb[2]

    merged = {}
    for i in label_ids:
        merged[i] = [i]

    keys = list(merged.keys())
    for i in range(len(keys)):
        for j in range(i+1, len(keys)):
            a, b = keys[i], keys[j]
            if np.linalg.norm(mean_lab[a] - mean_lab[b]) < color_thresh:
                merged[a].extend(merged[b])
                merged[b] = []

    groups = [sorted(set(g)) for g in merged.values() if g]

    candidates = []
    for group in groups:
        mask = np.isin(segments, group)
        if mask.sum() == 0:
            continue

        comp_map = cv2.connectedComponents(mask.astype(np.uint8))[1]
        comp_ids = np.unique(comp_map)[1:]

        for cid in comp_ids:
            comp_mask = comp_map == cid
            y_idx, x_idx = np.where(comp_mask)
            x1, y1, x2, y2 = x_idx.min(), y_idx.min(), x_idx.max(), y_idx.max()
            bbox_area = (x2 - x1 + 1) * (y2 - y1 + 1)

            if bbox_area < area_ratio_thresh * img_area:
                continue

            region = image_rgb[y1:y2+1, x1:x2+1]
            region_f = region.astype(np.float32)

            lab = cv2.cvtColor(region_f, cv2.COLOR_RGB2LAB)
            pixels = lab.reshape(-1, 3)
            center = pixels.mean(axis=0)
            var = float(((pixels - center) ** 2).sum(axis=1).mean())

            reshaped_rgb = region.reshape(-1, 3)
            unique, counts = np.unique(reshaped_rgb, axis=0, return_counts=True)
            most_common_rgb = unique[np.argmax(counts)].tolist()

            candidates.append( ((x1, y1, x2, y2), var, most_common_rgb) )

    candidates.sort(key=lambda x: x[1], reverse=True)

    def is_valid_rgb(rgb, black_thresh=20, white_thresh=235):
        return not (
            all(c <= black_thresh for c in rgb) or
            all(c >= white_thresh for c in rgb)
        )

    filtered = [item for item in candidates if is_valid_rgb(item[2])]
    return filtered[:topk]


def process_multi_folder(
        input_dirs,
        output_dir="boxed_images_LVIS",
        json_out="bbox_and_variance_lvis.json",
        target_count=7000,
        **kwargs         
):
    os.makedirs(output_dir, exist_ok=True)
    results = {}
    collected = 0
    all_image_paths = []

    # Step 1: Gather all image paths from val → train → test
    for dir_path in input_dirs:
        for fname in sorted(os.listdir(dir_path)):
            if fname.lower().endswith((".jpg", ".jpeg", ".png")):
                all_image_paths.append((os.path.join(dir_path, fname), fname))

    print(f"🔍 Total candidate images: {len(all_image_paths)}")
    pbar = tqdm(total=target_count, desc="Collecting boxes")

    for full_path, fname in all_image_paths:
        if collected >= target_count:
            break

        img_bgr = cv2.imread(full_path)
        if img_bgr is None:
            continue
        max_side = 512
        h, w = img_bgr.shape[:2]
        scale = max_side / max(h, w)
        if scale < 1.0:
            new_w, new_h = int(w * scale), int(h * scale)
            img_bgr = cv2.resize(img_bgr, (new_w, new_h), interpolation=cv2.INTER_AREA)

        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

        try:
            boxes = superpixel_boxes_by_variance(img_rgb, **kwargs)
        except Exception:
            continue

        if not boxes:
            continue

        # Draw box on image
        for (x1, y1, x2, y2), var, mode_rgb in boxes:
            cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 0, 255), 2)

        # Generate unique name (e.g. val2017_xxx.jpg)
        src_folder = os.path.basename(os.path.dirname(full_path))
        fname_new = f"{src_folder}_{fname}"

        cv2.imwrite(os.path.join(output_dir, fname_new), img_bgr)

        results[fname_new] = [
            {
                "bbox": [int(x1), int(y1), int(x2), int(y2)],
                "variance": round(var, 2),
                "mode_rgb": mode_rgb
            }
            for (x1, y1, x2, y2), var, mode_rgb in boxes
        ]
        collected += 1
        pbar.update(1)

    pbar.close()

    with open(json_out, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\n✅ Collected {collected} images and saved to {json_out}")
    print(f"📁 Images saved in: {output_dir}")


# === Execute ===
if __name__ == "__main__":
    process_multi_folder(
        input_dirs=[
            "/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/LVIS/downloads/val2017",
            "/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/LVIS/downloads/train2017",
            "/fs/scratch/PAS2099/Herz_2/VFM/Dataset/counting/LVIS/downloads/test2017"
        ],
        output_dir="boxed_images_LVIS",
        json_out="bbox_and_variance_lvis.json",
        target_count=7000,
        n_segments=400,
        color_thresh=0.005,
        area_ratio_thresh=0.001,
        topk=1
    )
