"""
This module contains pytorch dataset classes for loading and handling
the holey layer datasets generated using the `generate_dataset.py` script.
"""

import h5py
import json

import random
import numpy as np
import cv2

import torch
from torch.utils.data import Dataset
from torchvision import transforms


class RandomRoll:
    """
    Rolls a tensor (seq_len, H, W) along H and W with same shift for all
    elements in the sequence.
    """

    def __init__(self, max_shift=None):
        """
        Args:
            max_shift (int or tuple): maximum shift allowed.
                - If int: uniform between [0, max_shift).
                - If tuple: interpreted as (max_shift_x, max_shift_y).
                - If None: defaults to full size (i.e. up to H and W).
        """
        self.max_shift = max_shift

    def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
        # Expecting input shape (T, H, W)
        assert tensor.ndim == 3, f"Expected 3D tensor (T,H,W), got {tensor.shape}"

        _, H, W = tensor.shape

        if self.max_shift is None:
            max_shift_x, max_shift_y = H, W
        elif isinstance(self.max_shift, int):
            max_shift_x = max_shift_y = self.max_shift
        else:
            max_shift_x, max_shift_y = self.max_shift

        shift_x = random.randint(0, max_shift_x - 1)
        shift_y = random.randint(0, max_shift_y - 1)

        return torch.roll(tensor, shifts=(shift_x, shift_y), dims=(1, 2))


