"""
Down-sample images in ./output_lb and ./output_ub by 5×,
assemble them into two NumPy arrays ordered by the index i in
'plane_{i}_0.0

005.png', and save both arrays in one NPZ file.
"""

from pathlib import Path
import re
import numpy as np
from PIL import Image

# ------------------------------------------------------------------
# 1. Locate and sort all image paths
# ------------------------------------------------------------------
pat      = re.compile(r"plane_(\d+)_0\.0001\.png$")
root_lb  = Path("output_lb_0.0001_0-2pi")
root_ub  = Path("output_ub_0.0001_0-2pi")

lb_paths = sorted([p for p in root_lb.glob("plane_*_0.0001.png")],
                  key=lambda p: int(pat.match(p.name).group(1)))
ub_paths = sorted([p for p in root_ub.glob("plane_*_0.0001.png")],
                  key=lambda p: int(pat.match(p.name).group(1)))

assert len(lb_paths) == len(ub_paths), "Folders must contain the same image count"

# ------------------------------------------------------------------
# 2. Helper to down-sample one image
# ------------------------------------------------------------------
def downsample(path, factor=5):
    img = Image.open(path).convert("RGB")                # ensure 3-channel
    w, h = img.size
    img_small = img.resize((w // factor, h // factor), Image.NEAREST)
    return np.asarray(img_small, dtype=np.uint8)         # (H, W, 3)

# ------------------------------------------------------------------
# 3. Build the (N, H, W, 3) tensors in memory
#    (H,W are the reduced dimensions)
# ------------------------------------------------------------------
sample = downsample(lb_paths[0], 1)                         # probe size
N, H, W, C = len(lb_paths), *sample.shape
output_lb = np.empty((N, H, W, C), dtype=np.uint8)
output_ub = np.empty_like(output_lb)

for i, (plb, pub) in enumerate(zip(lb_paths, ub_paths)):
    output_lb[i] = downsample(plb, 1)
    output_ub[i] = downsample(pub, 1)

# ------------------------------------------------------------------
# 4. Save both arrays in a single compressed NPZ archive
# ------------------------------------------------------------------
np.savez_compressed("res_classification.npz",
                    output_lb=output_lb,
                    output_ub=output_ub)

print(f"Saved {N} down-sampled pairs to downsampled_planes.npz "
      f"with shape {(N, H, W, C)}")

# import matplotlib.pyplot as plt 
# plt.figure(0)
# plt.imshow(output_lb[0])
# plt.figure(1)
# plt.imshow(output_ub[0])
# plt.show()
# print(output_lb.shape, output_ub.shape)
