# ---------------------------------------------------------------
# This file has been modified from following sources: 
# Source:
# 1. https://github.com/NVlabs/LSGM/blob/main/util/ema.py (NVIDIA License)
# 2. https://github.com/NVlabs/denoising-diffusion-gan/blob/main/train_ddgan.py (NVIDIA License)
# ---------------------------------------------------------------

import os
import torch
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, CIFAR10
from PIL import Image
import os.path
from torch.utils.data import Dataset
import random
from torchvision.transforms import InterpolationMode, RandomCrop, CenterCrop
from forward_operator import get_operator


class Sampler: # a dump data sampler
    def __init__(self, dataloader):
        self.dataloader = dataloader

    def sample(self):
        try: 
            data = next(self.iterloader)
        except:
            self.iterloader = iter(self.dataloader)
            data = next(self.iterloader)
        
        try: src, tgt = data
        except: pass
        
        return src.float(), tgt.float()
    

# Image datasets
class CelebA_HQ(data.Dataset):
    '''Note: CelebA (about 200000 images) vs CelebA-HQ (30000 images)'''
    def __init__(self, root, partition_path, mode='train', transform=None):
        self.root = root
        self.mode = mode
        self.transform = transform

        # Split train/val/test 
        self.partition_dict = {}
        self.get_partition_label(partition_path)
        self.train_dataset = []
        self.val_dataset = []
        self.test_dataset = []
        self.save_img_path()
        print('[Celeba-HQ Dataset]')
        print(f'Train {len(self.train_dataset)} | Val {len(self.val_dataset)} | Test {len(self.test_dataset)}')

        if mode == 'train':
            self.dataset = self.train_dataset
        elif mode == 'val':
            self.dataset = self.val_dataset
        elif mode == 'test':
            self.dataset = self.test_dataset
        else:
            raise ValueError

    def get_partition_label(self, list_eval_partition_celeba_path):
        '''Get partition labels (Train 0, Valid 1, Test 2) from CelebA
        See "celeba/Eval/list_eval_partition.txt"
        '''
        with open(list_eval_partition_celeba_path, 'r') as f:
            for line in f.readlines():
                filenum = line.split(' ')[0].split('.')[0] # Use 6-digit 'str' instead of int type
                partition_label = int(line.split(' ')[1]) # 0 (train), 1 (val), 2 (test)
                self.partition_dict[filenum] = partition_label

    def save_img_path(self):
        for filename in os.listdir(self.root):
            assert os.path.isfile(os.path.join(self.root, filename))
            filenum = filename.split('.')[0]
            label = self.partition_dict[filenum]
            if label == 0:
                self.train_dataset.append(os.path.join(self.root, filename))
            elif label == 1:
                self.val_dataset.append(os.path.join(self.root, filename))
            elif label == 2:
                self.test_dataset.append(os.path.join(self.root, filename))
            else:
                raise ValueError

    def __getitem__(self, index):
        img_path = self.dataset[index]
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.dataset)
    
class PairedTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img1, img2):
        # save : random seed
        state = random.getstate()
        torch_state = torch.get_rng_state()


        img1 = self.transform(img1)

        # restore : random seed
        random.setstate(state)
        torch.set_rng_state(torch_state)


        img2 = self.transform(img2)

        return img1, img2

class IPDatasetTrain(Dataset):
    def __init__(self, image_dir, factor=2, seed=0, transform_orig=None, transform_low=None):
        self.image_dir = image_dir
        self.factor = factor
        self.seed = seed
        self.transform_orig = transform_orig
        self.transform_low = transform_low

        self.paired_transform = PairedTransform(self.transform_orig)

        self.image_paths = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])[:1000]

        self.low_image_paths = self.image_paths.copy()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_orig = Image.open(self.image_paths[idx])
        img_low  = Image.open(self.low_image_paths[idx])


        img_low, img_orig = self.paired_transform(img_low, img_orig)

        return img_low, img_orig
    
