import os
import warnings
from pathlib import Path

import numpy as np
from skimage.transform import resize
from tqdm import tqdm
import xarray as xr

warnings.filterwarnings("error")

# Add the path to the directory where the ERA5 dataset should be written.
save_dir = ""

if __name__ == "__main__":
    Path.mkdir(Path(save_dir), exist_ok=True)
    noise_std = 0.01

    # Open dataset
    dset = xr.open_dataset("era5_2m_temperature_2009-2017_01.grib")
    dset = dset["t2m"].values

    # Normalize dataset
    dset = (dset - dset.min()) / (dset.max() - dset.min())

    # Splits dataset into staggered time steps
    split = np.lib.stride_tricks.sliding_window_view(dset, window_shape=2, axis=0)
    t0 = split[..., 0]
    t1 = split[..., 1]

    i = 0
    for x, y in tqdm(zip(t0, t1), total=len(t0)):
        # Resize to 256x256
        y = resize(
            y, (256, 256), mode="reflect", anti_aliasing=True, preserve_range=True
        )
        # Add noise
        y += np.random.randn(*y.shape) * noise_std

        np.save(os.path.join(save_dir, f"target_{i}.npy"), x)
        np.save(os.path.join(save_dir, f"measurement_{i}.npy"), y)
        i += 1
