import os
import math
import random
from pathlib import Path

import cv2
import numpy as np
from tqdm import tqdm
import albumentations as A


# ---------------------------
# Augmentation pipeline
# ---------------------------
def build_aug_pipeline(target_h, target_w):
    """
    Returns a single Compose with randomness baked in (OneOf + probabilities).
    NOTE: No Normalize here because we want to SAVE nice-looking images.
    """
    return A.Compose([
        # geometric flips/rotations
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
        A.RandomRotate90(p=0.3),

        # mild perspective / affine
        A.ShiftScaleRotate(shift_limit=0.08, scale_limit=0.15, rotate_limit=25,
                           border_mode=cv2.BORDER_REFLECT_101, p=0.5),

        # elastic / grid style warps (at most one)
        A.OneOf([
            A.ElasticTransform(alpha=70, sigma=10, alpha_affine=10, border_mode=cv2.BORDER_REFLECT_101, p=1),
            A.GridDistortion(num_steps=5, distort_limit=0.3, border_mode=cv2.BORDER_REFLECT_101, p=1),
            A.OpticalDistortion(distort_limit=0.25, shift_limit=0.05, border_mode=cv2.BORDER_REFLECT_101, p=1),
        ], p=0.35),

        # color/contrast/gamma (image-only effect; masks pass through unchanged)
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=1),
            A.CLAHE(clip_limit=3.0, tile_grid_size=(8,8), p=1),
            A.RandomGamma(gamma_limit=(70, 130), p=1),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=1),
        ], p=0.6),

        # noise/blur occasionally
        A.OneOf([
            A.GaussianBlur(blur_limit=(3,5), p=1),
            A.MedianBlur(blur_limit=3, p=1),
            A.GaussNoise(var_limit=(5.0, 25.0), p=1),
        ], p=0.25),

        # final resize to target (image bilinear, mask nearest to preserve labels)
        A.Resize(height=target_h, width=target_w,
                 interpolation=cv2.INTER_LINEAR, mask_interpolation=cv2.INTER_NEAREST)
    ])


