"""
PyTorch Dataset classes and utility functions for handling the ARSO radar dataset.

This module provides several Dataset implementations for different stages of a
nowcasting pipeline:
- `ArsoH5Dataset`: Reads raw sequence data directly from HDF5 files.
- `ArsoH5LatentDataset`: Reads pre-encoded latent space data from HDF5 files.


It also includes helper functions for data preprocessing, such as bicubic
interpolation for downsampling, and padding/cropping utilities to handle
variable spatial dimensions.
"""

import h5py
import numpy as np
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
from torch import cuda
import common.utils.bicubic_interpolation as bicubic_interpolation
import time


def post_process_samples(
    samples: np.ndarray, clamp_min: int = 12, clamp_max: int = 57
) -> np.ndarray:
    """
    Clamps the values in a numpy array to a specified range.

    Args:
        samples (np.ndarray): Input array.
        clamp_min (int): Minimum value for clamping.
        clamp_max (int): Maximum value for clamping.

    Returns:
        np.ndarray: The processed array with clamped values.
    """
    # Clamp the values using numpy's clip function
    processed_samples = np.clip(samples, a_min=clamp_min, a_max=clamp_max)
    return processed_samples


def apply_bicubic_interpolation(
    data, downsample_factor, batch_size=32, permute=False, channel_last=False
):
    """
    Downsamples spatial dimensions of a tensor using bicubic interpolation.

    Processes data in batches to manage memory usage.

    Args:
        data (torch.Tensor or np.ndarray): Input data, expected to be 4D or 5D.
        downsample_factor (float): The factor by which to downsample.
        batch_size (int): Batch size for GPU processing.
        permute (bool): If True, permutes dimensions from (B, H, W, T) to (B, T, H, W)
                        before processing and back after.
        channel_last (bool): Indicates if the channel dimension is last.

    Returns:
        The downsampled data, in the same format (Tensor/ndarray) as the input.
    """
    is_original_tensor = isinstance(data, torch.Tensor)
    is_len_5 = len(data.shape) == 5
    if is_len_5:
        if is_original_tensor:
            if channel_last:
                data = data.squeeze(4)
            else:
                data = data.squeeze(1)
        else:
            if channel_last:
                data = torch.from_numpy(data).squeeze(4)
            else:
                data = torch.from_numpy(data).squeeze(1)
    else:
        if not is_original_tensor:
            data = torch.from_numpy(data)
    if permute:
        # Permute B x H x W x T to B x T x H x W

        data = np.transpose(data.numpy(), (0, 3, 1, 2))

    # Prepare output container
    downsampled_data = []

    # Process in batches
    num_batches = (data.size(0) + batch_size - 1) // batch_size
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, data.size(0))

        # Extract the batch
        batch = data[start_idx:end_idx]

        # Transfer the batch to GPU if available
        if torch.cuda.is_available():
            batch = batch.cuda()

        # Apply bicubic interpolation using the existing `imresize` function
        batch_downsampled = bicubic_interpolation.imresize(
            batch, scale=(1 / downsample_factor)
        )

        # Move back to CPU
        if torch.cuda.is_available():
            batch_downsampled = batch_downsampled.cpu()
            batch = batch.cpu()

        # Append to output
        downsampled_data.append(batch_downsampled)

    # Concatenate all batches along the batch dimension
    downsampled_data = torch.cat(downsampled_data, dim=0)

    # Convert to numpy and permute B x C x H x W to B x H x W x C
    downsampled_data = downsampled_data.numpy()
    if permute:
        downsampled_data = np.transpose(downsampled_data, (0, 2, 3, 1))

    # Clean cache
    if cuda.is_available():
        cuda.empty_cache()

    if is_len_5:
        if channel_last:
            downsampled_data = torch.from_numpy(downsampled_data).unsqueeze(4)
        else:
            downsampled_data = torch.from_numpy(downsampled_data).unsqueeze(1)
    if not is_original_tensor and isinstance(downsampled_data, torch.Tensor):
        downsampled_data = downsampled_data.numpy()

    return downsampled_data


