
"""Load and preprocess the wind dataset."""
import os

import torch
from dotenv import load_dotenv
from torch.utils.data import Dataset

from .utils import Scale, ScaleRMSE, StackDataset


class ScaleDataset(Dataset):
    """Scale a dataset by mean and std."""

    def __init__(self, dataset, mean_cy, std_cy, mean_ty, std_ty):
        """Initialize the dataset.

        Args:
            dataset (torch.utils.data.Dataset): Dataset to split.
            mean_cy (torch.Tensor): Mean of the context targets.
            std_cy (torch.Tensor): Std of the context targets.
            mean_ty (torch.Tensor): Mean of the target targets.
            std_ty (torch.Tensor): Std of the target targets.
        """
        self.dataset = dataset
        self.mean_cy = mean_cy
        self.std_cy = std_cy
        self.mean_ty = mean_ty
        self.std_ty = std_ty

    def __getitem__(self, i):
        """Get an item.

        Args:
            index (int): Index of the item.

        Returns:
            tuple: Normalized Inputs and targets of the item.
        """
        (cx, cy, tx), ty = self.dataset[i]
        cy = (cy - self.mean_cy) / self.std_cy
        ty = (ty - self.mean_ty) / self.std_ty
        return (cx, cy, tx), ty

    def __len__(self):
        """Get the number of items.

        Returns:
            int: Number of items.
        """
        return len(self.dataset)


def datasets():
    """Load the Darcy Flow dataset.

    Returns:
        tuple: Train dataset, validation dataset, metric.
    """
    load_dotenv()
    train_dataset = torch.load(os.getenv("DARCY_TRAIN_PATH"))
    cx, cy, tx, ty = (t.flatten(0, 1) for t in train_dataset)
    train_dataset = StackDataset(StackDataset(cx, cy, tx), ty)
    mean_cy, std_cy = cy.mean(), cy.std()
    mean_ty, std_ty = ty.mean(), ty.std()
    val_dataset = torch.load(os.getenv("DARCY_VAL_PATH"))
    cx, cy, tx, ty = (t.flatten(0, 1) for t in val_dataset)
    val_dataset = StackDataset(StackDataset(cx, cy, tx), ty)

    train_dataset = ScaleDataset(train_dataset, mean_cy, std_cy, mean_ty, std_ty)
    val_dataset = ScaleDataset(val_dataset, mean_cy, std_cy, mean_ty, std_ty)
    metric = ScaleRMSE(Scale(mean_ty, std_ty))

    return train_dataset, val_dataset, metric
