
"""Load and preprocess the wind dataset."""
import os
from argparse import ArgumentParser
from pathlib import Path

import torch
from dotenv import load_dotenv
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive

from .sequential import AdaptDataset, CombineSequentialDataset, ContextSeqDataset
from .utils import Scale, ScaleRMSE
from .windspeed import BaseDataset, Day, Week


class ScaleDataset(Dataset):
    """Scale a dataset by mean and std."""

    def __init__(self, dataset, meanx, stdx, meany, stdy):
        """Initialize the dataset.

        Args:
            dataset (torch.utils.data.Dataset): Dataset to split.
            meanx (float): Mean of the x coordinates.
            stdx (float): Standard deviation of the x coordinates.
            meany (float): Mean of the y coordinates.
            stdy (float): Standard deviation of the y coordinates.
        """
        self.dataset = dataset
        self.meanx = meanx
        self.stdx = stdx
        self.meany = meany
        self.stdy = stdy

    def __getitem__(self, i):
        """Get an item.

        Args:
            index (int): Index of the item.

        Returns:
            tuple: Normalized Inputs and targets of the item.
        """
        (cx, cy, tx), ty = self.dataset[i]
        cx = (cx - self.meanx) / self.stdx
        tx = (tx - self.meanx) / self.stdx
        cy = (cy - self.meany) / self.stdy
        ty = (ty - self.meany) / self.stdy
        return (cx, cy, tx), ty

    def __len__(self):
        """Get the number of items.

        Returns:
            int: Number of items.
        """
        return len(self.dataset)


def get_mean_std(dataset):
    """Get the mean and std of a dataset.

    Args:
        dataset (torch.utils.data.Dataset): Dataset to split.

    Returns:
        tuple: Mean and std of the dataset.
    """
    cx = torch.cat([x[0][0] for x in dataset])
    tx = torch.cat([x[0][2] for x in dataset])
    x = torch.cat([cx, tx])
    cy = torch.cat([x[0][1] for x in dataset])
    ty = torch.cat([x[1] for x in dataset])
    y = torch.cat([cy, ty])
    return x.mean(0), x.std(0), y.mean(0), y.std(0)


def datasets(tw):
    """Load the wind dataset.

    Returns:
        tuple: Train dataset, validation dataset, metric.
    """
    load_dotenv()
    if tw is None or tw == 30:
        train_dataset = torch.load(os.getenv("TRAIN_PATH"))
        val_dataset = torch.load(os.getenv("VAL_PATH"))
    else:
        train_dataset = torch.load(os.getenv("DATASET_PATH") + f"/train{tw}.pt")
        val_dataset = torch.load(os.getenv("DATASET_PATH") + f"/val{tw}.pt")

    meanx, stdx, meany, stdy = get_mean_std(train_dataset)
    train_dataset = ScaleDataset(train_dataset, meanx, stdx, meany, stdy)
    val_dataset = ScaleDataset(val_dataset, meanx, stdx, meany, stdy)
    metric = ScaleRMSE(Scale(meany, stdy))

    return train_dataset, val_dataset, metric


if __name__ == "__main__":
    load_dotenv()
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset", choices=["default", "otherhours", "download"], default="default"
    )
    args = parser.parse_args()

    csv = Path(os.getenv("CSV_PATH"))
    out = Path(os.getenv("DATASET_PATH"))
    print(f"{csv} -> {out}")

    if args.dataset == "download":
        mirror = "https://zenodo.org/record/5074237/files/"
        tarname = "SkySoft_WindSpeed.tar.gz"
        md5 = "87b5f8c8e7faad810b43af6d67a49428"

        if not Path(os.getenv("CSV_FOLDER")).exists():
            download_and_extract_archive(
                mirror + tarname,
                download_root=os.getenv("CSV_FOLDER"),
                filename=tarname,
                md5=md5,
            )
    elif args.dataset == "default":
        val = BaseDataset(csv, Day(3, 3))
        val = AdaptDataset(ContextSeqDataset(val, 1 * 60, 30 * 60))
        torch.save(val, out / "val.pt")

        train = [BaseDataset(csv, Week(i)) for i in [0, 1, 2, 4]]
        train = [ContextSeqDataset(d, 1 * 60, 30 * 60) for d in train]
        train = CombineSequentialDataset(train)
        train = AdaptDataset(train)
        torch.save(train, out / "train.pt")
    else:
        tws = [10, 20, 60, 120, 240, 360]
        for tw in tws:
            val = BaseDataset(csv, Day(3, 3))
            val = AdaptDataset(ContextSeqDataset(val, 1 * 60, tw * 60))
            torch.save(val, out / f"val{tw}.pt")

            train = [BaseDataset(csv, Week(i)) for i in [0, 1, 2, 4]]
            train = [ContextSeqDataset(d, 1 * 60, tw * 60) for d in train]
            train = CombineSequentialDataset(train)
            train = AdaptDataset(train)
            torch.save(train, out / f"train{tw}.pt")
