import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data import random_split

from evaluation.predictive.dataset import PredictiveDataset


class PredictiveDM(LightningDataModule):

    def __init__(
        self, data_real_ev: torch.Tensor, data_synthetic_ev: torch.Tensor, train_percentage: float, cutoff: float,
        batch_size: int, num_workers: int, pin_memory: bool
    ) -> None:
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        n_samples_train = int(train_percentage * len(data_real_ev))
        n_samples_val = len(data_real_ev) - n_samples_train

        data_real_train, data_real_val = random_split(data_real_ev, [n_samples_train, n_samples_val])
        data_synthetic_train, _ = random_split(data_synthetic_ev, [n_samples_train, n_samples_val])

        self.pd_train_real = PredictiveDataset(data_real_train, cutoff)
        self.pd_train_synthetic = PredictiveDataset(data_synthetic_train, cutoff)
        self.pd_val = PredictiveDataset(data_real_val, cutoff)

        self.train_on_real = True

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.pd_train_real if self.train_on_real else self.pd_train_synthetic,
            self.batch_size, True, num_workers=self.num_workers, pin_memory=self.pin_memory
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.pd_val, self.batch_size, False, num_workers=self.num_workers, pin_memory=self.pin_memory
        )
