#!/usr/bin/env python3
"""
Build controlled variants of a dataset to reduce *diversity* and/or *quality*.

Input (fixed):
  evals-cata-ffhq/all/0/generated_images.zip
  with a dataset.json like: {"labels":[["000000.png", 0], ["000001.png", 1], ...]}

Output (per requested label subset):
  cata/<label>/
    ├─ div_1/              # diversity reduced (anchor_count = div_param)
    ├─ div_2/              # more reduction    (anchor_count = max(1, div_param//2))
    ├─ qua_1/              # quality reduced (strength = qua_param)
    ├─ qua_2/              # stronger quality (strength = 2*qua_param)
    ├─ div_1_qua_1/        # both reductions (div_1 + qua_1)
    └─ div_2_qua_2/        # both reductions (div_2 + qua_2)
(Names reflect chosen parameters, e.g., div_1000, qua_0p50, div_1000_qua_0p50.)

Arguments:
  label        : 'male' | 'female' | 'all' -> which subset(s) to DEGRADE
  --div ...    : zero or more ints (e.g., --div 1000 2000 5000)
  --qua ...    : zero or more floats (e.g., --qua 0.25 0.5 0.75)
  --pairs      : combine lists pairwise instead of full cartesian product

Naming:
  - Only div:  div_<k>            (e.g., div_1000)
  - Only qua:  qua_<s>            (e.g., qua_0p25)
  - Both:      div_<k>_qua_<s>    (e.g., div_1000_qua_0p50)
  (floats are sanitized: 0.5 -> 0p5, 0.250 -> 0p25)
"""

import argparse
import io
import json
import math
import os
import random
import zipfile
from collections import defaultdict
from typing import Dict, List, Tuple

from PIL import Image, ImageFilter
import numpy as np
from tqdm import tqdm


def _san_float(x: float) -> str:
    s = f"{x:.6f}".rstrip("0").rstrip(".")
    return s.replace(".", "p")


# ---- Change these if your numeric mapping differs ----
MALE_ID = 0
FEMALE_ID = 1
# ------------------------------------------------------

SRC_ZIP = os.path.join("evals-cata-ffhq", "all", "0", "generated_images.zip")


# --------------- Quality degradation ops ----------------
def jpeg_reencode(img: Image.Image, quality: int) -> Image.Image:
    quality = max(5, min(95, int(quality)))
    buf = io.BytesIO()
    img.save(buf, format="JPEG", quality=quality, optimize=False)
    buf.seek(0)
    return Image.open(buf).convert("RGB")