def apply_bicubic_interpolation_sizes(
    data, sizes, batch_size=32, permute=False, channel_last=False
):
    """
    Downsamples spatial dimensions to a fixed size using bicubic interpolation.

    Args:
        data (torch.Tensor or np.ndarray): Input data.
        sizes (tuple): Target (height, width) for the output.
        batch_size (int): Batch size for GPU processing.
        permute (bool): If True, handles permutation between T and H, W dimensions.
        channel_last (bool): Indicates if the channel dimension is last.

    Returns:
        The downsampled data.
    """
    is_original_tensor = isinstance(data, torch.Tensor)
    is_len_5 = len(data.shape) == 5
    if is_len_5:
        if is_original_tensor:
            if channel_last:
                data = data.squeeze(4)
            else:
                data = data.squeeze(1)
        else:
            if channel_last:
                data = torch.from_numpy(data).squeeze(4)
            else:
                data = torch.from_numpy(data).squeeze(1)
    else:
        if not is_original_tensor:
            data = torch.from_numpy(data)
    if permute:
        # Permute B x H x W x T to B x T x H x W

        data = np.transpose(data.numpy(), (0, 3, 1, 2))
    # Prepare output container
    downsampled_data = []

    # Measure time taken
    start_time = time.time()

    # Process in batches
    num_batches = (data.size(0) + batch_size - 1) // batch_size
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = min((batch_idx + 1) * batch_size, data.size(0))

        # Extract the batch
        batch = data[start_idx:end_idx]

        # Transfer the batch to GPU if available
        if torch.cuda.is_available():
            batch = batch.cuda()

        # Apply bicubic interpolation using the existing `imresize` function
        batch_downsampled = bicubic_interpolation.imresize(batch, sizes=sizes)

        # Move back to CPU
        if torch.cuda.is_available():
            batch_downsampled = batch_downsampled.cpu()
            batch = batch.cpu()

        # Append to output
        downsampled_data.append(batch_downsampled)

    # Measure time taken
    end_time = time.time()
    print(f"Time taken for bicubic interpolation: {end_time - start_time:.2f} seconds")

    # Concatenate all batches along the batch dimension
    downsampled_data = torch.cat(downsampled_data, dim=0)

    # Convert to numpy and permute B x C x H x W to B x H x W x C
    downsampled_data = downsampled_data.numpy()
    if permute:
        downsampled_data = np.transpose(downsampled_data, (0, 2, 3, 1))

    if is_len_5:
        if channel_last:
            downsampled_data = torch.from_numpy(downsampled_data).unsqueeze(4)
        else:
            downsampled_data = torch.from_numpy(downsampled_data).unsqueeze(1)
    if not is_original_tensor and isinstance(downsampled_data, torch.Tensor):
        downsampled_data = downsampled_data.numpy()

    # Clean cache
    if cuda.is_available():
        cuda.empty_cache()
    return downsampled_data


def pad_to_shape(tensor, target_shape_latent):
    """
    Pads a tensor to a target spatial shape using constant zero padding.

    Assumes tensor is shaped (..., H, W, C).

    Args:
        tensor (torch.Tensor): The input tensor.
        target_shape_latent (tuple): The target (Height, Width).

    Returns:
        torch.Tensor: The padded tensor.
    """
    current_H, current_W = tensor.shape[-3], tensor.shape[-2]
    target_H, target_W = target_shape_latent

    # Calculate padding for Height (dim -3)
    pad_H_total = target_H - current_H
    pad_H_top = pad_H_total // 2
    pad_H_bottom = pad_H_total - pad_H_top

    # Calculate padding for Width (dim -2)
    pad_W_total = target_W - current_W
    pad_W_left = pad_W_total // 2
    pad_W_right = pad_W_total - pad_W_left

    if pad_H_total < 0 or pad_W_total < 0:
        # This function only handles padding, not cropping.
        raise ValueError(
            "Target shape is smaller than current shape. Cropping should be handled by a different function."
        )

    # Padding format for F.pad: (pad_dim_1_left, pad_dim_1_right, pad_dim_2_top, pad_dim_2_bottom, ...)
    # We are padding dims -2 (W) and -3 (H). The channel dim (-1) is not padded.
    padding_dims = (0, 0, pad_W_left, pad_W_right, pad_H_top, pad_H_bottom)

    if any(p > 0 for p in padding_dims):  # Only pad if necessary
        padded_tensor = F.pad(tensor, padding_dims, mode="constant", value=0)
        return padded_tensor
    return tensor


