
"""Utilities to load the Poisson dataset."""
import os

import torch
from dotenv import load_dotenv
from graph_element_networks.gen_datasets import FTDataset
from graph_element_networks.poisson_datasets import (
    PoissonSquareRoomInpDataset,
    PoissonSquareRoomOutDataset,
)

from .utils import StackDataset


def datasets(device, train_data_frac=0.8):
    """Load the Poisson dataset.

    Args:
        device (torch.device): The device to load the data on.
        train_data_frac (float): The fraction of the data to use for training.

    Returns:
        train_dataset (torch.utils.data.Dataset): The training dataset.
        val_dataset (torch.utils.data.Dataset): The validation dataset.
    """
    load_dotenv()

    data_dir = os.getenv("POISSON_DATA_FOLDER")

    if data_dir is None:
        raise ValueError(
            "Please set the environment variable POISSON_DATA_FOLDER to the data folder."
            "You can download the data from"
            "https://github.com/FerranAlet/graph_element_networks/tree/master/data"
        )

    full_dataset = FTDataset(
        inp_datasets=[PoissonSquareRoomInpDataset],
        inp_datasets_args=[{"dir_path": data_dir + "/poisson_inp"}],
        out_datasets=[PoissonSquareRoomOutDataset],
        out_datasets_args=[{"file_path": data_dir + "/poisson_out.hdf5"}],
    )
    num_rows = len(full_dataset)
    train_size = int(train_data_frac * num_rows)
    test_size = num_rows - train_size
    train_raw, val_raw = torch.utils.data.random_split(
        full_dataset,
        [train_size, test_size],
        generator=torch.Generator().manual_seed(42),
    )

    train_cx = torch.cat([cx for (((cx, _),), _), _ in train_raw]).to(device)
    train_cy = torch.cat([cy for (((_, cy),), _), _ in train_raw]).to(device)
    train_tx = torch.cat([tx for (_, ((tx, _),)), _ in train_raw]).to(device)
    train_ty = torch.cat([ty for (_, ((_, ty),)), _ in train_raw]).to(device)

    train_inputs = StackDataset(train_cx, train_cy, train_tx)
    train_dataset = StackDataset(train_inputs, train_ty)

    val_cx = torch.cat([cx for (((cx, _),), _), _ in val_raw]).to(device)
    val_cy = torch.cat([cy for (((_, cy),), _), _ in val_raw]).to(device)
    val_tx = torch.cat([tx for (_, ((tx, _),)), _ in val_raw]).to(device)
    val_ty = torch.cat([ty for (_, ((_, ty),)), _ in val_raw]).to(device)

    val_inputs = StackDataset(val_cx, val_cy, val_tx)
    val_dataset = StackDataset(val_inputs, val_ty)

    return train_dataset, val_dataset