def add_gaussian_noise(img: Image.Image, std: float) -> Image.Image:
    if std <= 0:
        return img
    arr = np.asarray(img).astype(np.float32)
    noise = np.random.normal(0.0, std * 255.0, size=arr.shape).astype(np.float32)
    out = np.clip(arr + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(out, mode="RGB")


def blur(img: Image.Image, radius: float) -> Image.Image:
    if radius <= 0:
        return img
    return img.filter(ImageFilter.GaussianBlur(radius=radius))


def sinusoidal_warp(img: Image.Image, amp: float, freq: float) -> Image.Image:
    """Cheap wavy distortion (horizontal). amp, freq in pixels and cycles/width-ish."""
    if amp <= 0 or freq <= 0:
        return img
    arr = np.asarray(img)
    h, w = arr.shape[:2]
    out = np.zeros_like(arr)
    xs = np.arange(w, dtype=np.float32)
    for y in range(h):
        shift = int(amp * math.sin(2 * math.pi * freq * (y / max(1, h - 1))))
        out[y] = np.roll(arr[y], shift, axis=0)
    return Image.fromarray(out)


def degrade_quality(img: Image.Image, strength: float, seed: int) -> Image.Image:
    """
    strength ~ [0.1 .. 1.0] typical.
    Applies: JPEG re-encode (lower quality) -> blur -> gaussian noise -> mild warp
    """
    random.seed(seed)
    np.random.seed(seed & 0xFFFF)

    # Map strength to parameters
    # Lower jpeg quality for higher strength
    q = 95 - int(60 * strength)  # 95 .. 35
    blur_r = 0.5 * strength * 2.0  # up to ~1.0
    noise_std = 0.03 * strength * 2.0  # up to ~0.06
    amp = 1.0 * strength  # up to ~1 px
    freq = 0.03 + 0.5 * strength  # ~0.03..0.08

    out = jpeg_reencode(img, q)
    out = blur(out, blur_r)
    out = add_gaussian_noise(out, noise_std)
    out = sinusoidal_warp(out, amp, freq)
    return out


# --------------- Diversity reduction ops ----------------
def cheap_permutation(img: Image.Image) -> Image.Image:
    """
    Deterministic cheap variations so repeated anchors aren't byte-identical.
    """
    out = img
    if random.random() < 0.5:
        out = out.transpose(Image.FLIP_LEFT_RIGHT)
    # out = out.resize(img.size, Image.BICUBIC)
    return out


def build_anchor_indices(filelist: List[str], anchor_count: int) -> List[str]:
    """
    Pick first N by filename order as anchors (stable & deterministic).
    """
    anchor_count = max(1, min(anchor_count, len(filelist)))
    anchors = sorted(filelist)[:anchor_count]
    return anchors


def map_to_anchor(filename: str, anchors: List[str]) -> int:
    """
    Deterministic mapping filename -> anchor index.
    """
    return (hash(filename) & 0x7FFFFFFF) % len(anchors)


# ------------------------ IO utils ----------------------
def read_dataset(zip_path: str) -> Tuple[zipfile.ZipFile, List[Tuple[str, int]]]:
    z = zipfile.ZipFile(zip_path, "r")
    meta_name = "dataset.json" if "dataset.json" in z.namelist() else "datajson"
    meta = json.loads(z.read(meta_name).decode("utf-8"))
    items = [(fn, int(lbl)) for fn, lbl in meta["labels"]]
    return z, items


def group_by_label(items: List[Tuple[str, int]]) -> Dict[int, List[str]]:
    groups: Dict[int, List[str]] = defaultdict(list)
    for fn, lbl in items:
        groups[lbl].append(fn)
    return groups


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)


def write_dataset_json(dest_dir: str, items: List[Tuple[str, int]]):
    meta = {"labels": [[fn, int(lbl)] for fn, lbl in items]}
    with open(os.path.join(dest_dir, "dataset.json"), "w", encoding="utf-8") as f:
        json.dump(meta, f)


def save_image(path: str, img: Image.Image):
    ensure_dir(os.path.dirname(path))
    img.save(path, format="PNG")