# ---------------------------
# IO helpers
# ---------------------------
def _read_image_rgb(path):
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)  # BGR
    if img is None:
        raise FileNotFoundError(f"Could not read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def _read_mask_keep(path):
    """
    Reads mask preserving channels. Handles grayscale or RGB(A).
    """
    m = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if m is None:
        raise FileNotFoundError(f"Could not read mask: {path}")
    if m.ndim == 3 and m.shape[2] == 4:  # drop alpha if present
        m = m[:, :, :3]
    return m  # keep as-is (grayscale or 3ch palette)


def _save_image_rgb(path, img_rgb):
    # write RGB -> BGR
    bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite(str(path), bgr)

def _save_mask(path, mask):
    """
    Saves mask exactly as given. If grayscale (H,W) -> single-channel PNG,
    if RGB palette (H,W,3) -> 3-channel PNG. Always lossless PNG.
    """
    cv2.imwrite(str(path), mask)


# ---------------------------
# Core augmentation driver
# ---------------------------
def augment_and_save(
    images_dir,
    masks_dir,
    out_images_dir,
    out_masks_dir,
    target_size=(512, 512),
    n_generate_total=1000,
    per_image=False,      # if True: generate this many PER SOURCE image; else: total across dataset
    seed=42,
    image_exts=(".png", ".jpg", ".jpeg"),
    mask_exts=(".png", ".jpg", ".jpeg"),
):
    """
    - Picks random pairs (with replacement if needed) to generate n_generate_total outputs,
      OR generates 'n_generate_total' per image when per_image=True.
    - Saves augmented pairs as PNG with unique names: {stem}_augXXXX.png
    """
    random.seed(seed)
    np.random.seed(seed)

    images_dir = Path(images_dir)
    masks_dir  = Path(masks_dir)
    out_images_dir = Path(out_images_dir); out_images_dir.mkdir(parents=True, exist_ok=True)
    out_masks_dir  = Path(out_masks_dir);  out_masks_dir.mkdir(parents=True, exist_ok=True)

    # collect matching stems
    def list_by_stem(root, exts):
        m = {}
        for p in root.iterdir():
            if p.is_file() and p.suffix.lower() in exts:
                m[p.stem] = p
        return m

    img_map = list_by_stem(images_dir, tuple(x.lower() for x in image_exts))
    msk_map = list_by_stem(masks_dir,  tuple(x.lower() for x in mask_exts))

    common_stems = sorted(set(img_map.keys()) & set(msk_map.keys()))
    if not common_stems:
        raise RuntimeError("No matching image/mask basenames found.")

    H, W = target_size
    aug = build_aug_pipeline(H, W)

    # plan how many to generate
    if per_image:
        plan = [(stem, n_generate_total) for stem in common_stems]
        total = len(common_stems) * n_generate_total
    else:
        # sample stems (with replacement if needed) to reach total
        if n_generate_total <= len(common_stems):
            sampled = random.sample(common_stems, n_generate_total)
        else:
            sampled = [random.choice(common_stems) for _ in range(n_generate_total)]
        plan = [(s, 1) for s in sampled]
        total = n_generate_total

    counter = 0
    pbar = tqdm(total=total, desc="Augmenting")

    # track unique index per original stem to avoid collisions
    stem_counts = {s: 0 for s in common_stems}

    for stem, k in plan:
        img_path = img_map[stem]
        msk_path = msk_map[stem]

        # read once
        img = _read_image_rgb(img_path)          # RGB uint8
        mask = _read_mask_keep(msk_path)         # gray or RGB uint8

        for _ in range(k):
            # Albumentations expects image=HxWx3 (RGB) and mask=HxW or HxWx3
            # (we do NOT normalize since we save visuals)
            data = aug(image=img, mask=mask)
            aug_img  = np.clip(data["image"], 0, 255).astype(np.uint8)
            aug_mask = data["mask"]
            # for safety: if mask transformed to float by some op, round & cast back
            if aug_mask.dtype != np.uint8:
                if aug_mask.ndim == 2:
                    aug_mask = np.rint(aug_mask).astype(np.uint8)
                else:
                    aug_mask = np.clip(aug_mask, 0, 255).astype(np.uint8)

            idx = stem_counts[stem]
            stem_counts[stem] += 1

            out_img_name = f"{stem}_aug{idx:04d}.png"
            out_msk_name = f"{stem}_aug{idx:04d}.png"

            _save_image_rgb(out_images_dir / out_img_name, aug_img)
            _save_mask(out_masks_dir / out_msk_name, aug_mask)

            counter += 1
            pbar.update(1)

    pbar.close()
    print(f"Done. Wrote {counter} augmented pairs to:\n  {out_images_dir}\n  {out_masks_dir}")


# ---------------------------
# Example usage
# ---------------------------
if __name__ == "__main__":
    # INPUT folders with 1:1 matched basenames
    IMAGES_DIR = "./train/"
    MASKS_DIR  = "./train_mask"

    # OUTPUT folders
    OUT_IMAGES = "out_aug/images"
    OUT_MASKS  = "out_aug/masks"

    # generate a TOTAL of 3000 new pairs (randomly sampled across the dataset)
    augment_and_save(
        images_dir=IMAGES_DIR,
        masks_dir=MASKS_DIR,
        out_images_dir=OUT_IMAGES,
        out_masks_dir=OUT_MASKS,
        target_size=(512, 512),
        n_generate_total=50000,
        per_image=False,     # TOTAL count across dataset
        seed=1337
    )

    # or: generate 5 augmentations PER SOURCE image
    # augment_and_save(
    #     images_dir=IMAGES_DIR,
    #     masks_dir=MASKS_DIR,
    #     out_images_dir=OUT_IMAGES,
    #     out_masks_dir=OUT_MASKS,
    #     target_size=(512, 256),
    #     n_generate_total=5,
    #     per_image=True,    # 5 per image
    #     seed=1337
    # )
