import os
import torch
from torchvision import transforms
import torchvision.datasets.folder
from torch.utils.data import TensorDataset, Subset, ConcatDataset, Dataset
from torchvision.datasets import MNIST, ImageFolder
from torchvision.models import resnet50, ResNet50_Weights
import random
import numpy as np
from torch.utils.data import random_split
from utils_v2 import set_seed

import pandas as pd
from PIL import Image
import numpy as np
from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset
from wilds.common.grouper import CombinatorialGrouper

from collections import defaultdict
from torchvision.datasets import MNIST
import xml.etree.ElementTree as ET
from zipfile import ZipFile
import argparse
import tarfile
import shutil
import gdown
import uuid
import json
import os
import urllib

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

def set_seed(seed):
    random.seed(seed)                             # Python random module
    np.random.seed(seed)                          # NumPy
    torch.manual_seed(seed)                       # PyTorch (CPU)
    torch.cuda.manual_seed(seed)                  # PyTorch (single GPU)
    torch.cuda.manual_seed_all(seed)              # PyTorch (multi-GPU)
    
    torch.backends.cudnn.deterministic = True     # Force determinism
    torch.backends.cudnn.benchmark = False        # Turn off optimizations that break determinism

    torch.use_deterministic_algorithms(True)      # Force determinism

    os.environ['PYTHONHASHSEED'] = str(seed)      # Python hashing

class MultipleDomainDataset:
    N_STEPS = 5001           # Default, subclasses may override
    CHECKPOINT_FREQ = 100    # Default, subclasses may override
    N_WORKERS = 16            # Default, subclasses may override
    ENVIRONMENTS = None      # Subclasses should override
    INPUT_SHAPE = None       # Subclasses should override

    def __getitem__(self, index):
        return self.datasets[index]

    def __len__(self):
        return len(self.datasets)

class MultipleEnvironmentMNIST(MultipleDomainDataset):
    def __init__(self, root, environments, dataset_transform, input_shape,
                 num_classes):
        super().__init__()
        if root is None:
            raise ValueError('Data directory not specified!')

        original_dataset_tr = MNIST(root, train=True, download=True)
        original_dataset_te = MNIST(root, train=False, download=True)

        original_images = torch.cat((original_dataset_tr.data,
                                     original_dataset_te.data))

        original_labels = torch.cat((original_dataset_tr.targets,
                                     original_dataset_te.targets))

        shuffle = torch.randperm(len(original_images))

        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []

        for i in range(len(environments)):
            images = original_images[i::len(environments)]
            labels = original_labels[i::len(environments)]
            self.datasets.append(dataset_transform(images, labels, environments[i]))

        self.input_shape = input_shape
        self.num_classes = num_classes

######## CMNIST Dataset
class ColoredMNIST(MultipleEnvironmentMNIST):
    ENVIRONMENTS = ['+90%', '+80%', '-90%']

    def __init__(self, root, test_envs=[2], hparams=None, transform=None, val_frac=0.2, seed=42):
        set_seed(seed)
        self.transform = transform
        super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9],
                                         self.color_dataset, (3, 28, 28,), 2)
        self.input_shape = (3, 28, 28,)
        self.num_classes = 2

        self.test_envs = test_envs
        self.train_datasets = []
        self.val_datasets = []
        self.test_datasets = []

        for i, dataset in enumerate(self.datasets):
            if i in test_envs:
                self.test_datasets.append(dataset)
            else:
                val_len = int(len(dataset) * val_frac)
                train_len = len(dataset) - val_len
                train_set, val_set = random_split(dataset, [train_len, val_len])
                self.train_datasets.append(train_set)
                self.val_datasets.append(val_set)

        # Optional: Flatten for convenience
        self.train_dataset = ConcatDataset(self.train_datasets)
        self.val_dataset = ConcatDataset(self.val_datasets)
        self.test_dataset = ConcatDataset(self.test_datasets)

    def color_dataset(self, images, labels, environment):
        # # Subsample 2x for computational convenience
        # images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit
        labels = (labels < 5).float()
        # Flip label with probability 0.25
        labels = self.torch_xor_(labels,
                                 self.torch_bernoulli_(0.25, len(labels)))

        # Assign a color based on the label; flip the color with probability e
        colors = self.torch_xor_(labels,
                                 self.torch_bernoulli_(environment,
                                                       len(labels)))
        images = torch.stack([images, images, images], dim=1) # Added another channel (for ResNet50 usage)
        # Apply the color to the image by zeroing out the other color channel
        images[torch.tensor(range(len(images))), (
            1 - colors).long(), :, :] *= 0

        images[torch.tensor(range(len(images))), 2, :, :] *= 0

        x = images
        y = labels.view(-1).long()
        z = colors.long()

        return CMNISTDataset(x, z, y, self.transform)

    def torch_bernoulli_(self, p, size):
        return (torch.rand(size) < p).float()

    def torch_xor_(self, a, b):
        return (a - b).abs()