class IPDatasetTrain_paired(Dataset):
    def __init__(self, image_dir,low_image_dir, factor=2, seed=0, transform_orig=None, transform_low=None):
        self.image_dir = image_dir
        self.low_image_dir = low_image_dir
        self.factor = factor
        self.seed = seed
        self.transform_orig = transform_orig
        self.transform_low = transform_low

        self.paired_transform = PairedTransform(self.transform_orig)

        self.image_paths = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])[-1000:]
        
        self.low_image_paths = sorted([
            os.path.join(low_image_dir, fname)
            for fname in os.listdir(low_image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]) 
        assert len(self.image_paths) == len(self.low_image_paths)
        
        

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_orig = Image.open(self.image_paths[idx])
        img_low  = Image.open(self.low_image_paths[idx])

        img_low, img_orig = self.paired_transform(img_low, img_orig)

        return img_low, img_orig


class IPDatasetTest(Dataset):
    def __init__(self, image_dir, additional_image_dir, factor=2, transform_orig=None, transform_low=None):
        self.image_dir = image_dir
        self.factor = factor
        self.transform_orig = transform_orig
        self.transform_low = transform_low
        self.additional_image_dir = additional_image_dir
        ##
        self.additional_image_paths = sorted([
            os.path.join(additional_image_dir, fname)
            for fname in os.listdir(additional_image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])
        
        self.image_paths_0 = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])
        
        self.image_paths = self.image_paths_0  + self.additional_image_paths[2000:5000]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])

        img_orig = self.transform_orig(img) if self.transform_orig else img
        img_low  = self.transform_low(img) if self.transform_low else img

        return img_low, img_orig

class IPDatasetTest_paired(Dataset):
    def __init__(self, image_dir,low_image_dir,args, factor=2, transform_orig=None, transform_low=None):
        self.image_dir = image_dir
        self.low_image_dir = low_image_dir
        self.factor = factor
        self.transform_orig = transform_orig
        self.transform_low = transform_low
        
        if args.noise == 'total':
            self.image_paths = sorted([
                os.path.join(image_dir, fname)
                for fname in os.listdir(image_dir)
                if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])
            self.low_image_paths = sorted([
                os.path.join(low_image_dir, fname)
                for fname in os.listdir(low_image_dir)
                if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])
        else:
            self.image_paths = sorted([
                os.path.join(image_dir, fname)
                for fname in os.listdir(image_dir)
                if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])[:-1000]
            self.low_image_paths = sorted([
                os.path.join(low_image_dir, fname)
                for fname in os.listdir(low_image_dir)
                if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
            ])[:-1000]
        assert len(self.low_image_paths) == len(self.image_paths)
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])
        img_  = Image.open(self.low_image_paths[idx])

        img_orig = self.transform_orig(img) if self.transform_orig else img
        img_low  = self.transform_low(img_) if self.transform_low else img_

        return img_low, img_orig


class AnomalyDataset(data.Dataset):
    def __init__(self, dataset, anomaly_dataset, frac=0.01):
        '''
        dataset : target dataset (CIFAR10)
        anomaly_dataset : anomaly dataset (MNIST)
        frac : fraction of anomaly dataset (p=0.01)
        '''
        try: normal_sample, _ = dataset[0]
        except: normal_sample = dataset[0]
        c, size, _ = normal_sample.shape # [c, w, h]
        
        self.dataset = dataset
        self.anomaly_dataset = anomaly_dataset

        self.num_normal = dataset.__len__()
        self.num_anomaly = int(frac * self.num_normal)
        
        self.ANOMALIES = []
        for i in range(self.num_anomaly):
            # get samples
            x = anomaly_dataset[i]
            try: x, _ = x
            except: pass
            # check if image size is same
            if i==0: assert x.shape[1] == size
            # match the number of channels
            if x.shape[0]==1 and c==3:
                x = x.repeat(3,1,1)
            # append to self.ANOMALIES
            self.ANOMALIES.append(x)
    
    def __getitem__(self, index):
        if index < self.num_normal:
            x = self.dataset[index]
            try: x, _ = x
            except: pass
        else:
            x = self.ANOMALIES[index-self.num_normal]
        
        return x

    def __len__(self):
        return self.num_normal + self.num_anomaly