# --------------------- Core pipeline --------------------
def process_variant(
    label_mode: str,
    variant_name: str,
    items: List[Tuple[str, int]],
    z: zipfile.ZipFile,
    do_diversity: bool,
    div_anchor_count: int,
    do_quality: bool,
    quality_strength: float,
):
    """
    Create one variant subdir under cata/<label_mode>/<variant_name>
    """
    out_root = os.path.join("cata", label_mode, variant_name)
    ensure_dir(out_root)
    dest_zip_path = os.path.join(out_root, "generated_images.zip")

    # Determine which numeric labels to degrade
    if label_mode == "male":
        target_labels = {MALE_ID}
    elif label_mode == "female":
        target_labels = {FEMALE_ID}
    elif label_mode == "all":
        target_labels = {FEMALE_ID, MALE_ID}
    else:
        raise ValueError("label_mode must be 'male', 'female', or 'all'.")

    # Pre-compute anchors per targeted label (for diversity reduction)
    by_label = group_by_label(items)
    label_to_anchors: Dict[int, List[str]] = {}
    if do_diversity:
        for lbl in target_labels:
            src_files = by_label.get(lbl, [])
            label_to_anchors[lbl] = build_anchor_indices(src_files, div_anchor_count)

    out_items: List[Tuple[str, int]] = []

    # Open destination zip and stream images into it
    with zipfile.ZipFile(dest_zip_path, "w") as zout:
        desc = f"{variant_name:>14}"
        for idx, (fn, lbl) in enumerate(tqdm(items, desc=desc, ncols=88)):
            # Build final image (diversity/quality as requested)
            if lbl in target_labels:
                if do_diversity:
                    anchors = label_to_anchors.get(lbl, [])
                    if len(anchors) == 0:
                        base_bytes = z.read(fn)
                        base_img = Image.open(io.BytesIO(base_bytes)).convert("RGB")
                    else:
                        aidx = map_to_anchor(fn, anchors)
                        anchor_name = anchors[aidx]
                        base_bytes = z.read(anchor_name)
                        base_img = Image.open(io.BytesIO(base_bytes)).convert("RGB")
                        base_img = cheap_permutation(base_img)
                else:
                    base_bytes = z.read(fn)
                    base_img = Image.open(io.BytesIO(base_bytes)).convert("RGB")

                if do_quality:
                    seed = hash((variant_name, fn)) & 0xFFFFFFFF
                    final_img = degrade_quality(
                        base_img, strength=quality_strength, seed=seed
                    )
                else:
                    final_img = base_img
            else:
                base_bytes = z.read(fn)
                final_img = Image.open(io.BytesIO(base_bytes)).convert("RGB")

            # 1) Write PNG bytestream into the destination ZIP under original filename
            buf = io.BytesIO()
            final_img.save(buf, format="PNG")
            buf.seek(0)
            zout.writestr(fn, buf.getvalue())

            # 2) Save ONLY the first 100 images as standalone PNGs under 00000/<lbl>/
            if idx < 100:
                demo_dir = os.path.join(out_root, "00000", str(lbl))
                ensure_dir(demo_dir)
                demo_path = os.path.join(demo_dir, fn)
                save_image(demo_path, final_img)

            # Record label mapping for dataset.json inside the zip
            out_items.append((fn, lbl))

        # Write dataset.json into the ZIP at the end
        meta = {"labels": [[f, int(l)] for (f, l) in out_items]}
        zout.writestr("dataset.json", json.dumps(meta))


def main():
    parser = argparse.ArgumentParser(
        description="Build diversity/quality-reduced dataset variants."
    )
    parser.add_argument(
        "label",
        choices=["male", "female", "all"],
        help="Which label(s) to DEGRADE (others are kept unchanged).",
    )
    parser.add_argument(
        "--div",
        type=int,
        nargs="*",
        default=None,
        help="Anchor counts for diversity reduction (e.g., --div 1000 2000). If omitted, no diversity variants are produced.",
    )
    parser.add_argument(
        "--qua",
        type=float,
        nargs="*",
        default=None,
        help="Quality strengths for degradation (e.g., --qua 0.25 0.5). If omitted, no quality variants are produced.",
    )
    parser.add_argument(
        "--pairs",
        action="store_true",
        help="Pairwise combine lists instead of full cartesian product (len aligns by min length).",
    )
    args = parser.parse_args()

    # Load source once
    z, items = read_dataset(SRC_ZIP)

    div_list = [] if args.div is None else [max(1, int(v)) for v in args.div]
    qua_list = [] if args.qua is None else [max(0.0, float(v)) for v in args.qua]

    if not div_list and not qua_list:
        print("No variants requested: provide --div and/or --qua values.")
        return

    # Build variant specs
    specs = []  # list of (name, do_div, div_count, do_qua, qua_strength)

    if div_list and not qua_list:
        for k in div_list:
            specs.append((f"div_{k}", True, k, False, 0.0))
    elif qua_list and not div_list:
        for s in qua_list:
            specs.append((f"qua_{_san_float(s)}", False, 0, True, s))
    else:
        if args.pairs:
            n = min(len(div_list), len(qua_list))
            for k, s in zip(div_list[:n], qua_list[:n]):
                specs.append((f"div_{k}_qua_{_san_float(s)}", True, k, True, s))
        else:
            for k in div_list:
                for s in qua_list:
                    specs.append((f"div_{k}_qua_{_san_float(s)}", True, k, True, s))

    for name, ddiv, dcount, dqua, qstr in specs:
        process_variant(
            label_mode=args.label,
            variant_name=name,
            items=items,
            z=z,
            do_diversity=ddiv,
            div_anchor_count=dcount,
            do_quality=dqua,
            quality_strength=qstr,
        )

    print("Done.")


if __name__ == "__main__":
    main()