class CMNISTDataset(Dataset):
    """
    Dataset for CMNIST
    """
    def __init__(self, inputs, metadata, outputs, transform=None):
        """
        Arguments:
            inputs (tensor): Images
            metadata (tensor): Metadata
            outputs (tensor): Labels
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.inputs = inputs
        self.metadata = metadata
        self.groups = torch.hstack((metadata.reshape(-1, 1), outputs.reshape(-1, 1)))
        self.outputs = outputs
        self.transform = transform

    def __len__(self):
        return len(self.outputs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        image = self.inputs[idx, :, :, :]
        label = self.outputs[idx]
        metadata = self.metadata[idx]
        group = self.groups[idx]
        if self.transform:
            image = self.transform(image)

        return image, label, group

##### WaterBirds Dataset
class WaterbirdsDataset_Base(WILDSDataset):
    """
    The Waterbirds dataset.
    This dataset is not part of the official WILDS benchmark.
    We provide it for convenience and to facilitate comparisons to previous work.

    Supported `split_scheme`:
        'official'

    Input (x):
        Images of birds against various backgrounds that have already been cropped and centered.

    Label (y):
        y is binary. It is 1 if the bird is a waterbird (e.g., duck), and 0 if it is a landbird.

    Metadata:
        Each image is annotated with whether the background is a land or water background.

    Original publication:
        @inproceedings{sagawa2019distributionally,
          title = {Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
          author = {Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
          booktitle = {International Conference on Learning Representations},
          year = {2019}
        }

    The dataset was constructed from the CUB-200-2011 dataset and the Places dataset:
        @techreport{WahCUB_200_2011,
        	Title = {{The Caltech-UCSD Birds-200-2011 Dataset}},
        	Author = {Wah, C. and Branson, S. and Welinder, P. and Perona, P. and Belongie, S.},
        	Year = {2011}
        	Institution = {California Institute of Technology},
        	Number = {CNS-TR-2011-001}
        }
        @article{zhou2017places,
          title = {Places: A 10 million Image Database for Scene Recognition},
          author = {Zhou, Bolei and Lapedriza, Agata and Khosla, Aditya and Oliva, Aude and Torralba, Antonio},
          journal ={IEEE Transactions on Pattern Analysis and Machine Intelligence},
          year = {2017},
          publisher = {IEEE}
        }

    License:
        The use of this dataset is restricted to non-commercial research and educational purposes.
    """

    _dataset_name = 'waterbirds'
    _versions_dict = {
        '1.0': {
            'download_url': 'https://worksheets.codalab.org/rest/bundles/0x505056d5cdea4e4eaa0e242cbfe2daa4/contents/blob/',
            'compressed_size': None}}

    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=42):
        set_seed(seed)
        self._version = version
        self._data_dir = self.initialize_data_dir(root_dir, download)

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Read in metadata
        # Note: metadata_df is one-indexed.
        metadata_df = pd.read_csv(
            os.path.join(self.data_dir, 'metadata.csv'))

        # Get the y values
        self._y_array = torch.LongTensor(metadata_df['y'].values)
        self._y_size = 1
        self._n_classes = 2

        self._metadata_array = torch.stack(
            (torch.LongTensor(metadata_df['place'].values), self._y_array),
            dim=1
        )
        self._metadata_fields = ['background', 'y']
        self._metadata_map = {
            'background': [' land', 'water'], # Padding for str formatting
            'y': [' landbird', 'waterbird']
        }

        # Extract filenames
        self._input_array = metadata_df['img_filename'].values
        self._original_resolution = (224, 224)

        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')
        self._split_array = metadata_df['split'].values

        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(['background', 'y']))

        super().__init__(root_dir, download, split_scheme)

    def get_input(self, idx):
       """
       Returns x for a given idx.
       """
       img_filename = os.path.join(
           self.data_dir,
           self._input_array[idx])
       x = Image.open(img_filename).convert('RGB')
       return x

    def get_environments(self, transform=None):
        masks = {}
        counts = {}

        for i in torch.unique(self.metadata_array, dim=0):
            if i[2] == 0:
                continue
            masks[tuple(i.tolist())] = torch.all(self.metadata_array == i, dim=1) & torch.tensor(self.split_array == self.split_dict['train'])
            counts[tuple(i.tolist())] = torch.sum(masks[tuple(i.tolist())]).item()

        land_idx = []
        water_idx = []

        # land Environment - (LL - 0.55, LW - 0.55, WW - 0.35, WL - 0.35)
        # Water Environment - (LL - 0.45, LW - 0.45, WW - 0.65, WL - 0.65)
        
        for i, key in enumerate(counts.keys()):
            arr = np.where(masks[key])[0]
        
            if key[0] == 0:
                land_len = np.round(len(arr) * 0.55).astype(int)
                land_choice = np.random.choice(arr, land_len, replace=False)
                water_choice = np.setdiff1d(arr, land_choice)
            else:
                land_len = np.round(len(arr) * 0.35).astype(int)
                land_choice = np.random.choice(arr, land_len, replace=False)
                water_choice = np.setdiff1d(arr, land_choice)
        
            land_idx.append(land_choice)
            water_idx.append(water_choice)
        
        len_land = sum(len(i) for i in land_idx)
        len_water = sum(len(i) for i in water_idx)

        greater_land = None
        
        if len_land != len_water:
            over, under, greater_land = (land_idx, water_idx, True) if len_land > len_water else (water_idx, land_idx, False)
            diff = abs(len_land - len_water)
            idx = np.argmax([len(i) for i in over])
            filler = np.random.choice(over[idx], diff // 2 + diff % 2, replace=False)
            over[idx] = np.setdiff1d(over[idx], filler)
            under[idx] = np.append(under[idx], filler)
            
        if greater_land:
            land_array = np.concatenate(over).flatten()
            water_array = np.concatenate(under).flatten()
        elif greater_land is False:
            land_array = np.concatenate(under).flatten()
            water_array = np.concatenate(over).flatten()
        else:
            land_array = np.concatenate(land_idx).flatten()
            water_array = np.concatenate(water_idx).flatten()
                
        return {
            "land": WILDSSubset(self, land_array, transform),
            "water": WILDSSubset(self, water_array, transform),
            "sizes": {"land": len(land_array), "water": len(water_array)},
            "groups": list(counts.keys())
        }


class WaterbirdsDataset(WaterbirdsDataset_Base):
    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=42, transforms=None):
        super().__init__(version, root_dir, download, split_scheme, seed)

        if transforms:
            self.train_dataset = self.get_subset('train', transform=transforms)
            self.val_dataset = self.get_subset('val', transform=transforms)
            self.test_dataset = self.get_subset('test', transform=transforms)
        else:
            self.train_dataset = self.get_subset('train')
            self.val_dataset = self.get_subset('val')
            self.test_dataset = self.get_subset('test')

        self.train_environments = self.get_environments(transform=transforms)

### CelebA Dataset
class CelebADataset_Base(WILDSDataset):
    """
    A variant of the CelebA dataset.
    This dataset is not part of the official WILDS benchmark.
    We provide it for convenience and to facilitate comparisons to previous work.

    Supported `split_scheme`:
        'official'

    Input (x):
        Images of celebrity faces that have already been cropped and centered.

    Label (y):
        y is binary. It is 1 if the celebrity in the image has blond hair, and is 0 otherwise.

    Metadata:
        Each image is annotated with whether the celebrity has been labeled 'Male' or 'Female'.

    Website:
        http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

    Original publication:
        @inproceedings{liu2015faceattributes,
          title = {Deep Learning Face Attributes in the Wild},
          author = {Liu, Ziwei and Luo, Ping and Wang, Xiaogang and Tang, Xiaoou},
          booktitle = {Proceedings of International Conference on Computer Vision (ICCV)},
          month = {December},
          year = {2015}
        }

    This variant of the dataset is identical to the setup in:
        @inproceedings{sagawa2019distributionally,
          title = {Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization},
          author = {Sagawa, Shiori and Koh, Pang Wei and Hashimoto, Tatsunori B and Liang, Percy},
          booktitle = {International Conference on Learning Representations},
          year = {2019}
        }

    License:
        This version of the dataset was originally downloaded from Kaggle
        https://www.kaggle.com/jessicali9530/celeba-dataset

        It is available for non-commercial research purposes only.
    """
    _dataset_name = 'celebA'
    _versions_dict = {
        '1.0': {
            'download_url': 'https://worksheets.codalab.org/rest/bundles/0xfe55077f5cd541f985ebf9ec50473293/contents/blob/',
            'compressed_size': 1_308_557_312}}

    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=42):
        set_seed(seed)
        self._version = version
        self._data_dir = self.initialize_data_dir(root_dir, download)
        target_name = 'Blond_Hair'
        confounder_names = ['Male']
    
        # Read in attributes
        attrs_df = pd.read_csv(
            os.path.join(self.data_dir, 'list_attr_celeba.csv'))
    
        # Split out filenames and attribute names
        # Note: idx and filenames are off by one.
        self._input_array = attrs_df['image_id'].values
        self._original_resolution = (178, 218)
        attrs_df = attrs_df.drop(labels='image_id', axis='columns')
        attr_names = attrs_df.columns.copy()
        def attr_idx(attr_name):
            return attr_names.get_loc(attr_name)
    
        # Then cast attributes to numpy array and set them to 0 and 1
        # (originally, they're -1 and 1)
        attrs_df = attrs_df.values
        attrs_df[attrs_df == -1] = 0
    
        # Get the y values
        target_idx = attr_idx(target_name)
        self._y_array = torch.LongTensor(attrs_df[:, target_idx])
        self._y_size = 1
        self._n_classes = 2
    
        # Get metadata
        confounder_idx = [attr_idx(a) for a in confounder_names]
        confounders = attrs_df[:, confounder_idx]
    
        self._metadata_array = torch.cat(
            (torch.LongTensor(confounders), self._y_array.reshape((-1, 1))),
            dim=1)
        confounder_names = [s.lower() for s in confounder_names]
        self._metadata_fields = confounder_names + ['y']
        self._metadata_map = {
            'y': ['not blond', '    blond'] # Padding for str formatting
        }
    
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=(confounder_names + ['y']))
    
        # Extract splits
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')
        split_df = pd.read_csv(
            os.path.join(self.data_dir, 'list_eval_partition.csv'))
        self._split_array = split_df['partition'].values
    
        super().__init__(root_dir, download, split_scheme)
    
    def get_input(self, idx):
       # Note: idx and filenames are off by one.
       img_filename = os.path.join(
           self.data_dir,
           'img_align_celeba',
           self._input_array[idx])
       x = Image.open(img_filename).convert('RGB')
       return x

    def get_environments(self, transform=None):
        masks = {}
        counts = {}

        for i in torch.unique(self.metadata_array, dim=0):
            if i[2] == 0:
                continue
            masks[tuple(i.tolist())] = torch.all(self.metadata_array == i, dim=1) & torch.tensor(self.split_array == self.split_dict['train'])
            counts[tuple(i.tolist())] = torch.sum(masks[tuple(i.tolist())]).item()

        male_idx = []
        female_idx = []
        
        # Male-biased - (FNB - 0.45, FB - 0.35, MB - 0.55, MNB - 0.6)
        # Female-biased - (FNB - 0.55, FB - 0.65, MB - 0.45, MNB - 0.4)
        
        for i, key in enumerate(counts.keys()):
            arr = np.where(masks[key])[0]
        
            if key[0] == 0:
                if key[0] != key[1]:
                    male_len = np.round(len(arr) * 0.35).astype(int)
                    male_choice = np.random.choice(arr, male_len, replace=False)
                    female_choice = np.setdiff1d(arr, male_choice)
                else:
                    male_len = np.round(len(arr) * 0.45).astype(int)
                    male_choice = np.random.choice(arr, male_len, replace=False)
                    female_choice = np.setdiff1d(arr, male_choice)
            else:
                if key[0] != key[1]:
                    male_len = np.round(len(arr) * 0.6).astype(int)
                    male_choice = np.random.choice(arr, male_len, replace=False)
                    female_choice = np.setdiff1d(arr, male_choice)
                else:
                    male_len = np.round(len(arr) * 0.55).astype(int)
                    male_choice = np.random.choice(arr, male_len, replace=False)
                    female_choice = np.setdiff1d(arr, male_choice)
        
            male_idx.append(male_choice)
            female_idx.append(female_choice)
        
        len_male = sum(len(i) for i in male_idx)
        len_female = sum(len(i) for i in female_idx)
        
        if len_male != len_female:
            over, under, greater_male = (male_idx, female_idx, True) if len_male > len_female else (female_idx, male_idx, False)
            diff = abs(len_male - len_female)
            idx = np.argmax([len(i) for i in over])
            filler = np.random.choice(over[idx], diff // 2 + diff % 2, replace=False)
            over[idx] = np.setdiff1d(over[idx], filler)
            under[idx] = np.append(under[idx], filler)
        
        if greater_male:
            male_array = np.concatenate(over).flatten()
            female_array = np.concatenate(under).flatten()
        elif greater_male is False:
            male_array = np.concatenate(under).flatten()
            female_array = np.concatenate(over).flatten()
        else:
            male_array = np.concatenate(male_idx).flatten()
            female_array = np.concatenate(female_idx).flatten()
                
        return {
            "male": WILDSSubset(self, male_array, transform),
            "female": WILDSSubset(self, female_array, transform),
            "sizes": {"male": len(male_array), "female": len(female_array)},
            "groups": list(counts.keys())
        }

class CelebADataset(CelebADataset_Base):
    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=42, transforms=None):
        super().__init__(version, root_dir, download, split_scheme, seed)

        if transforms:
            self.train_dataset = self.get_subset('train', transform=transforms)
            self.val_dataset = self.get_subset('val', transform=transforms)
            self.test_dataset = self.get_subset('test', transform=transforms)
        else:
            self.train_dataset = self.get_subset('train')
            self.val_dataset = self.get_subset('val')
            self.test_dataset = self.get_subset('test')

        self.train_environments = self.get_environments(transform=transforms)

##### Spawrious Dataset
## Spawrious base classes
class CustomImageFolder(Dataset):
    """
    A class that takes one folder at a time and loads a set number of images in a folder and assigns them a specific class
    """
    def __init__(self, folder_path, class_index, location_index, env_index, limit=None, transform=None):
        self.folder_path = folder_path
        self.class_index = class_index
        self.location_index = location_index
        self.env_index = env_index
        self.image_paths = [os.path.join(folder_path, img) for img in os.listdir(folder_path) if img.endswith(('.png', '.jpg', '.jpeg'))]
        if limit:
            self.image_paths = self.image_paths[:limit]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
        
        label = torch.tensor(self.class_index, dtype=torch.long)
        group = torch.tensor([self.location_index, self.env_index], dtype=torch.long)
        return img, label, group

class SpawriousBenchmark(MultipleDomainDataset):
    ENVIRONMENTS = ["Test", "SC_group_1", "SC_group_2"]
    input_shape = (3, 224, 224)
    num_classes = 4
    class_list = ["bulldog", "corgi", "dachshund", "labrador"]
    location_list = ["beach", "desert", "dirt", "jungle", "mountain", "snow"]

    def __init__(self, train_combinations, test_combinations, root_dir, augment, val_frac=0.2, type1=False):
        self.type1 = type1
        self.location_adj = []
        self.has_filler = False

        for list_animals, list_combinations in train_combinations.items():
            if len(list_animals) > 1:
                self.filler = list_combinations[0][0]
                self.has_filler = True
            else:
                self.location_adj.append(list_combinations[0][0])

        if self.has_filler:
            self.location_adj.append(self.filler)

        self.env_list = [(class_, location_) for location_ in self.location_list for class_ in self.class_list]
        train_datasets, test_datasets = self._prepare_data_lists(train_combinations, test_combinations, root_dir, augment)
        self.datasets = [ConcatDataset(test_datasets)] + train_datasets
        
        self.train_datasets = []
        self.val_datasets = []
        self.test_datasets = []

        for i, dataset in enumerate(self.datasets):
            if i == 0:
                self.test_datasets.append(dataset)
            else:
                val_len = int(len(dataset) * val_frac)
                train_len = len(dataset) - val_len
                train_set, val_set = random_split(dataset, [train_len, val_len])
                self.train_datasets.append(train_set)
                self.val_datasets.append(val_set)

        # Optional: Flatten for convenience
        self.train_dataset = ConcatDataset(self.train_datasets)
        self.val_dataset = ConcatDataset(self.val_datasets)
        self.test_dataset = ConcatDataset(self.test_datasets)

    # Prepares the train and test data lists by applying the necessary transformations.
    def _prepare_data_lists(self, train_combinations, test_combinations, root_dir, augment):
        test_transforms = ResNet50_Weights.DEFAULT.transforms()
        
        if augment:
            train_transforms = ResNet50_Weights.DEFAULT.transforms()
        else:
            train_transforms = test_transforms

        train_data_list = self._create_data_list(train_combinations, root_dir, train_transforms)
        test_data_list = self._create_data_list(test_combinations, root_dir, test_transforms)

        return train_data_list, test_data_list

    # Creates a list of datasets based on the given combinations and transformations.
    def _create_data_list(self, combinations, root_dir, transforms):
        data_list = []
        if isinstance(combinations, dict):
            
            # Build class groups for a given set of combinations, root directory, and transformations.
            for_each_class_group = []
            cg_index = 0
            for classes, comb_list in combinations.items():
                for_each_class_group.append([])
                for ind, location_limit in enumerate(comb_list):
                    if isinstance(location_limit, tuple):
                        location, limit = location_limit
                    else:
                        location, limit = location_limit, None
                        
                    cg_data_list = []
                    for cls in classes:
                        path = os.path.join(root_dir, f"{0 if not self.type1 else ind}/{location}/{cls}")
                        data = CustomImageFolder(folder_path=path, class_index=self.class_list.index(cls),
                                                 location_index=self.location_adj.index(location),
                                                 env_index=self.env_list.index((cls, location)),
                                                 limit=limit, transform=transforms) # Added location
                        cg_data_list.append(data)
                    
                    for_each_class_group[cg_index].append(ConcatDataset(cg_data_list))
                cg_index += 1

            for group in range(len(for_each_class_group[0])):
                data_list.append(
                    ConcatDataset(
                        [for_each_class_group[k][group] for k in range(len(for_each_class_group))]
                    )
                )
        else:
            for location in combinations:
                path = os.path.join(root_dir, f"{0}/{location}/")
                data = ImageFolder(root=path, transform=transforms)
                data_list.append(data)

        return data_list
    
    
    # Buils combination dictionary for o2o datasets
    def build_type1_combination(self,group,test,filler):
        total = 3168
        counts = [int(0.97*total),int(0.87*total)]
        combinations = {}
        combinations['train_combinations'] = {
            ## correlated class
            ("bulldog",):[(group[0],counts[0]),(group[0],counts[1])],
            ("dachshund",):[(group[1],counts[0]),(group[1],counts[1])],
            ("labrador",):[(group[2],counts[0]),(group[2],counts[1])],
            ("corgi",):[(group[3],counts[0]),(group[3],counts[1])],
            ## filler
            ("bulldog","dachshund","labrador","corgi"):[(filler,total-counts[0]),(filler,total-counts[1])],
        }
        ## TEST
        combinations['test_combinations'] = {
            ("bulldog",):[test[0], test[0]],
            ("dachshund",):[test[1], test[1]],
            ("labrador",):[test[2], test[2]],
            ("corgi",):[test[3], test[3]],
        }
        return combinations

    # Buils combination dictionary for m2m datasets
    def build_type2_combination(self,group,test):
        total = 3168
        counts = [total,total]
        combinations = {}
        combinations['train_combinations'] = {
            ## correlated class
            ("bulldog",):[(group[0],counts[0]),(group[1],counts[1])],
            ("dachshund",):[(group[1],counts[0]),(group[0],counts[1])],
            ("labrador",):[(group[2],counts[0]),(group[3],counts[1])],
            ("corgi",):[(group[3],counts[0]),(group[2],counts[1])],
        }
        combinations['test_combinations'] = {
            ("bulldog",):[test[0], test[1]],
            ("dachshund",):[test[1], test[0]],
            ("labrador",):[test[2], test[3]],
            ("corgi",):[test[3], test[2]],
        }
        return combinations

# SPAWRIOUS #############################################################

def download_spawrious(data_dir, remove=True):
    dst = os.path.join(data_dir, "spawrious.tar.gz")
    urllib.request.urlretrieve('https://www.dropbox.com/s/e40j553480h3f3s/spawrious224.tar.gz?dl=1', dst)
    tar = tarfile.open(dst, "r:gz")
    tar.extractall(os.path.dirname(dst))
    tar.close()
    if remove:
        os.remove(dst)

## Spawrious classes for each Spawrious dataset 
class SpawriousO2O_easy(SpawriousBenchmark):
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ["desert","jungle","dirt","snow"]
        test = ["dirt","snow","desert","jungle"]
        filler = "beach"
        combinations = self.build_type1_combination(group,test,filler)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac, type1=True)

class SpawriousO2O_medium(SpawriousBenchmark):
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ['mountain', 'beach', 'dirt', 'jungle']
        test = ['jungle', 'dirt', 'beach', 'snow']
        filler = "desert"
        combinations = self.build_type1_combination(group,test,filler)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac, type1=True)

class SpawriousO2O_hard(SpawriousBenchmark):
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ['jungle', 'mountain', 'snow', 'desert']
        test = ['mountain', 'snow', 'desert', 'jungle']
        filler = "beach"
        combinations = self.build_type1_combination(group,test,filler)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac, type1=True)

class SpawriousM2M_easy(SpawriousBenchmark):
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ['desert', 'mountain', 'dirt', 'jungle']
        test = ['dirt', 'jungle', 'mountain', 'desert']
        combinations = self.build_type2_combination(group,test)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac) 

class SpawriousM2M_medium(SpawriousBenchmark):
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ['beach', 'snow', 'mountain', 'desert']
        test = ['desert', 'mountain', 'beach', 'snow']
        combinations = self.build_type2_combination(group,test)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac)
        
class SpawriousM2M_hard(SpawriousBenchmark):
    ENVIRONMENTS = ["Test","SC_group_1","SC_group_2"]
    def __init__(self, root_dir, test_envs, hparams, val_frac=0.2, seed=42):
        set_seed(seed)
        group = ["dirt","jungle","snow","beach"]
        test = ["snow","beach","dirt","jungle"]
        combinations = self.build_type2_combination(group,test)
        super().__init__(combinations['train_combinations'], combinations['test_combinations'], root_dir, hparams['data_augmentation'], val_frac)