class HoleyLayerDataset(Dataset):
    """
    PyTorch Dataset class for loading holey layer datasets from HDF5 files.

    Each sample contains:
        - pixmap_layers: 2D array representing the layers
        - thicknesses: 1D array representing the thicknesses of each layer
        - layer_names: List of layer names
        - fs: 1D array representing the frequencies
        - kps: 2D array representing the wavevectors
        - R: 2D array representing the reflection coefficients
        - T: 2D array representing the transmission coefficients
        - A: 2D array representing the absorption coefficients
        - BD: 2D array representing the band diagram
    """

    def __init__(
        self,
        config_file: str,
        h5_files: list[str],
        Nx: int = 64,
        max_seq_len: int = None,
        scale_eps: bool = True,
        binarize_BD: bool = True,
        clip_limit: float = 1.0,
        threshold_size: int = 31,
        kernel_size: int = 3,
        additional_transforms: dict = {
            "pixmap_layers": [transforms.Lambda(lambda x: x)],
            "R": [transforms.ToTensor()],
            "T": [transforms.ToTensor()],
            "BD": [transforms.ToTensor()],
        },
        random_air_layers: bool = True,
    ):
        """
        Args:
            config_file (str): Path to the configuration file.
            h5_files (list[str]): List of paths to HDF5 files containing the datasets.
            Nx (int, optional): Size of the input images (default: 64).
            max_seq_len (int, optional): Maximum sequence length (number of layers * 2).
                If None, it will be set to the maximum number of layers in the config * 2.
            scale_eps (bool, optional): Whether to minmax scale the permittivity values to [0, 1].
            binarize_BD (bool, optional): Whether to binarize the band diagram.
            clip_limit (float, optional): Clip limit for CLAHE when binarizing the band diagram.
            threshold_size (int, optional): Threshold size for adaptive thresholding when binarizing the band diagram
            additional_transforms (dict, optional): Dictionary of additional transformations to apply
                to each component of the dataset, at the end of the default transformations.
            random_air_layers (bool, optional): Whether to randomly add layers of eps=1.0 at the
                beginning and/or end of the structure as data augmentation.
        """
        super().__init__()
        self.h5_files = h5_files
        self.binarize_BD = binarize_BD
        self.clip_limit = clip_limit
        self.threshold_size = threshold_size
        self.kernel_size = kernel_size
        self.random_air_layers = random_air_layers

        # Load config file
        with open(config_file, "r") as f:
            config = json.load(f)
            if max_seq_len is None:
                self.max_seq_len = config["variables"]["n_layers"][1] * 2
            else:
                self.max_seq_len = max_seq_len
                assert max_seq_len >= config["variables"]["n_layers"][1] * 2, (
                    "max_seq_len must be >= max number of layers in the config"
                )
            self.min_thickness = config["variables"]["layer_thickness"][0]
            self.max_thickness = config["variables"]["layer_thickness"][1]

        # Default transforms
        self.default_pixmap_transforms = [
            transforms.ToTensor(),
            transforms.Resize((Nx, Nx)),
        ]
        if scale_eps:
            # Minmax scale the permittivity values to [0, 1]
            min_eps = np.min(
                [
                    config["constants"]["default_eps"],
                    *config["constants"]["material_eps"],
                ]
            )
            max_eps = np.max(
                [
                    config["constants"]["default_eps"],
                    *config["constants"]["material_eps"],
                ]
            )
            self.default_pixmap_transforms.append(
                transforms.Lambda(lambda x: (x - min_eps) / (max_eps - min_eps))
            )
        self.default_pixmap_transforms = transforms.Compose(
            self.default_pixmap_transforms
        )

        self.additional_transforms = {
            k: transforms.Compose(v) for k, v in additional_transforms.items()
        }

    def __len__(self):
        return len(self.h5_files)

    def __getitem__(self, idx):
        filepath = self.h5_files[idx]

        with h5py.File(filepath, "r") as f:
            # Structure
            pixmap_layers = f["pixmap_layers"][:]  # (n_layers, Nx, Nx)
            thicknesses = f["thicknesses"][:]  # (n_layers,)

            # Physical parameters
            fs = f["fs"][:]  # (Nf,)
            kps = f["kps"][:]  # (Nk, 2)

            # Optical response
            R = np.real(f["R"][:])  # (Nf, Nk)
            T = np.real(f["T"][:])  # (Nf, Nk)
            BD = np.abs(f["d2S"][:])  # (Nf, Nk)

        # Rescale the band diagram
        BD_min = np.amin(BD, axis=(0, 1), keepdims=True)
        BD_max = np.amax(BD, axis=(0, 1), keepdims=True)
        BD = (BD - BD_min) / (BD_max - BD_min + 1e-8)

        if self.binarize_BD:
            clahe = cv2.createCLAHE(clipLimit=self.clip_limit, tileGridSize=(8, 8))
            BD = clahe.apply((BD * 255).astype("uint8"))
            BD = cv2.GaussianBlur(BD, (self.kernel_size, self.kernel_size), 0)
            BD = cv2.adaptiveThreshold(
                BD,
                255,
                cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                cv2.THRESH_BINARY,
                self.threshold_size,
                0,
            )

        # Randomly add layers of eps=1.0 at the beginning and/or end of the structure
        # if len(thicknesses) < self.max_seq_len and self.random_air_layers:
        if self.random_air_layers:
            while len(thicknesses) < self.max_seq_len and random.random() < 0.5:
                if random.random() < 0.5:
                    # Add layer at the beginning
                    pixmap_layers = np.concatenate(
                        [
                            np.ones(
                                (1, pixmap_layers.shape[1], pixmap_layers.shape[2])
                            ),
                            pixmap_layers,
                        ],
                        axis=0,
                    )
                    thicknesses = np.concatenate(
                        [
                            np.random.uniform(
                                self.min_thickness, self.max_thickness, size=(1,)
                            ),
                            thicknesses,
                        ],
                        axis=0,
                    )
                else:
                    # Add layer at the end
                    pixmap_layers = np.concatenate(
                        [
                            pixmap_layers,
                            np.ones(
                                (1, pixmap_layers.shape[1], pixmap_layers.shape[2])
                            ),
                        ],
                        axis=0,
                    )
                    thicknesses = np.concatenate(
                        [
                            thicknesses,
                            np.random.uniform(
                                self.min_thickness, self.max_thickness, size=(1,)
                            ),
                        ],
                        axis=0,
                    )

        # Generate the key_padding mask for the layers
        mask = torch.ones(self.max_seq_len).bool()
        pixmap_tensors = []
        for i in range(len(thicknesses)):
            mask[i] = False  # False for real layers, True for padding
            pixmap_tensors.append(
                self.default_pixmap_transforms(pixmap_layers[i])[0, :, :].to(
                    torch.float32
                )
            )
        pixmap_tensors = torch.stack(pixmap_tensors, dim=0)

        # Apply additional transforms
        pixmap_tensors = self.additional_transforms["pixmap_layers"](pixmap_tensors)
        R = self.additional_transforms["R"](R)
        T = self.additional_transforms["T"](T)
        BD = self.additional_transforms["BD"](BD)

        # Apply padding
        while pixmap_tensors.size(0) < self.max_seq_len:
            pixmap_tensors = torch.cat(
                [pixmap_tensors, torch.zeros_like(pixmap_tensors[0:1])], dim=0
            )
            thicknesses = np.append(thicknesses, 0.0)

        return {
            "pixmap_layers": pixmap_tensors,  # (max_seq_len, Nx, Nx)
            "thicknesses": torch.tensor(thicknesses).to(
                torch.float32
            ),  # (max_seq_len,)
            "fs": torch.tensor(fs).to(torch.float32),  # (Nf,)
            "kps": torch.tensor(kps).to(torch.float32),  # (Nk, 2)
            "R": R,  # (Nf, Nk)
            "T": T,  # (Nf, Nk)
            "BD": BD,  # (Nf, Nk)
            "key_padding_mask": mask,  # (max_seq_len,)
        }