class MixedDataset_train(Dataset):
    def __init__(self, image_dir, image_dir2, mixed_ratio=1, factor=2, seed=0, transform_orig=None, transform_low=None):
        self.image_dir = image_dir # dog
        self.image_dir2 = image_dir2 # cat
        self.mixed_ratio = mixed_ratio # dog : cat = 1 : mixed_ratio 0.3 0.5 1 
        self.factor = factor
        self.seed = seed
        self.transform_orig = transform_orig
        self.transform_low = transform_low

        self.paired_transform = PairedTransform(self.transform_orig)
        
        self.image_paths1 = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]) ## dog
        
        self.image_paths2 = sorted([
            os.path.join(image_dir2, fname)
            for fname in os.listdir(image_dir2)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ]) ## cat
              
        num_data = min(len(self.image_paths1), len(self.image_paths2), 1000)
        num_dog = int(num_data//(self.mixed_ratio+1))
        num_cat = num_data - num_dog        
        self.low_image_paths = self.image_paths1[:num_dog]  + self.image_paths2[:num_cat][::-1] 
        self.image_paths = self.image_paths1[:int(num_data//2)]  + self.image_paths2[:int(num_data//2)][::-1] 
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_orig = Image.open(self.image_paths[idx])
        img_low  = Image.open(self.low_image_paths[idx])


        img_low, img_orig = self.paired_transform(img_low, img_orig)

        return img_low, img_orig

class MixedDataset_test(Dataset):
    def __init__(self, image_dir, image_dir2, mixed_ratio=1, factor=2, seed=0, transform_orig=None, transform_low=None):
        self.image_dir = image_dir # dog
        self.image_dir2 = image_dir2 # cat
        self.mixed_ratio = mixed_ratio # dog : cat = 1 : mixed_ratio 0.3 0.5 1 
        self.factor = factor
        self.seed = seed
        self.transform_orig = transform_orig
        self.transform_low = transform_low

        self.paired_transform = PairedTransform(self.transform_orig)
        
        self.dog_paths = sorted([
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])[:200]
        
        self.cat_paths = sorted([
            os.path.join(image_dir2, fname)
            for fname in os.listdir(image_dir2)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])[:200]
        
        assert len(self.dog_paths) == 200
        assert len(self.cat_paths) == 200
        
        self.image_paths = self.dog_paths  + self.cat_paths[::-1]  
        self.low_image_paths = self.image_paths.copy()    

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_orig = Image.open(self.image_paths[idx])
        img_low  = Image.open(self.low_image_paths[idx])

        img_low, img_orig = self.paired_transform(img_low, img_orig)

        return img_low, img_orig


# get dataloader
def get_dataloader(args):
    num_workers = 4
    if args.dataset == 'mnist':
        dataset = MNIST('./data', train=True, transform=transforms.Compose([
                        transforms.Resize(args.image_size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5), (0.5))]), download=True)
    
    elif args.dataset == 'cifar10':
        dataset = CIFAR10('./data', train=True, transform=transforms.Compose([
                        transforms.Resize(args.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]), download=True)
    
    elif args.dataset == 'cifar10+mnist':
        normal_dataset = CIFAR10('./data', train=True, transform=transforms.Compose([
                        transforms.Resize(args.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]), download=True)
        
        anomaly_dataset = MNIST('./data', train=True, transform=transforms.Compose([
                        transforms.Resize(args.image_size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5), (0.5))]), download=True)
        
        dataset = AnomalyDataset(normal_dataset, anomaly_dataset)
    
    elif args.dataset == 'celeba_256':
        train_transform = transforms.Compose([
                transforms.Resize(args.image_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
            ])
        dataset = CelebA_HQ(
            root='data/celeba-hq/celeba-256',
            partition_path='data/celeba-hq/list_eval_partition_celeba.txt',
            mode='train', # 'train', 'val', 'test'
            transform=train_transform,
        )
    elif args.dataset == 'AFHQ':
        train_transform_list = [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]
        
        val_transform_list = [
            transforms.ToTensor(), 
        ]
        
        if args.normalize:
            train_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
            val_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
        
        train_transform_orig = transforms.Compose(train_transform_list)

        train_transform_low = transforms.Compose(train_transform_list)

        val_transform_orig = transforms.Compose(val_transform_list)

        val_transform_low = transforms.Compose(val_transform_list)


        dataset_train = IPDatasetTrain_paired(image_dir='../../data/afhq128/val_raw',low_image_dir=f'../../data/afhq128/train_{args.operator_type}/noise_{args.noise}' , transform_orig=train_transform_orig, transform_low=train_transform_low)
        if args.noise == 'total':
            dataset_test = IPDatasetTest_paired(image_dir=f'../../data/afhq128/val_raw_total', low_image_dir=f'../../data/afhq128/val_{args.operator_type}/noise_{args.noise}',args=args, transform_orig=val_transform_orig, transform_low=val_transform_low)
        else:
            dataset_test = IPDatasetTest_paired(image_dir=f'../../data/afhq128/val_raw', low_image_dir=f'../../data/afhq128/val_{args.operator_type}/noise_{args.noise}' ,args=args, transform_orig=val_transform_orig, transform_low=val_transform_low)

        train_data_loader = DataLoader(
                                dataset_train,
                                batch_size=args.batch_size,
                                shuffle=True, 
                                num_workers=num_workers,
                                drop_last=True,
                            )

        test_data_loader = DataLoader(
                                dataset_test,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                drop_last=True,
                            )
        return train_data_loader, test_data_loader
    
    elif args.dataset == 'FFHQ':        
        train_transform_list = [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]
        
        val_transform_list = [
            transforms.ToTensor(), 
        ]
        
        if args.normalize:
            train_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
            val_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
        
        train_transform_orig = transforms.Compose(train_transform_list)
        train_transform_low = transforms.Compose(train_transform_list)

        val_transform_orig = transforms.Compose(val_transform_list)
        val_transform_low = transforms.Compose(val_transform_list)


        
        dataset_train = IPDatasetTrain_paired(image_dir='../../data/ffhq128/val_raw',low_image_dir=f'../../data/ffhq128/train_{args.operator_type}/noise_{args.noise}' , transform_orig=train_transform_orig, transform_low=train_transform_low)
        if args.noise == 'total':
            dataset_test = IPDatasetTest_paired(image_dir=f'../../data/ffhq128/val_raw_total', low_image_dir=f'../../data/ffhq128/val_{args.operator_type}/noise_{args.noise}',args=args, transform_orig=val_transform_orig, transform_low=val_transform_low)
        else:
            dataset_test = IPDatasetTest_paired(image_dir=f'../../data/ffhq128/val_raw', low_image_dir=f'../../data/ffhq128/val_{args.operator_type}/noise_{args.noise}' ,args=args, transform_orig=val_transform_orig, transform_low=val_transform_low)
        
        
        train_data_loader = DataLoader(
                                dataset_train,
                                batch_size=args.batch_size,
                                shuffle=True, 
                                num_workers=num_workers,
                                drop_last=True,
                            )

        test_data_loader = DataLoader(
                                dataset_test,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                drop_last=True,
                            )
        
        return train_data_loader, test_data_loader
    
    elif args.dataset == 'DIV2K':        
        train_transform_list = [
            transforms.RandomHorizontalFlip(),
            RandomCrop(128),
            transforms.ToTensor()
        ]
        
        val_transform_list = [
            CenterCrop(128),
            transforms.ToTensor(), 
        ]
        
        if args.normalize:
            train_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
            val_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
        
        train_transform_orig = transforms.Compose(train_transform_list)
        train_transform_low = transforms.Compose(train_transform_list)

        val_transform_orig = transforms.Compose(val_transform_list)
        val_transform_low = transforms.Compose(val_transform_list)

        dataset_train = IPDatasetTrain(image_dir='../../data/DIV2K/DIV2K_train_HR', transform_orig=train_transform_orig, transform_low=train_transform_low)
        if args.operator_type == 'nonlinear_blur':
            dataset_test = IPDatasetTest_paired(image_dir='../../data/DIV2K/DIV2K_valid_centercrop', low_image_dir='../../data/DIV2K/DIV2K_valid_nonlinear', transform_orig=val_transform_orig, transform_low=val_transform_low)
        else:
            dataset_test = IPDatasetTest(image_dir='../../data/DIV2K/DIV2K_valid_HR', transform_orig=val_transform_orig, transform_low=val_transform_low)
     
        train_data_loader = DataLoader(
                                dataset_train,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                drop_last=True,
                            )

        test_data_loader = DataLoader(
                                dataset_test,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                drop_last=True,
                            )
        return train_data_loader, test_data_loader
    
    elif args.dataset == 'mixed':
        ### mixed_ratio (alpha) means dog : cat = 1 : alpha
        train_transform_list = [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor()
        ]
        
        val_transform_list = [
            transforms.ToTensor(), 
        ]
        
        if args.normalize:
            train_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
            val_transform_list.append(transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)))
        
        train_transform_orig = transforms.Compose(train_transform_list)

        train_transform_low = transforms.Compose(train_transform_list)

        val_transform_orig = transforms.Compose(val_transform_list)

        val_transform_low = transforms.Compose(val_transform_list)


        dataset_train = MixedDataset_train(image_dir='../../data/afhq128/train', image_dir2='../../data/afhq128_cat/train', mixed_ratio = args.mixed_ratio, transform_orig=train_transform_orig, transform_low=train_transform_low)
        dataset_test = MixedDataset_test(image_dir='../../data/afhq128/val', image_dir2='../../data/afhq128_cat/val', mixed_ratio = args.mixed_ratio, transform_orig=val_transform_orig, transform_low=val_transform_low)

        
        train_data_loader = DataLoader(
                                dataset_train,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=num_workers,
                                drop_last=True,
                            )

        test_data_loader = DataLoader(
                                dataset_test,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=num_workers,
                                drop_last=True,
                            )
        return train_data_loader, test_data_loader    
    else: NotImplementedError
        
    data_loader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )
    return data_loader




# ------------------------
# For Toy
# ------------------------
# datasets
class ToydatasetGaussian(data.Dataset):
    def __init__(self, cfg):
        self.dataset = torch.randn(cfg.num_data, cfg.data_dim) + torch.tensor([0,10])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]


