from pytorch_lightning import LightningDataModule
import pytorch_lightning as pl
from abc import ABC, abstractmethod
from torch.utils.data import DataLoader, TensorDataset

import torch
import os
from src.utils.transforms import get_transforms

class AnyDatamodule(pl.LightningDataModule, ABC):
    def __init__(self, dataset_name: str, data_dir: str,
                 batch_size: int = 32,
                 transforms=None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.dataset_name = dataset_name
        self.transforms = transforms
        self.num_workers = 0 #int(os.cpu_count())

    @abstractmethod
    def get_datasets(self):
        pass

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def testt(self):
        return 20
    
    # @abstractmethod
    # def predict_dataloader(self):
    #     pass

class GaussianMixtureDatamodule(AnyDatamodule):
    def __init__(self, dataset_name: str, data_dir: str,
                 batch_size: int = 32,
                 transforms=None):

        super().__init__(dataset_name, data_dir, batch_size, transforms)

        self.get_datasets()
  

    def get_datasets(self):
        self.train_dataset = TensorDataset(torch.load(os.path.join(self.data_dir, "trainset.py")))
        self.val_dataset = TensorDataset(torch.load(os.path.join(self.data_dir, "valset.py")))
        self.test_dataset = TensorDataset(torch.load(os.path.join(self.data_dir, "testset.py")))




def get_datamodule(cfg):
    """Return the datamodule based on the config."""
    
    transforms = get_transforms(cfg)
    
    if cfg.dataset.dataset_name == "gauss_mixture" or cfg.dataset.dataset_name == "gauss_mixture_v3":
        datamodule = GaussianMixtureDatamodule(data_dir=cfg.dataset.data_dir,
                                    batch_size=cfg.trainer.batch_size,
                                    dataset_name=cfg.dataset.dataset_name,
                                    transforms=transforms)
    
    return datamodule

def get_simple_dataset():
    datamodule = GaussianMixtureDatamodule(data_dir="data/raw/gauss_mixture_v3",
                                    batch_size=32,
                                    dataset_name="gauss_mixture_v3")
    
    return datamodule

def get_complicated_dataset(sphere, dim, mix):
    data_dir = f"data/raw/gauss_mixture_{sphere}dim_{dim}_mix_{mix}"
    datamodule = GaussianMixtureDatamodule(data_dir=data_dir,
                                           batch_size=32,
                                           dataset_name="gauss_mixture_{sphere}dim_{dim}_mix_{mix}")
    
    return datamodule