def crop_to_shape(tensor, target_shape_latent):
    """
    Crops a tensor to a target spatial shape from the center.

    Assumes tensor is shaped (..., H, W, C).

    Args:
        tensor (torch.Tensor): The input tensor.
        target_shape_latent (tuple): The target (Height, Width).

    Returns:
        torch.Tensor: The cropped tensor.
    """
    current_H, current_W = tensor.shape[-3], tensor.shape[-2]
    target_H, target_W = target_shape_latent

    if target_H > current_H or target_W > current_W:
        raise ValueError(
            "Target shape is larger than current shape. Padding should be handled by a different function."
        )

    # Calculate cropping for Height (dim -3)
    crop_H_total = current_H - target_H
    crop_H_top = crop_H_total // 2
    crop_H_bottom = crop_H_total - crop_H_top

    # Calculate cropping for Width (dim -2)
    crop_W_total = current_W - target_W
    crop_W_left = crop_W_total // 2
    crop_W_right = crop_W_total - crop_W_left

    # Perform cropping
    cropped_tensor = tensor[
        ...,
        crop_H_top : current_H - crop_H_bottom,
        crop_W_left : current_W - crop_W_right,
        :,
    ]

    return cropped_tensor


#############################################
# Adapted Dataset: Reading Directly from HDF5
#############################################


class ArsoH5Dataset(Dataset):
    """
    Dataset for ARSO radar data that reads directly from an HDF5 file.

    Uses lazy loading for the HDF5 file to work efficiently with multiple
    DataLoader workers. Assumes the HDF5 file contains 'zm_IN' and 'zm_OUT' datasets.
    """

    def __init__(
        self,
        h5_file_path,
        indices=None,
        transform=None,
        channel_last=False,
        return_special_indices=False,
    ):
        """
        Initializes the dataset.

        Args:
            h5_file_path (str): Path to the HDF5 data file.
            indices (array-like, optional): List of indices for this dataset split.
                                           If None, all samples are used.
            transform (callable, optional): Transform to apply to the input sample.
            channel_last (bool): If True, permutes tensors to (T, H, W, C).
            return_special_indices (bool): If True, also returns 'special_indices' from the HDF5 file.
        """
        self.h5_file_path = h5_file_path
        self.indices = indices
        self.transform = transform
        self.channel_last = channel_last
        self.h5_file = None  # The file will be opened lazily in __getitem__
        self.return_special_indices = return_special_indices

    def __len__(self):
        if self.indices is not None:
            return len(self.indices)
        else:
            # Open the file temporarily to determine length.
            if self.h5_file is None:
                with h5py.File(self.h5_file_path, "r") as f:
                    length = len(f["zm_IN"])
            else:
                length = len(self.h5_file["zm_IN"])
            return length

    def __getitem__(self, idx):
        """
        Retrieves a single input/target sequence from the HDF5 file.
        """
        # Lazy file opening (important for DataLoader workers)
        if self.h5_file is None:
            self.h5_file = h5py.File(self.h5_file_path, "r")
        # Map the provided index to the actual sample index.
        real_idx = self.indices[idx] if self.indices is not None else idx

        # Read input and target directly from the HDF5 file.
        # (It is assumed that the HDF5 file has datasets "zm_IN" and "zm_OUT".)
        input_data = self.h5_file["zm_IN"][real_idx]  # e.g. shape: (T, H, W)
        target_data = self.h5_file["zm_OUT"][real_idx]  # e.g. shape: (T, H, W)

        # Optionally apply a transform (e.g. interpolation) to the input.
        if self.transform:
            input_data = self.transform(input_data)

        # Convert the numpy arrays to torch tensors.
        input_tensor = torch.from_numpy(input_data).float()
        target_tensor = torch.from_numpy(target_data).float()

        # Add a channel dimension (from (T, H, W) to (C, T, H, W))
        input_tensor = input_tensor.unsqueeze(0)
        target_tensor = target_tensor.unsqueeze(0)

        if self.channel_last:
            # Permute so that channels come last: (C, T, H, W) -> (T, H, W, C)
            input_tensor = input_tensor.permute(1, 2, 3, 0)
            target_tensor = target_tensor.permute(1, 2, 3, 0)

        if self.return_special_indices:
            special_indices = self.h5_file["special_indices"][real_idx]
            return input_tensor, target_tensor, special_indices
        else:
            return input_tensor, target_tensor


