import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset, random_split

from evaluation.discriminative.dataset import DiscriminativeDataset


class DiscriminativeDM(LightningDataModule):

    def __init__(
        self, data_real_ev: torch.Tensor, data_synthetic_ev: torch.Tensor, train_percentage: float,
        batch_size: int, num_workers: int, pin_memory: bool
    ) -> None:
        super().__init__()

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

        data_real_ev = TensorDataset(torch.Tensor(data_real_ev), torch.ones(len(data_real_ev)))
        data_synthetic_ev = TensorDataset(torch.Tensor(data_synthetic_ev), torch.zeros(len(data_synthetic_ev)))

        n_samples_train = int(train_percentage * len(data_real_ev))
        n_samples_val = len(data_real_ev) - n_samples_train

        data_real_train, data_real_test = random_split(data_real_ev, [n_samples_train, n_samples_val])
        data_synthetic_train, data_synthetic_test = random_split(
            data_synthetic_ev, [n_samples_train, n_samples_val]
        )

        self.dd_train = DiscriminativeDataset(data_real_train, data_synthetic_train)
        self.dd_val = DiscriminativeDataset(data_real_test, data_synthetic_test)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dd_train, self.batch_size, True, num_workers=self.num_workers, pin_memory=self.pin_memory
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dd_val, self.batch_size, True, num_workers=self.num_workers, pin_memory=self.pin_memory
        )
