'''Modified from https://github.com/alinlab/LfF/blob/master/data/util.py'''

import os
import torch
from torch.utils.data.dataset import Dataset, Subset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from glob import glob
from PIL import Image
import pandas as pd

class IdxDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return (idx, *self.dataset[idx])
    
class ProbDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.prob = None # torch.zeros(len(self.dataset))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        idx = self.idx_sample() # if self.prob_on else idx
        return self.dataset[idx]
    
    def idx_sample(self):
        idx = torch.sum(torch.rand(1)>self.prob)
        return torch.clamp(idx, 0, len(self.prob)-1 )
    
    def update_prob(self, prob):
        self.prob = torch.cumsum(prob,dim=0)
        return


class ZippedDataset(Dataset):
    def __init__(self, datasets):
        super(ZippedDataset, self).__init__()
        self.dataset_sizes = [len(d) for d in datasets]
        self.datasets = datasets

    def __len__(self):
        return max(self.dataset_sizes)

    def __getitem__(self, idx):
        items = []
        for dataset_idx, dataset_size in enumerate(self.dataset_sizes):
            items.append(self.datasets[dataset_idx][idx % dataset_size])

        item = [torch.stack(tensors, dim=0) for tensors in zip(*items)]

        return item

class CMNISTDataset(Dataset):
    def __init__(self,root,split,transform=None, image_path_list=None):
        super(CMNISTDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list
        self.n_classes = 10

        if split=='train':
            self.align = glob(os.path.join(root, 'align',"*","*"))
            self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
            self.data = self.align + self.conflict
        elif split=='valid':
            self.data = glob(os.path.join(root,split,"*"))            
        elif split=='test':
            self.data = glob(os.path.join(root, '../test',"*","*"))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([int(self.data[index].split('_')[-2]),int(self.data[index].split('_')[-1].split('.')[0])])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)
        
        return image, attr, self.data[index]
    
