from pathlib import Path
from typing import Union, Dict

import numpy as np
import torch
from torch.utils.data import Dataset


class PDE1DDataset(Dataset):
    """
    Dataset for 1D heat equation trajectories stored in a single .npz file.

    The .npz is expected to contain:
      - data: shape (n_samples, n_times, n_x)
      - x: grid (n_x,) [optional]
      - t: times (n_times,) [optional]

    This dataset returns as input the initial condition at time index 0
    (optionally strided in space), and as label the full field at a
    configurable target time index.

    Shapes per sample:
      - x: (n_x_input, 1) where n_x_input = n_x // input_stride
      - y: (n_x, 1)
    """

    def __init__(
            self,
            data_path: Union[str, Path],
            dtype: torch.dtype = torch.float32,
            start_time_index: int = 0,
            target_time_index: int = 1,
            input_stride: int = 1,
    ) -> None:
        self.data_path = Path(data_path) if isinstance(data_path, str) else data_path
        self.dtype = dtype

        self.input_stride = input_stride

        npz = np.load(self.data_path)
        data: np.ndarray = npz["data"]  # (N, T, X)
        self.data = torch.tensor(data, dtype=dtype)

        # Optional grids
        x_grid = npz["x"] if "x" in npz.files else np.linspace(0.0, 1.0, data.shape[-1], dtype=np.float32)
        self.latent_grid = torch.tensor(x_grid, dtype=dtype).unsqueeze(-1)
        self.input_grid = self.latent_grid[::self.input_stride]

        self.n_samples = self.data.shape[0]
        self.n_times = self.data.shape[1]
        self.nx = self.data.shape[2]

        if not (0 <= target_time_index < self.n_times):
            raise ValueError(
                f"target_time_index {target_time_index} is out of bounds for n_times={self.n_times}"
            )
        self.target_time_index = int(target_time_index)
        self.start_time_index = int(start_time_index)

    def __len__(self) -> int:
        return self.n_samples

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        if idx >= self.n_samples:
            raise IndexError("Index out of bounds for dataset.")

        # Input is time-step 0, optionally downsampled in space
        u0 = self.data[idx, self.start_time_index]  # (X,)
        x = u0.unsqueeze(0)  # (1, X)
        x = x[..., :: self.input_stride]  # (1, X_in)
        x = x.T  # (X_in, 1)

        # Output/label is full grid at the specified time index
        y = self.data[idx, self.target_time_index]  # (X,)
        y = y.unsqueeze(-1)  # (X, 1)

        return {
            "x": x,
            "y": y,
            "input_grid": self.input_grid,
            "latent_grid": self.latent_grid,
            "output_grid": self.latent_grid,
        }
