import os
from glob import glob
from typing import Any

import numpy as np
from torch.utils.data import Dataset


class Era5(Dataset):
    def __init__(self, data_dir: str, transform: Any):
        super().__init__()

        self.target_files = glob(os.path.join(data_dir, "target_*.npy"))
        self.measurement_files = glob(os.path.join(data_dir, "measurement_*.npy"))
        print(f"Dataset size: {len(self.target_files)}")

        self.tfm = transform

    def __len__(self):
        return len(self.target_files)

    def __getitem__(self, idx: int):
        target = self.tfm(np.load(self.target_files[idx]))
        measurement = self.tfm(np.load(self.measurement_files[idx]))

        # Normalize to [-1, 1]
        target = (target - target.min()) / (target.max() - target.min())
        measurement = (measurement - measurement.min()) / (
            measurement.max() - measurement.min()
        )
        target = target * 2 - 1
        measurement = measurement * 2 - 1
        return target.float(), measurement.float()
