"""
PyTorch Dataset for training an autoencoder on the ARSO dataset.

This module provides a Dataset class, `DynamicAutoencoderArsoDataset`, which is
designed to load ARSO radar data from HDF5 files and provide single, randomly
selected frames for autoencoder training. It also includes utility functions for
padding and unpadding tensors to dimensions that are multiples of 16, which is
often required for convolutional autoencoder architectures.
"""

from torch.utils.data import Dataset
import torch
import h5py
import numpy as np


def pad_to_multiple_of_16(x):
    """
    Pads a 4D tensor's height and width to be a multiple of 16.

    Args:
        x (torch.Tensor): Input tensor of shape (B, C, H, W).

    Returns:
        A tuple containing the padded tensor and the padding values (top, bottom, left, right).
    """
    b, c, h, w = x.shape
    new_h = ((h + 15) // 16) * 16
    new_w = ((w + 15) // 16) * 16

    # Calculate padding
    pad_h = new_h - h
    pad_w = new_w - w
    top = pad_h // 2
    bottom = pad_h - top
    left = pad_w // 2
    right = pad_w - left

    # Pad the tensor
    padded = torch.nn.functional.pad(x, (left, right, top, bottom), mode="replicate")
    return padded, (top, bottom, left, right)


def pad_to_multiple_of_16_5d(x):
    """
    Pads a 5D tensor's height and width to be a multiple of 16.

    This is a wrapper around `pad_to_multiple_of_16` for tensors with a time dimension.

    Args:
        x (torch.Tensor): Input tensor of shape (B, T, C, H, W).

    Returns:
        A tuple containing the padded tensor and the padding values.
    """
    b, t, c, h, w = x.shape
    x_reshaped = x.view(b * t, c, h, w)
    padded_reshaped, padding_info = pad_to_multiple_of_16(x_reshaped)
    padded = padded_reshaped.view(
        b, t, c, padded_reshaped.shape[-2], padded_reshaped.shape[-1]
    )
    return padded, padding_info


def remove_padding_5d(x, padding_info):
    """
    Removes padding from a 5D tensor.

    Wrapper around `remove_padding` for tensors with a time dimension.

    Args:
        x (torch.Tensor): Padded tensor of shape (B, T, C, H, W).
        padding_info (tuple): Padding values (top, bottom, left, right).

    Returns:
        torch.Tensor: The tensor with padding removed.
    """
    b, t, c, h, w = x.shape
    x_reshaped = x.view(b * t, c, h, w)
    unpadded_reshaped = remove_padding(x_reshaped, padding_info)
    unpadded = unpadded_reshaped.view(
        b, t, c, unpadded_reshaped.shape[-2], unpadded_reshaped.shape[-1]
    )
    return unpadded


def remove_padding(x, padding_info):
    """
    Removes padding from a 4D tensor given the padding info.

    Args:
        x (torch.Tensor): Padded tensor of shape (B, C, H, W).
        padding_info (tuple): A tuple of (top, bottom, left, right) padding values.

    Returns:
        torch.Tensor: The unpadded tensor.
    """
    top, bottom, left, right = padding_info
    return x[
        :,
        :,
        top : -bottom if bottom > 0 else None,
        left : -right if right > 0 else None,
    ]


class DynamicAutoencoderArsoDataset(Dataset):
    """
    PyTorch Dataset for training an autoencoder on the ARSO dataset.

    This dataset loads sequences from an HDF5 file, combines the input and target
    parts, and returns a single, randomly selected frame from the full sequence.
    It uses lazy loading for the HDF5 file to work efficiently with multiple
    DataLoader workers.
    """

    def __init__(
        self,
        h5_file_path,
        indices=None,
        transform=None,
        channel_last=False,
        debug_mode=False,
        normalize=False,
    ):
        """
        Initializes the dataset.

        Args:
            h5_file_path (str): Path to the HDF5 data file.
            indices (array-like, optional): A list of indices for this dataset split.
                                           If None, all samples are used.
            transform (callable, optional): A transform to apply to the frames.
            channel_last (bool): If True, output tensors are (H, W, C).
            debug_mode (bool): If True, uses only a small subset of data.
            normalize (bool): If True, scales data by dividing by 57.0.
        """
        self.h5_file_path = h5_file_path
        self.indices = indices
        self.transform = transform
        self.channel_last = channel_last
        self.debug_mode = debug_mode
        self.normalize = normalize
        self.h5_file = None  # The file will be opened lazily in __getitem__

    def __len__(self):
        """
        Returns the total number of samples in the dataset.

        Returns:
            int: The number of samples.
        """
        if self.indices is not None:
            return len(self.indices)
        else:
            # Open the file temporarily to determine length
            if self.h5_file is None:
                with h5py.File(self.h5_file_path, "r") as f:
                    length = len(f["zm_IN"])
            else:
                length = len(self.h5_file["zm_IN"])

            # If debug mode is on, use only the first 10 events
            if self.debug_mode:
                return min(10, length)
            return length

    def __getitem__(self, idx):
        """
        Retrieves a single random frame from a concatenated sequence.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            A tuple containing the frame tensor and a metadata dictionary.
        """
        # Lazy file opening (important for DataLoader workers)
        if self.h5_file is None:
            self.h5_file = h5py.File(self.h5_file_path, "r")

        # Map the provided index to the actual sample index
        real_idx = self.indices[idx] if self.indices is not None else idx

        # Read input and target sequences
        input_data = self.h5_file["zm_IN"][real_idx]  # shape: (T_in, H, W)
        target_data = self.h5_file["zm_OUT"][real_idx]  # shape: (T_out, H, W)

        # Concatenate input and target sequences along time dimension
        full_sequence = np.concatenate([input_data, target_data], axis=0)

        # Apply optional transform
        if self.transform is not None:
            full_sequence = self.transform(full_sequence)

        # Select a single random frame from the full sequence
        total_frames = full_sequence.shape[0]
        frame_idx = np.random.randint(0, total_frames)
        frame = full_sequence[frame_idx]  # shape: (H, W)

        # Convert to tensor and add a channel dimension
        frame = torch.from_numpy(frame).float().unsqueeze(0)  # shape: (1, H, W)

        # Normalize if requested
        if self.normalize:
            frame /= 57.0

        # Convert to channel-last format if requested
        if self.channel_last:
            frame = frame.permute(1, 2, 0)  # shape: (H, W, 1)

        return frame, {"frame_idx": frame_idx, "total_frames": total_frames}


def sequential_collate(batch):
    """
    Collate function for `DynamicAutoencoderArsoDataset`.

    Stacks the frames and collects metadata from a batch of samples.

    Args:
        batch (list): A list of (frame, metadata) tuples from the dataset.

    Returns:
        A tuple containing the batched frame tensor and a list of metadata dicts.
    """
    frames, metadata = zip(*batch)
    frames = torch.stack(frames, dim=0)
    return frames, list(metadata)