class ArsoH5LatentDataset(Dataset):
    """
    Dataset for ARSO data in latent space, read from an HDF5 file.

    Handles lazy loading and optional padding of latent tensors to a uniform shape.
    Assumes HDF5 file contains 'zm_IN_latent' and 'zm_OUT_latent' datasets.
    """

    def __init__(
        self,
        h5_file_path,
        indices=None,
        transform=None,
        channel_last=False,
        target_shape_latent=None,
    ):
        """
        Initializes the latent space dataset.

        Args:
            h5_file_path (str): Path to the HDF5 data file.
            indices (array-like, optional): List of indices for this dataset split.
            transform (callable, optional): Transform to apply to the input sample.
            channel_last (bool): If True, permutes tensors to (T, H, W, C).
            target_shape_latent (tuple, optional): Target (H, W) to pad latent tensors to.
        """
        self.h5_file_path = h5_file_path
        self.indices = indices
        self.transform = transform
        self.channel_last = channel_last
        self.h5_file = None  # The file will be opened lazily in __getitem__
        self.target_shape_latent = target_shape_latent

    def __len__(self):
        if self.indices is not None:
            return len(self.indices)
        else:
            if self.h5_file is None:
                with h5py.File(self.h5_file_path, "r") as f:
                    length = len(f["zm_IN_latent"])
            else:
                length = len(self.h5_file["zm_IN_latent"])
            return length

    def __getitem__(self, idx):
        """
        Retrieves a single latent input/target sequence.
        """
        if self.h5_file is None:
            self.h5_file = h5py.File(self.h5_file_path, "r")

        real_idx = self.indices[idx] if self.indices is not None else idx

        input_data_np = self.h5_file["zm_IN_latent"][real_idx]  # Shape: (T, C, H, W)
        target_data_np = self.h5_file["zm_OUT_latent"][real_idx]  # Shape: (T, C, H, W)

        input_tensor = torch.from_numpy(input_data_np).float()
        target_tensor = torch.from_numpy(target_data_np).float()

        if self.target_shape_latent:
            input_tensor = pad_to_shape(input_tensor, self.target_shape_latent)
            target_tensor = pad_to_shape(target_tensor, self.target_shape_latent)

        if self.transform:  # Apply transform after padding
            # Note: Transform might need to be aware of the new padded shape
            # For simplicity, assuming transform (if any) is compatible.
            input_tensor = self.transform(input_tensor)

        if self.channel_last:
            input_tensor = input_tensor.permute(0, 2, 3, 1)  # (T, H, W, C)
            target_tensor = target_tensor.permute(0, 2, 3, 1)  # (T, H, W, C)

        return input_tensor, target_tensor