class Toydatasetp(data.Dataset):
    def __init__(self, cfg):
        std = 0.5
        self.dataset = torch.cat([std*torch.randn(cfg.num_data//2, cfg.data_dim)+1, 
                                  std*torch.randn(cfg.num_data-cfg.num_data//2, cfg.data_dim)-1])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]


class Toydatasetq(data.Dataset):
    def __init__(self, cfg):
        std = 0.5
        self.dataset = torch.cat([std*torch.randn(2*cfg.num_data//3, cfg.data_dim)+2, 
                                  std*torch.randn(cfg.num_data-2*cfg.num_data//3, cfg.data_dim)-1])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]


class ToydatasetOutlier(data.Dataset):
    def __init__(self, cfg):
        M = int(cfg.num_data*cfg.p)
        self.dataset = torch.cat([0.1*torch.randn(cfg.num_data-M, cfg.data_dim) + 1, 0.1*torch.randn(M, cfg.data_dim) - 1])
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        return self.dataset[idx]


class ToydatasetNoise(data.Dataset):
    def __init__(self, cfg):
        self.N = cfg.num_data
        self.dim = cfg.data_dim
    
    def __len__(self):
        return int(self.N)
        
    
    def __getitem__(self, idx):
        return torch.randn((1, self.dim))


def get_datasets(cfg):
    src_name, tar_name = cfg.source_name, cfg.target_name
    datasets = []

    for name in [src_name, tar_name]:
        if name == 'gaussian':
            dataset = ToydatasetGaussian(cfg)
        elif name == 'p':
            dataset = Toydatasetp(cfg)
        elif name == 'q':
            dataset = Toydatasetq(cfg)
        elif name == 'outlier':
            dataset = ToydatasetOutlier(cfg)
        elif name == 'noise':
            dataset = ToydatasetNoise(cfg)
        else:
            raise NotImplementedError
        
        datasets.append(dataset)
    
    return datasets
