import os
from os.path import isfile
from time import perf_counter

from PIL import Image, ImageOps
import numpy as np
import h5py
import torch
from torch.utils.data import Dataset, DataLoader

from se.configs import DatasetConfig, TrainConfig


def load_data_m(data, batch_size=100):
    """Load Mohan et al. (2020) bias-free denoising splits (train/valid)."""
    train_dataset = DatasetM(filename=os.path.join(data, "train.h5"))
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, num_workers=1, shuffle=True
    )

    valid_dataset = DatasetM(filename=os.path.join(data, "valid.h5"))
    valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=1, shuffle=False)
    return train_loader, valid_loader


class DatasetM(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.h5f = h5py.File(filename, "r")
        self.keys = list(self.h5f.keys())

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, index):
        key = self.keys[index]
        data = np.array(self.h5f[key])
        return torch.Tensor(data)


def load_images(in_folders: list[str]) -> list[np.ndarray]:
    """Load grayscale images following Herbreteau et al. (2022) preprocessing."""
    start = perf_counter()
    exts = (".jpg", ".png", ".bmp")
    files = [
        f"{in_folder}/{name}"
        for in_folder in in_folders
        for name in os.listdir(in_folder)
        if isfile(f"{in_folder}/{name}")
        and not name.startswith(".")
        and name.lower().endswith(exts)
    ]
    images = [
        np.array(ImageOps.grayscale(Image.open(f))).astype(np.uint8) for f in files
    ]
    duration = perf_counter() - start
    print(f"load_images: loaded {len(images)} images in {duration:.3f}s", flush=True)
    return images


def augmentation(x, k=0, inverse=False):
    k = k % 8
    if inverse:
        k = [0, 1, 6, 3, 4, 5, 2, 7][k]
    if k % 2 == 1:
        x = torch.flip(x, dims=[2])
    return torch.rot90(x, k=(k // 2) % 4, dims=[1, 2])


class DatasetH(Dataset):
    def __init__(
        self,
        in_folders: list[str],
        patch_size=70,
        samples_per_epoch=1000,
    ):
        self.patch_size = patch_size
        self.samples_per_epoch = samples_per_epoch

        self.images_train = load_images(in_folders)
        self.number_of_images = len(self.images_train)

    def __len__(self):
        return self.samples_per_epoch

    def __getitem__(self, idx):
        img_np = self.images_train[np.random.choice(self.number_of_images)]
        h, w = img_np.shape
        i, j = np.random.choice(h - self.patch_size - 1), np.random.choice(
            w - self.patch_size - 1
        )
        patch = img_np[i : i + self.patch_size, j : j + self.patch_size]

        patch = patch.astype(np.float32) / 255.0
        img_torch = torch.from_numpy(patch).view(1, *patch.shape).float()

        k = np.random.randint(8)
        img_torch = augmentation(img_torch, k)

        return img_torch


def _resolve_train_dirs(base_path: str, folders: list[str]) -> list[str]:
    resolved: list[str] = []
    for folder in folders:
        folder_path = (
            folder if os.path.isabs(folder) else os.path.join(base_path, folder)
        )
        if not os.path.isdir(folder_path):
            raise FileNotFoundError(f"Folder '{folder_path}' does not exist.")
        resolved.append(folder_path)
    return resolved


def load_data_h(cfg: TrainConfig | DatasetConfig):
    train_dirs = _resolve_train_dirs(cfg.train_path, cfg.train_image_dirs)
    assert (
        cfg.s_samples_per_epoch is not None
    ), "s_samples_per_epoch must be set for 'h' dataset"
    train_dataset = DatasetH(
        in_folders=train_dirs,
        patch_size=cfg.s_patch_size,
        samples_per_epoch=cfg.s_samples_per_epoch,
    )
    train_loader = DataLoader(
        train_dataset, batch_size=cfg.batch_size, num_workers=1, shuffle=True
    )

    valid_dataset = DatasetM(filename=os.path.join(cfg.train_path, "valid.h5"))
    valid_loader = DataLoader(valid_dataset, batch_size=1, num_workers=1, shuffle=False)
    return train_loader, valid_loader


# %% data loaders
def build_loaders(cfg: TrainConfig | DatasetConfig):
    dataset_type = cfg.train_dataset_type.lower()
    if dataset_type == "h":
        return load_data_h(cfg)
    if dataset_type == "m":
        return load_data_m(cfg.train_path, batch_size=cfg.batch_size)
    raise ValueError(
        f"Unknown dataset type '{cfg.train_dataset_type}'. Use 'm' or 'h'."
    )