class SCMNISTDataset(Dataset):
    def __init__(self,root,split,transform=None, image_path_list=None, bias=[]):
        super(SCMNISTDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        # self.bias = [1, None, None]
        self.n_classes = 10
        self.bias = bias
        self.n_s = len(self.bias)

        if split in ['train', 'valid']:
            self.data = glob(os.path.join(root, split, "*", "*"))         
        elif split=='test':
            self.data = glob(os.path.join(root, '../../test',"*","*"))
        self.y, self.s = [], []
        for data in self.data:
            self.y.append(int(data.split('/')[-1].split('_')[0]))
            self.s.append(int(data.split('/')[-1].split('_')[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        name = self.data[index].split("/")[-1]
        attr = torch.LongTensor([self.y[index],self.s[index]])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)
        
        return image, attr, self.data[index]


class CORRUPTEDCIFAR10Dataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None, bias=[]):
        super(CORRUPTEDCIFAR10Dataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        self.bias = bias
        # self.bias = [0, None, None, None, None, None, None, None, None, None]
        # self.bias = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.n_classes = 10
        self.n_s = len(self.bias)

        if split in ['train', 'valid']:
            self.data = glob(os.path.join(root, split, "*", "*"))
        elif split == 'test':
            self.data = glob(os.path.join(root, '../../test', "*", "*"))
        self.y, self.s = [], []
        for data in self.data:
            self.y.append(int(data.split('/')[-1].split('_')[0]))
            self.s.append(int(data.split('/')[-1].split('_')[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # name = self.data[index].split("/")[-1]
        attr = torch.LongTensor([self.y[index],self.s[index]])
        image = Image.open(self.data[index])

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, self.data[index]
    
class CIFAR10CMBDataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None, bias=[], bias2=[]):
        super(CIFAR10CMBDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        self.bias = bias
        self.bias2 = bias2
        # self.bias = [0, None, None, None, None, None, None, None, None, None]
        # self.bias = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
        self.n_classes = 10
        self.n_s = len(self.bias)

        if split in ['train', 'valid']:
            self.data = glob(os.path.join(root, split, "*", "*"))
        elif split == 'test':
            self.data = glob(os.path.join(root, '../../../test', "*", "*"))
        self.y, self.s, self.s2 = [], [], []
        for data in self.data:
            self.y.append(int(data.split('/')[-1].split('_')[0]))
            self.s.append(int(data.split('/')[-1].split('_')[1]))
            self.s2.append(int(data.split('/')[-1].split('_')[2]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # name = self.data[index].split("/")[-1]
        attr = torch.LongTensor([self.y[index],self.s[index],self.s2[index]])
        image = Image.open(self.data[index])

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, self.data[index]


class bFFHQDataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None):
        super(bFFHQDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        self.bias = [0, 1]
        self.n_classes = 2
        self.n_s = len(self.bias)

        if split=='train':
            self.align = glob(os.path.join(root, 'align',"*","*"))
            self.conflict = glob(os.path.join(root, 'conflict',"*","*"))
            self.data = self.align + self.conflict

        elif split=='valid':
            self.data = glob(os.path.join(os.path.dirname(root), split, "*"))

        elif split=='test':
            self.data = glob(os.path.join(os.path.dirname(root), split, "*"))
            # data_conflict = []
            # for path in self.data:
            #     target_label = path.split('/')[-1].split('.')[0].split('_')[1]
            #     bias_label = path.split('/')[-1].split('.')[0].split('_')[2]
            #     if target_label != bias_label:
            #         data_conflict.append(path)
            # self.data = data_conflict
        self.y, self.s = [], []
        for data in self.data:
            self.y.append(int(data.split('_')[-2]))
            self.s.append(int(data.split('_')[-1].split('.')[0]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([self.y[index],self.s[index]])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)  
        return image, attr, self.data[index]

class BARDataset(Dataset):
    def __init__(self, root, split, transform=None, percent=None, image_path_list=None):
        super(BARDataset, self).__init__()
        self.transform = transform
        self.percent = percent
        self.split = split
        self.image2pseudo = {}
        self.image_path_list = image_path_list
        
        self.n_classes = 6
        self.bias = [0, 1, 2, 3, 4, 5]
        self.n_s = len(self.bias)

        self.train_align = glob(os.path.join(root,'train/align',"*/*"))
        self.train_conflict = glob(os.path.join(root,'train/conflict',f"{self.percent}/*/*"))
        self.valid = glob(os.path.join(root,'valid',"*/*"))
        self.test = glob(os.path.join(root,'test',"*/*"))

        if self.split=='train':
            self.data = self.train_align + self.train_conflict
        elif self.split=='valid':
            self.data = self.valid
        elif self.split=='test':
            self.data = self.test
        self.y, self.s = [], []
        for data in self.data:
            self.y.append(int(data.split('_')[-2]))
            self.s.append(int(data.split('_')[-1].split('.')[0]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([self.y[index],self.s[index]])
        image = Image.open(self.data[index]).convert('RGB')
        image_path = self.data[index]

        if 'bar/train/conflict' in image_path:
            attr[1] = (attr[0] + 1) % 6 # not exactly, but not aligned with the target label at least
        elif 'bar/train/align' in image_path:
            attr[1] = attr[0]

        if self.transform is not None:
            image = self.transform(image)  
        return image, attr, (image_path, index)
    
class DogCatDataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None):
        super(DogCatDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image_path_list = image_path_list

        if split == "train":
            self.align = glob(os.path.join(root, "align", "*", "*"))
            self.conflict = glob(os.path.join(root, "conflict", "*", "*"))
            self.data = self.align + self.conflict
        elif split == "valid":
            self.data = glob(os.path.join(root, split, "*"))
        elif split == "test":
            self.data = glob(os.path.join(root, "../test", "*", "*"))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor(
            [int(self.data[index].split('_')[-2]), int(self.data[index].split('_')[-1].split('.')[0])])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)  
        return image, attr, self.data[index]

class NICODataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None):
        super(NICODataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        # self.bias = [0, None, None, None, None, None, None, None, None, None]
        self.bias = [None] * 32
        self.n_classes = 10
        self.n_s = len(self.bias)

        self.data = glob(os.path.join(root, split, "*"))
        self.y, self.s = [], []
        for data in self.data:
            name = data.split("/")[-1]
            self.y.append(int(name.split('_')[0]))
            self.s.append(int(name.split('_')[1]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([self.y[index],self.s[index]])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, self.data[index]
    
class CelebADataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None):
        super(CelebADataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        # self.bias = [0, None, None, None, None, None, None, None, None, None]
        self.bias = [1, 0]
        self.n_classes = 2
        self.n_s = len(self.bias)
        
        metadata_path = os.path.join(root, "metadata.csv")
        split_path = os.path.join(root, "list_eval_partition.txt")
        datadir_path = os.path.join(root, "img_align_celeba")
        metadata = pd.read_csv(metadata_path)
        splitdata = pd.read_csv(split_path, sep=" ")
        
        
        if split == "train":
            metadata = metadata[splitdata["split"] == 0]
        elif split == "valid":
            metadata = metadata[splitdata["split"] == 1]
        elif split == "test":
            metadata = metadata[splitdata["split"] == 2]
        # pathes 
        metadata["image_id"] = datadir_path + "/" + metadata["image_id"]
        self.data = metadata["image_id"].tolist()
        
        # labels
        metadata = metadata.drop(labels="image_id", axis="columns") 
        # cast -1 to 0
        metadata[metadata == -1] = 0
        
        self.y = metadata["Blond_Hair"].tolist()
        self.s = metadata["Male"].tolist() # flip the label to align with y in the bias
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([int(self.y[index]), int(self.s[index])])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, self.data[index]
    
class WaterbirdsDataset(Dataset):
    def __init__(self, root, split, transform=None, image_path_list=None):
        super(WaterbirdsDataset, self).__init__()
        self.transform = transform
        self.root = root
        self.image2pseudo = {}
        self.image_path_list = image_path_list

        # self.bias = [0, None, None, None, None, None, None, None, None, None]
        self.bias = [0, 1]
        self.n_classes = 2
        self.n_s = len(self.bias)
        
        metadata_path = os.path.join(root, "metadata.csv")
        metadata = pd.read_csv(metadata_path)
        
        if split == "train":
            metadata = metadata[metadata["split"] == 0]
        elif split == "valid":
            metadata = metadata[metadata["split"] == 1]
        elif split == "test":
            metadata = metadata[metadata["split"] == 2]
        metadata["img_filename"] = root + metadata["img_filename"]
        self.y = metadata["y"].tolist()
        self.s = metadata["place"].tolist()
        self.data = metadata["img_filename"].tolist()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        attr = torch.LongTensor([int(self.y[index]), int(self.s[index])])
        image = Image.open(self.data[index]).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        return image, attr, self.data[index]

transforms = {
    "mnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "cmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "scmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "corruptedCifar10": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "cifar10c": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "cifar10c_mb": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "bar": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()])
    },
    "bffhq": {
        "train": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224,224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224,224)), T.ToTensor()])
        },
    "dogs_and_cats": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
    },
    "celeba": {
        "train": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "valid": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
        "test": T.Compose([T.Resize((224, 224)), T.ToTensor()]),
    },
    }


transforms_preprcs = {
    "cmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
        },
    "scmnist": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "corrupted_cifar10": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "cifar10c": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "cifar10c_mb": {
        "train": T.Compose([T.ToTensor()]),
        "valid": T.Compose([T.ToTensor()]),
        "test": T.Compose([T.ToTensor()])
    },
    "bar": {
        "train": T.Compose([
            T.Resize((224, 224)),
            T.RandomCrop(224, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        ),
        "valid": T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        ),
        "test": T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
        )
    },
    "bffhq": {
        "train": T.Compose([
            T.Resize((224,224)),
            T.RandomCrop(224, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "valid": T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        ),
        "test": T.Compose([
            T.Resize((224,224)),
            T.ToTensor(),
            T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ]
        )
        },
    "dogs_and_cats": {
            "train": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.RandomCrop(224, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "valid": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "test": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
        },
    "NICO": {
            "train": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.RandomCrop(224, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "valid": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "test": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
        },
    "waterbirds": {
            "train": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.RandomCrop(224, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "valid": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "test": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
        },
    "celeba": {
            "train": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.RandomCrop(224, padding=4),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "valid": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
            "test": T.Compose(
                [
                    T.Resize((224, 224)),
                    T.ToTensor(),
                    T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ]
            ),
        },
}


data2model = {'cmnist': "MLP",
                'scmnist': "MLP",
                'bar': "ResNet18",
                'bffhq': "ResNet18",
                'dogs_and_cats': "ResNet18",
                'corruptedCifar10': "ResNet18",
                'cifar10c': "ResNet18",
                'cifar10c_mb': "ResNet18",
                'NICO': "ResNet18",
                'waterbirds': "ResNet18",
                'celeba': "ResNet18",
                }

data2batch_size = {'cmnist': 256,
                    'scmnist': 256,
                    'bar': 64,
                    'bffhq': 64,
                    'dogs_and_cats': 64,
                    'corruptedCifar10': 256,
                    'cifar10c': 256,
                    'cifar10c_mb': 256,
                    'NICO': 64,
                    'waterbirds': 64,
                    'celeba': 128,
                    }

data2preprocess = {'cmnist': None,
                    'scmnist': None,
                    'bar': True,
                    'bffhq': True,
                    'dogs_and_cats':True,
                    'corruptedCifar10':None,
                    'cifar10c':None,
                    'cifar10c_mb':None,
                    'NICO': True,
                    'waterbirds': True,
                    'celeba': True
                    }


def get_dataset(dataset, data_dir, dataset_split, transform_split, percent, use_preprocess=None, image_path_list=None, args=None):

    dataset_category = dataset.split("-")[0]
    if use_preprocess:
        transform = transforms_preprcs[dataset_category][transform_split]
    else:
        transform = transforms[dataset_category][transform_split]

    dataset_split = "valid" if (dataset_split == "eval") else dataset_split

    if 'scmnist' in dataset:
        n_b = int(dataset.split('-')[-1])
        root = data_dir + f"/SCMNIST-{n_b}/{args.sparsity}/{args.corr}"
        bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
        dataset = SCMNISTDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, bias=bias)
        
    elif 'cmnist' in dataset:
        n_b = int(dataset.split('-')[-1])
        root = data_dir + f"/CMNIST-{n_b}/{args.sparsity}/{args.corr}"
        bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
        dataset = SCMNISTDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, bias=bias)

    elif dataset == "bffhq":
        root = data_dir + f"/bffhq/{percent}"
        dataset = bFFHQDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)

    elif dataset == "bar":
        root = data_dir + f"/bar"
        dataset = BARDataset(root=root, split=dataset_split, transform=transform, percent=percent, image_path_list=image_path_list)

    elif dataset == "dogs_and_cats":
        root = data_dir + f"/dogs_and_cats/{percent}"
        # print(root)
        dataset = DogCatDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    #
    # elif dataset == "corruptedCifar10":
    #     root = data_dir + f"/corruptedCifar_tenBias/{args.sparsity}/{args.corr}"
    #     print(root)
    #     dataset = CORRUPTEDCIFAR10Dataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    elif dataset == "corruptedCifar10":
        # root = data_dir + f"/corruptedCifar10/{args.sparsity}/{args.corr}"
        # root = data_dir + f"/corruptedCifar_tenBias/{args.sparsity}/{args.corr}"
        # bias = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

        root = data_dir + f"/corruptedCifar_oneBias/{args.sparsity}/{args.corr}"
        bias = [0, None, None, None, None, None, None, None, None, None]
        
        dataset = CORRUPTEDCIFAR10Dataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list, bias=bias)
    elif "cifar10c_mb" in dataset:
        n_b_1 = int(dataset.split('-')[-2])
        n_b_2 = int(dataset.split('-')[-1])
        root = data_dir + f"/Cifar10C-MB-{n_b_1}-{n_b_2}/{args.sparsity}/{args.corr}/{args.corr2}"
        bias = [i for i in range(n_b_1)] + [None for _ in range(10 - n_b_1)]
        bias2 = [i for i in range(n_b_2)] + [None for _ in range(10 - n_b_2)]
        dataset = CIFAR10CMBDataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, bias=bias, bias2=bias2)
        # root = data_dir + f"/corruptedCifar_oneBias/{args.sparsity}/{args.corr}"
        # dataset = CORRUPTEDCIFAR10Dataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    elif "cifar10c" in dataset:
        n_b = int(dataset.split('-')[-1])
        root = data_dir + f"/Cifar10C-{n_b}/{args.sparsity}/{args.corr}"
        bias = [i for i in range(n_b)] + [None for _ in range(10 - n_b)]
        dataset = CORRUPTEDCIFAR10Dataset(root=root,split=dataset_split,transform=transform, image_path_list=image_path_list, bias=bias)
    elif dataset == "NICO":
        root = data_dir + f"/NICO/multi_classification/"
        dataset = NICODataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    elif dataset == "waterbirds":
        root = data_dir + f"/waterbird_complete95_forest2water2/"
        dataset = WaterbirdsDataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    elif dataset == "celeba":
        root = data_dir + f"/celebA/"
        dataset = CelebADataset(root=root, split=dataset_split, transform=transform, image_path_list=image_path_list)
    else:
        print('wrong dataset ...')
        import sys
        sys.exit(0)

    return dataset