# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os, pickle
from io import BytesIO

import torch
from torchvision import transforms
import torchvision.datasets.folder
from torch.utils.data import TensorDataset, Subset
from torchvision.datasets import MNIST, ImageFolder
from torchvision.transforms.functional import rotate
import torch.distributions as tdist

import numpy as np
import random
import scipy.stats as stats
import pandas as pd
from PIL import Image, ImageFile

from collections import Counter

from small_norb.smallnorb.dataset import SmallNORBDataset

ImageFile.LOAD_TRUNCATED_IMAGES = True

DATASETS = [
    # Debug
    # "Debug28",
    # "Debug224",
    # Small images
    # MNIST
    "MNIST_Acause",
    "MNIST_Aind",
    "MNIST_AcauseUAind",
    # Big images
    # small NORB
    "SmallNORB",
    "SmallNORB_Acause",
    "SmallNORB_Aind",
    "SmallNORB_AcauseUAind",
    # Waterbirds
    "Waterbirds_Acause",
    "Waterbirds_Multiattr",
   
]

def get_dataset_class(dataset_name):
    """Return the dataset class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]


def num_environments(dataset_name):
    return len(get_dataset_class(dataset_name).ENVIRONMENTS)


class MultipleDomainDataset:
    N_STEPS = 5001           # Default, subclasses may override
    CHECKPOINT_FREQ = 100    # Default, subclasses may override
    N_WORKERS = 8            # Default, subclasses may override
    ENVIRONMENTS = None      # Subclasses should override
    INPUT_SHAPE = None       # Subclasses should override

    def __getitem__(self, index):
        return self.datasets[index]

    def __len__(self):
        return len(self.datasets)
    
    def get_transform(self, input_size, normalize, scheme):
        if scheme == 'domainbed':
            augment_transform = transforms.Compose([
                # transforms.Resize((224,224)),
                transforms.RandomResizedCrop(input_size, scale=(0.7, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
                transforms.RandomGrayscale(),
                transforms.ToTensor(),
                normalize
            ])
        
        elif scheme == 'jigen':
            augment_transform = transforms.Compose([
                # transforms.Resize((224,224)),
                transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
                transforms.RandomGrayscale(),
                transforms.ToTensor(),
                normalize
            ])
        elif scheme == 'decaug_nico':
            augment_transform = transforms.Compose([
                transforms.RandomResizedCrop(input_size),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize
            ])
        elif scheme == 'jigen_wo_color_aug':
            augment_transform = transforms.Compose([
                # transforms.Resize((224,224)),
                transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize
            ])
        else:
            raise NotImplementedError
            
        return augment_transform


class Debug(MultipleDomainDataset):
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        self.input_shape = self.INPUT_SHAPE
        self.num_classes = 2
        self.datasets = []
        for _ in [0, 1, 2]:
            self.datasets.append(
                TensorDataset(
                    torch.randn(16, *self.INPUT_SHAPE),
                    torch.randint(0, self.num_classes, (16,))
                )
            )

class Debug28(Debug):
    INPUT_SHAPE = (3, 28, 28)
    ENVIRONMENTS = ['0', '1', '2']

class Debug224(Debug):
    INPUT_SHAPE = (3, 224, 224)
    ENVIRONMENTS = ['0', '1', '2']


class MultipleEnvironmentMNIST(MultipleDomainDataset):
    def __init__(self, root, environments, dataset_transform, input_shape,
                 num_classes):
        super().__init__()
        if root is None:
            raise ValueError('Data directory not specified!')

        original_dataset_tr = MNIST(root, train=True, download=True)
        original_dataset_te = MNIST(root, train=False, download=True)

        original_images = torch.cat((original_dataset_tr.data,
                                     original_dataset_te.data))

        original_labels = torch.cat((original_dataset_tr.targets,
                                     original_dataset_te.targets))

        shuffle = torch.randperm(len(original_images))

        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []

        for i in range(len(environments)):
            images = original_images[i::len(environments)]
            labels = original_labels[i::len(environments)]
            self.datasets.append(dataset_transform(images, labels, environments[i]))

        self.input_shape = input_shape
        self.num_classes = num_classes


# single-attribute Causal
class MNIST_Acause(MultipleDomainDataset):
    N_STEPS = 5001
    CHECKPOINT_FREQ = 500
    ENVIRONMENTS = ['+90%', '+80%', '-90%']
    INPUT_SHAPE = (2, 14, 14)

    def __init__(self, root, test_envs, hparams):
        super().__init__()
        if root is None:
            raise ValueError('Data directory not specified!')

        original_dataset_tr = MNIST(root, train=True, download=True)

        original_images = original_dataset_tr.train_data
        original_labels = original_dataset_tr.train_labels

        shuffle = torch.randperm(len(original_images))
        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []

        environments = (0.1, 0.2, 0.9)
        for i, env in enumerate(environments[:-1]):
            images = original_images[:50000][i::2]
            labels = original_labels[:50000][i::2]
            self.datasets.append(self.color_dataset(images, labels, env))
        images = original_images[50000:]
        labels = original_labels[50000:]
        self.datasets.append(self.color_dataset(images, labels, environments[-1]))

        self.input_shape = self.INPUT_SHAPE
        self.num_classes = 2

    def color_dataset(self, images, labels, environment):
        # Subsample 2x for computational convenience
        images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit
        labels = (labels < 5).float()
        # Flip label with probability 0.25
        labels = self.torch_xor_(labels,
                                 self.torch_bernoulli_(0.25, len(labels)))

        # Assign a color based on the label; flip the color with probability e
        colors = self.torch_xor_(labels,
                                 self.torch_bernoulli_(environment,
                                                       len(labels)))
        images = torch.stack([images, images], dim=1)
        # Apply the color to the image by zeroing out the other color channel
        images[torch.tensor(range(len(images))), (
            1 - colors).long(), :, :] *= 0

        x = images.float().div_(255.0)
        y = labels.view(-1).long()

        return TensorDataset(x, y, colors, colors)
        # return TensorDataset(x, y)

    def torch_bernoulli_(self, p, size):
        return (torch.rand(size) < p).float()

    def torch_xor_(self, a, b):
        return (a - b).abs()

# single-attribute Independent
class MNIST_Aind(MultipleDomainDataset):
    N_STEPS = 5001
    CHECKPOINT_FREQ = 500
    ENVIRONMENTS = ['+90%', '+80%', '-90%']
    INPUT_SHAPE = (1, 14, 14) #  (3, 14, 14)

    def __init__(self, root, test_envs, hparams):
        super().__init__()
        if root is None:
            raise ValueError('Data directory not specified!')

        original_dataset_tr = MNIST(root, train=True, download=True)

        original_images = original_dataset_tr.train_data
        original_labels = original_dataset_tr.train_labels

        shuffle = torch.randperm(len(original_images))
        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []

        angles = ['15', '60', '90']
        for i, env in enumerate(environments[:-1]):
            images = original_images[:50000][i::2]
            labels = original_labels[:50000][i::2]
            self.datasets.append(self.rotate_dataset(images, labels, angles[i]))
        images = original_images[50000:]
        labels = original_labels[50000:]
        self.datasets.append(self.rotate_dataset(images, labels, angles[-1]))

        self.input_shape = self.INPUT_SHAPE
        self.num_classes = 2

    
    def rotate_dataset(self, images, labels, angle):
        rotation = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Lambda(lambda x: rotate(x, int(angle), fill=(0,),
                                               resample=Image.BICUBIC)),
            transforms.ToTensor()])

        # Subsample 2x for computational convenience
        images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit
        labels = (labels < 5).float()
        # Flip label with probability 0.25
        labels = self.torch_xor_(labels,
                                 self.torch_bernoulli_(0.25, len(labels)))

        x = torch.zeros(len(images), 1, 14, 14)
        for i in range(len(images)):
            x[i] = rotation(images[i].float().div_(255.0))

        y = labels.view(-1).long()

        return TensorDataset(x, y, y, y)

    def torch_bernoulli_(self, p, size):
        return (torch.rand(size) < p).float()

    def torch_xor_(self, a, b):
        return (a - b).abs()

# multi-attribute Causal + Independent
class MNIST_AcauseUAind(MultipleDomainDataset):
    N_STEPS = 5001
    CHECKPOINT_FREQ = 500
    ENVIRONMENTS = ['+90%', '+80%', '-90%']
    INPUT_SHAPE = (2, 14, 14)

    def __init__(self, root, test_envs, hparams):
        super().__init__()
        if root is None:
            raise ValueError('Data directory not specified!')

        original_dataset_tr = MNIST(root, train=True, download=True)

        original_images = original_dataset_tr.train_data
        original_labels = original_dataset_tr.train_labels

        shuffle = torch.randperm(len(original_images))
        original_images = original_images[shuffle]
        original_labels = original_labels[shuffle]

        self.datasets = []

        environments = (0.1, 0.2, 0.9)
        angles = ['15', '60', '90']
        for i, env in enumerate(environments[:-1]):
            images = original_images[:50000][i::2]
            labels = original_labels[:50000][i::2]
            self.datasets.append(self.color_dataset(images, labels, env, angles[i]))
        images = original_images[50000:]
        labels = original_labels[50000:]
        self.datasets.append(self.color_dataset(images, labels, environments[-1], angles[-1]))

        self.input_shape = self.INPUT_SHAPE
        self.num_classes = 2

    def color_dataset(self, images, labels, environment, angle):
        # Subsample 2x for computational convenience
        images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # rotate the image by angle in parameter
        images = self.rotate_dataset(images, angle)
        # Assign a binary label based on the digit
        labels = (labels < 5).float()
        # Flip label with probability 0.25
        labels = self.torch_xor_(labels,
                                 self.torch_bernoulli_(0.25, len(labels)))

        # Assign a color based on the label; flip the color with probability e
        colors = self.torch_xor_(labels,
                                 self.torch_bernoulli_(environment,
                                                       len(labels)))
        images = torch.stack([images, images], dim=1)
        # Apply the color to the image by zeroing out the other color channel
        images[torch.tensor(range(len(images))), (
            1 - colors).long(), :, :] *= 0

        x = images #.float().div_(255.0)
        y = labels.view(-1).long()

        return TensorDataset(x, y, colors, colors)
    
    def rotate_dataset(self, images, angle):
        rotation = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Lambda(lambda x: transforms.functional.rotate(x, int(angle), fill=(0,))),
            transforms.ToTensor()])

        x = torch.zeros(len(images), 14, 14)
        for i in range(len(images)):
            x[i] = rotation(images[i].float().div_(255.0))
        return x

    def torch_bernoulli_(self, p, size):
        return (torch.rand(size) < p).float()

    def torch_xor_(self, a, b):
        return (a - b).abs()

################ SmallNORB: corr_lighting ########################
class SmallNORB_corr_lighting(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # calculate counts for 0.9 spurious correlation of category with lighting
        """
        category i correlated with lighting i
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = category_feat_counts[0]['lighting'][0] * 0.01
        corr_count_max = category_feat_counts[0]['lighting'][0] * 0.99
        print('uncorr count max', uncorr_count_max)
        print('corr count max', corr_count_max)

        corr_counts_total, uncorr_counts_total = 0, 0
        self.dataset = []
        # case 1: corr(yi, li) = 0.9
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if self.split == 'train':
                if small_norb_example.category == small_norb_example.lighting:
                    if group_corr_counts[small_norb_example.category] < corr_count_max:
                        corr_counts_total += 1
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                elif group_uncorr_counts[small_norb_example.category] < uncorr_count_max:
                    uncorr_counts_total += 1
                    self.dataset.append(small_norb_example)
                    group_uncorr_counts[small_norb_example.category] += 1
            else:
                if small_norb_example.category != small_norb_example.lighting:
                    if group_corr_counts[small_norb_example.category] < corr_count_max:
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                elif group_uncorr_counts[small_norb_example.category] < uncorr_count_max:
                    self.dataset.append(small_norb_example)
                    group_uncorr_counts[small_norb_example.category] += 1

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        return image, label

    def __len__(self):
        return len(self.dataset)

class SmallNORB_spurious_lighting2(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation, domain_idx):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # calculate counts for 0.9 spurious correlation of category with lighting
        """
        category i correlated with lighting i
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = (category_feat_counts[0]['lighting'][0])/2 * (1 - correlation)
        corr_count_max = (category_feat_counts[0]['lighting'][0])/2 * correlation
        print('corr count max', corr_count_max)
        print('uncorr count max', uncorr_count_max)

        total_examples = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            total_examples += 1
        print('Total examples', total_examples)

        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting:
                    if domain_idx == 0 and j <= (total_examples/2):
                        uncorr_indices[i].append(j)
                    elif domain_idx == 1 and j > (total_examples/2):
                        uncorr_indices[i].append(j)

        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])
                group_uncorr_counts[i] += 1

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    if domain_idx == 0 and i <= (total_examples/2):
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                        corr_counts_total += 1
                    elif domain_idx == 1 and i > (total_examples/2):
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                        corr_counts_total += 1
        
        print('Total corr samples = ', corr_counts_total, 'Total uncorr samples =', uncorr_counts_total)
        print('group uncorr counts', group_uncorr_counts)
        print('group corr counts', group_corr_counts)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        # image2 = small_norb_example.image_rt
        # image2 = self.transform(image2)
        # image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting#, azimuth
        # return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

from collections import defaultdict

class SmallNORB_corr_lighting_noise(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation, domain_idx, list_indices):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # calculate counts for 0.9 spurious correlation of category with lighting
        """
        category i correlated with lighting i
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = (category_feat_counts[0]['lighting'][0])/2 * (1 - correlation)
        corr_count_max = (category_feat_counts[0]['lighting'][0])/2 * correlation

        total_examples = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            total_examples += 1

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i:
                    if domain_idx == 0 and j in list_indices[0: int(total_examples/2)]:
                        cat_indices[i].append(j)
                    elif domain_idx == 1 and j in list_indices[int(total_examples/2):]:
                        cat_indices[i].append(j)

        count_noise = len(cat_indices[0]) * 0.05
        for i in range(5):
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            # print('i', i)
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.orig_dataset.data[self.split][chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

        azimuth_counts = dict.fromkeys(range(0, 36, 2), 0)
        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting:
                    if domain_idx == 0 and j in list_indices[0: int(total_examples/2)]:
                        uncorr_indices[i].append(j)
                    elif domain_idx == 1 and j in list_indices[int(total_examples/2):]:
                        uncorr_indices[i].append(j)

        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])
                group_uncorr_counts[i] += 1
                azimuth_counts[self.orig_dataset.data[self.split][chosen_idx].azimuth] += 1

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    if domain_idx == 0 and i in list_indices[0: int(total_examples/2)]:
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                        corr_counts_total += 1
                    elif domain_idx == 1 and i in list_indices[int(total_examples/2):]:
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                        corr_counts_total += 1
                    azimuth_counts[small_norb_example.azimuth] += 1
        
        print('Total corr samples = ', corr_counts_total, 'Total uncorr samples =', uncorr_counts_total)
        print('group uncorr counts', group_uncorr_counts)
        print('group corr counts', group_corr_counts)
        print('Azimuth counts', azimuth_counts)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

class SmallNORB_corr_lighting_noise_ablation(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # calculate counts for 0.9 spurious correlation of category with lighting
        """
        category i correlated with lighting i
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = (category_feat_counts[0]['lighting'][0]) * (1 - correlation)
        corr_count_max = (category_feat_counts[0]['lighting'][0]) * correlation

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i:
                    cat_indices[i].append(j)

        count_noise = len(cat_indices[0]) * 0.05
        for i in range(5):
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.orig_dataset.data[self.split][chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting:
                    uncorr_indices[i].append(j)

        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])
                group_uncorr_counts[i] += 1

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    self.dataset.append(small_norb_example)
                    group_corr_counts[small_norb_example.category] += 1
                    corr_counts_total += 1
        
    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

class SmallNORB_spurious_lighting_test(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation, noise=False):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])
        self.noise = noise

        if self.noise:
            # add noise to labels
            cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
            for i in range(5):
                for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                    if small_norb_example.category == i:
                        cat_indices[i].append(j)

            count_noise = len(cat_indices[0]) * 0.05
            for i in range(5):
                chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

            shuffled_counts = 0
            for i in range(5):
                indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
                for chosen_idx in chosen_noise_indices[i]:
                    shuffled_counts += 1
                    self.orig_dataset.data[self.split][chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

        # calculate counts for `correlation` spurious correlation of category with lighting
        """
        category i correlated with lighting i
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = (category_feat_counts[0]['lighting'][0]) * (1 - correlation)
        corr_count_max = (category_feat_counts[0]['lighting'][0]) * correlation
        print('test corr count max', corr_count_max)
        print('test uncorr count max', uncorr_count_max)

        total_examples = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            total_examples += 1
        print('Total examples', total_examples)

        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting:
                    uncorr_indices[i].append(j)

        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])
                group_uncorr_counts[i] += 1

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    self.dataset.append(small_norb_example)
                    group_corr_counts[small_norb_example.category] += 1
                    corr_counts_total += 1
        
        print('Total corr samples = ', corr_counts_total, 'Total uncorr samples =', uncorr_counts_total)
        print('group uncorr counts', group_uncorr_counts)
        print('group corr counts', group_corr_counts)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting#, azimuth
        # return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

###########################################################################

################# SmallNORB: azimuth/lighting shifts ######################
class SmallNORB_lighting_shift(torch.utils.data.Dataset):
    def __init__(self, dataset, split):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        unseen azimuth shift
        """
       
        self.dataset = []
        # case 2: {lighting_i | i<=3}
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.lighting <= 1:
                self.dataset.append(small_norb_example)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        return image, label

    def __len__(self):
        return len(self.dataset)

class SmallNORB_azimuth_shift(torch.utils.data.Dataset):
    def __init__(self, dataset, split, multiple_domain=False, azimuth_vals=None):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.multiple_domain = multiple_domain
        # self.domain_idx = domain_idx
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        unseen azimuth shift
        """
       
        self.dataset = []
        if self.multiple_domain is False:
            # case 2: {azimuth_i | i<=4}
            for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.azimuth <= 6:
                    self.dataset.append(small_norb_example)
        else:
            for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.azimuth in azimuth_vals:
                    self.dataset.append(small_norb_example)


    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        # return image, label, lighting, azimuth
        return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

class SmallNORB_azimuth_shift_noise(torch.utils.data.Dataset):
    def __init__(self, dataset, split, azimuth_vals, domain_idx, list_indices):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.domain_idx = domain_idx
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        unseen azimuth shift
        """
       
        total_examples = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            total_examples += 1

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}

        self.dataset = []
        self.dataset_indices = []
        count_ds = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.azimuth in azimuth_vals:
                if domain_idx == 0 and i in list_indices[0: int(total_examples/2)]:
                    self.dataset.append(small_norb_example)
                    cat_indices[small_norb_example.category].append(count_ds)
                    count_ds += 1
                    self.dataset_indices.append(i)
                elif domain_idx == 1 and i in list_indices[int(total_examples/2):]:
                    self.dataset.append(small_norb_example)
                    cat_indices[small_norb_example.category].append(count_ds)
                    count_ds += 1
                    self.dataset_indices.append(i)

        for i in range(5):
            count_noise = len(cat_indices[i]) * 0.05
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.dataset[chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth, lighting
        # return image, label, azimuth

    def __len__(self):
        return len(self.dataset)

################# SmallNORB: combination of shifts ###############################
class SmallNORB_corr_lighting_plus_azimuth_shift(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, domain_idx=0):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.domain_idx = domain_idx
        self.transform = transforms.Compose([transforms.ToTensor()])

        # calculate counts for 0.9 spurious correlation of category with lighting
        """
        category i correlated with lighting i and unseen azimuth shift
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        uncorr_count_max = category_feat_counts[0]['lighting'][0] * 0.01
        corr_count_max = category_feat_counts[0]['lighting'][0] * 0.99
        print('uncorr count max', uncorr_count_max)
        print('corr count max', corr_count_max)

        self.dataset = []

        # case 3: corr(yi, li) = 0.9 + {azimuth_i | i<=4}
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    if self.domain_idx == 0:
                        if small_norb_example.azimuth <= 4:
                            self.dataset.append(small_norb_example)
                            group_corr_counts[small_norb_example.category] += 1
                    elif self.domain_idx == 1:
                        if small_norb_example.azimuth > 4 and small_norb_example.azimuth <= 10:
                            self.dataset.append(small_norb_example)
                            group_corr_counts[small_norb_example.category] += 1
            elif group_uncorr_counts[small_norb_example.category] < uncorr_count_max:
                if self.domain_idx == 0:
                    if small_norb_example.azimuth <= 4:
                        self.dataset.append(small_norb_example)
                        group_uncorr_counts[small_norb_example.category] += 1
                elif self.domain_idx == 1:
                    if small_norb_example.azimuth > 4 and small_norb_example.azimuth <= 10:
                        self.dataset.append(small_norb_example)
                        group_uncorr_counts[small_norb_example.category] += 1
        
        print('group uncorr counts', group_uncorr_counts)
        print('group corr counts', group_corr_counts)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        return image, label

    def __len__(self):
        return len(self.dataset)

class SmallNORB_corr_lighting_plus_azimuth_shift_multiple_domains(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation, azimuth_vals):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        category i correlated with lighting i and unseen azimuth shift
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}

        category_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.azimuth in azimuth_vals:
                category_counts[small_norb_example.category] += 1
        print('cat counts', category_counts)
        uncorr_count_max = (category_counts[0]/5) * (1-correlation)
        corr_count_max = (category_counts[0]/5) * correlation
        print('uncorr count max', uncorr_count_max)
        print('corr count max', corr_count_max)

        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting \
                        and small_norb_example.azimuth in azimuth_vals:
                    uncorr_indices[i].append(j)

        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    if small_norb_example.azimuth in azimuth_vals:
                        self.dataset.append(small_norb_example)
                        group_corr_counts[small_norb_example.category] += 1
                        corr_counts_total += 1
        
        print('Total corr samples = ', corr_counts_total, 'Total uncorr samples =', uncorr_counts_total)
        print('group uncorr counts', group_uncorr_counts)
        print('group corr counts', group_corr_counts)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        # cat both images
        # image2 = small_norb_example.image_rt
        # image2 = self.transform(image2)
        # image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting#, azimuth
        # return image, label

    def __len__(self):
        return len(self.dataset)

class SmallNORB_corr_lighting_plus_azimuth_shift_multiple_domains_noise(torch.utils.data.Dataset):
    def __init__(self, dataset, split, category_feat_counts, correlation, azimuth_vals, domain_idx, list_indices):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        category i correlated with lighting i and unseen azimuth shift
        """
        group_uncorr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        group_corr_counts = {0:0, 1:0, 2:0, 3:0, 4:0}

        total_examples = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            total_examples += 1
        # print('Total examples', total_examples)

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if small_norb_example.category == i:
                    if domain_idx == 0 and j in list_indices[0: int(total_examples/2)]:
                        cat_indices[i].append(j)
                    elif domain_idx == 1 and j in list_indices[int(total_examples/2):]:
                        cat_indices[i].append(j)

        count_noise = len(cat_indices[0]) * 0.05
        for i in range(5):
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.orig_dataset.data[self.split][chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

        category_counts = {0:0, 1:0, 2:0, 3:0, 4:0}
        lighting_vals = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0}
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if domain_idx == 0 and i in list_indices[0: int(total_examples/2)]:
                if small_norb_example.azimuth in azimuth_vals:
                    category_counts[small_norb_example.category] += 1
                    if small_norb_example.category == 0:
                        lighting_vals[small_norb_example.lighting] += 1
            elif domain_idx == 1 and i in list_indices[int(total_examples/2):]:
                if small_norb_example.azimuth in azimuth_vals:
                    category_counts[small_norb_example.category] += 1
                    if small_norb_example.category == 0:
                        lighting_vals[small_norb_example.lighting] += 1
                    
        max_cat_lighting = 0
        for lighting_val in lighting_vals.keys():
            max_cat_lighting = max(max_cat_lighting, lighting_vals[lighting_val])

        uncorr_count_max = (max_cat_lighting) * ((1-correlation)/correlation)
        corr_count_max = (max_cat_lighting) #* correlation

        self.dataset = []
        uncorr_indices, chosen_uncorr_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if domain_idx == 0 and j in list_indices[0: int(total_examples/2)]:
                    if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting \
                            and small_norb_example.azimuth in azimuth_vals:
                        uncorr_indices[i].append(j)
                elif domain_idx == 1 and j in list_indices[int(total_examples/2):]:
                    if small_norb_example.category == i and small_norb_example.category != small_norb_example.lighting \
                            and small_norb_example.azimuth in azimuth_vals:
                        uncorr_indices[i].append(j)


        # randomly choose uncorr_count_max indices for choosing uncorrelated examples for all categories
        for i in range(5):
            chosen_uncorr_indices[i] = random.sample(uncorr_indices[i], int(uncorr_count_max))

        uncorr_counts_total, corr_counts_total = 0, 0
        for i in range(5):
            for chosen_idx in chosen_uncorr_indices[i]:
                uncorr_counts_total += 1
                self.dataset.append(self.orig_dataset.data[self.split][chosen_idx])
                group_uncorr_counts[i] += 1

        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.category == small_norb_example.lighting:
                if group_corr_counts[small_norb_example.category] < corr_count_max:
                    if small_norb_example.azimuth in azimuth_vals:
                        if domain_idx == 0 and i in list_indices[0: int(total_examples/2)]:
                            self.dataset.append(small_norb_example)
                            group_corr_counts[small_norb_example.category] += 1
                            corr_counts_total += 1
                        elif domain_idx == 1 and i in list_indices[int(total_examples/2):]:
                            self.dataset.append(small_norb_example)
                            group_corr_counts[small_norb_example.category] += 1
                            corr_counts_total += 1
        
    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        # cat both images
        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth, lighting
        # return image, label

    def __len__(self):
        return len(self.dataset)

###########################################################################

class SmallNORB_lightning_shift_plus_azimuth_shift(torch.utils.data.Dataset):
    def __init__(self, dataset, split, domain_idx=0, multiple_domain=False):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.domain_idx = domain_idx
        self.multiple_domain = multiple_domain
        self.transform = transforms.Compose([transforms.ToTensor()])

        """
        unseen lighting shift and azimuth shift
        """
       
        self.dataset = []
        # case 4: {azimuth_i | i<=4} & {lighting_i | i<=3}
        if self.multiple_domain is False:
            for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                    if small_norb_example.azimuth <= 6 and small_norb_example.lighting <= 2:
                        self.dataset.append(small_norb_example)
        else:
            for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
                if self.domain_idx == 0:
                    if small_norb_example.azimuth <= 4 and small_norb_example.lighting <= 0:
                        self.dataset.append(small_norb_example)
                elif self.domain_idx == 1:
                    if (small_norb_example.azimuth > 4 and small_norb_example.azimuth <= 10) and \
                        (small_norb_example.lighting > 0 and small_norb_example.lighting <= 2):
                        self.dataset.append(small_norb_example)

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        return image, label

    def __len__(self):
        return len(self.dataset)

###################### SmallNORB: base datasets ###########################
class SmallNORB_base_noise(torch.utils.data.Dataset):
    def __init__(self, dataset, split):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}
        for i in range(5):
            for j, small_norb_example in enumerate(self.dataset.data[self.split]):
                if small_norb_example.category == i:
                    cat_indices[i].append(j)

        count_noise = len(cat_indices[0]) * 0.05
        for i in range(5):
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.dataset.data[self.split][chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset.data[self.split][index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth
        # return image, label

    def __len__(self):
        return len(self.dataset.data[self.split])

class SmallNORB_base_noise_azimuth_filter(torch.utils.data.Dataset):
    def __init__(self, dataset, split, azimuth_vals):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.orig_dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

        # add noise to labels
        cat_indices, chosen_noise_indices = {0: [], 1: [], 2: [], 3: [], 4: []}, {0: [], 1: [], 2: [], 3: [], 4: []}

        self.dataset = []
        self.dataset_indices = []
        count_ds = 0
        for i, small_norb_example in enumerate(self.orig_dataset.data[self.split]):
            if small_norb_example.azimuth in azimuth_vals:
                self.dataset.append(small_norb_example)
                cat_indices[small_norb_example.category].append(count_ds)
                count_ds += 1
                self.dataset_indices.append(i)

        for i in range(5):
            count_noise = len(cat_indices[i]) * 0.05
            chosen_noise_indices[i] = random.sample(cat_indices[i], int(count_noise))

        shuffled_counts = 0
        for i in range(5):
            indices_shuffle_choice = [shuffle_idx for shuffle_idx in range(0, 5) if shuffle_idx != i]
            for chosen_idx in chosen_noise_indices[i]:
                shuffled_counts += 1
                self.dataset[chosen_idx].category = random.sample(indices_shuffle_choice, 1)[0]

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset[index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting, azimuth
        # return image, label, azimuth, lighting

    def __len__(self):
        return len(self.dataset)

class SmallNORB_base(torch.utils.data.Dataset):
    def __init__(self, dataset, split):
        """
        :param dataset: SmallNORB dataset
        :param split: dataset split/domain
        """
        self.dataset = dataset
        self.split = split
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __getitem__(self, index):
        """
        :param index: int
        :return: image: Tensor: (1, w, h)
                 label: str
        """

        small_norb_example = self.dataset.data[self.split][index]
        image = small_norb_example.image_lt
        label = small_norb_example.category
        
        image = self.transform(image)
        label = np.array(label)
        label = torch.from_numpy(label).to(torch.int64)

        image2 = small_norb_example.image_rt
        image2 = self.transform(image2)
        image = torch.cat((image, image2), 0)

        # return attribute labels
        azimuth = small_norb_example.azimuth
        lighting = small_norb_example.lighting

        return image, label, lighting#, azimuth
        # return image, label, azimuth
        # return image, label

    def __len__(self):
        return len(self.dataset.data[self.split])

class SmallNORB_Aind(MultipleDomainDataset):
    ENVIRONMENTS = [ "tr1", "tr2", "test"]
    N_STEPS = 2000
    CHECKPOINT_FREQ = 200
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        dataset = SmallNORBDataset(dataset_root=f'{root}/small_norb_root')
        self.input_shape = (2, 96, 96,)
        self.num_classes = 5

        dataset_splits = ['train', 'test']
        feat_keys = ['elevation', 'azimuth', 'lighting']
        for dataset_split in dataset_splits:
            # print('Split = ', dataset_split)
            category_feat_counts = {}
            lighting = set()
            for i, small_norb_example in enumerate(dataset.data[dataset_split]):

                if small_norb_example.category not in category_feat_counts:
                    category_feat_counts[small_norb_example.category] = {'elevation': {}, 'azimuth': {}, 'lighting': {}} 
                if small_norb_example.elevation not in category_feat_counts[small_norb_example.category]['elevation']:
                    category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] = 0
                if small_norb_example.azimuth not in category_feat_counts[small_norb_example.category]['azimuth']:
                    category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] = 0
                if small_norb_example.lighting not in category_feat_counts[small_norb_example.category]['lighting']:
                    category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] = 0
                category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] += 1
                category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] += 1
                category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] += 1

    
        total_examples = 0
        for i, small_norb_example in enumerate(dataset.data['train']):
            total_examples += 1

        # randomly shuffle data, then divide into domains
        list_indices = list(range(0, total_examples))
        random.shuffle(list_indices)

        # azimuth shift
        tr_dataset1 = SmallNORB_azimuth_shift_noise(dataset, 'train', [0, 2, 4], 0, list_indices)
        tr_dataset2 = SmallNORB_azimuth_shift_noise(dataset, 'train', [6, 8, 10], 1, list_indices)

        # test dataset variants
        te_dataset = SmallNORB_base_noise_azimuth_filter(dataset, 'test', [ai for ai in range(24, 36, 2)])
        self.datasets = [tr_dataset1, tr_dataset2, te_dataset]


class SmallNORB_Acause(MultipleDomainDataset):
    ENVIRONMENTS = [ "tr1", "tr2", "test"]
    N_STEPS = 2000
    CHECKPOINT_FREQ = 200
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        dataset = SmallNORBDataset(dataset_root=f'{root}/small_norb_root')
        self.input_shape = (2, 96, 96,)
        self.num_classes = 5

        dataset_splits = ['train', 'test']
        feat_keys = ['elevation', 'azimuth', 'lighting']
        for dataset_split in dataset_splits:
            # print('Split = ', dataset_split)
            category_feat_counts = {}
            lighting = set()
            for i, small_norb_example in enumerate(dataset.data[dataset_split]):

                if small_norb_example.category not in category_feat_counts:
                    category_feat_counts[small_norb_example.category] = {'elevation': {}, 'azimuth': {}, 'lighting': {}} 
                if small_norb_example.elevation not in category_feat_counts[small_norb_example.category]['elevation']:
                    category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] = 0
                if small_norb_example.azimuth not in category_feat_counts[small_norb_example.category]['azimuth']:
                    category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] = 0
                if small_norb_example.lighting not in category_feat_counts[small_norb_example.category]['lighting']:
                    category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] = 0
                category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] += 1
                category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] += 1
                category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] += 1

    
        # corr-lighting_noise
        total_examples = 0
        for i, small_norb_example in enumerate(dataset.data['train']):
            total_examples += 1

        # randomly shuffle data, then divide into domains
        list_indices = list(range(0, total_examples))
        random.shuffle(list_indices)

        tr_dataset1 = SmallNORB_corr_lighting_noise(dataset, 'train', category_feat_counts, 0.90, 0, list_indices)
        tr_dataset2 = SmallNORB_corr_lighting_noise(dataset, 'train', category_feat_counts, 0.95, 1, list_indices)
    
        # test dataset variants
        te_dataset = SmallNORB_base_noise(dataset, 'test')
        self.datasets = [tr_dataset1, tr_dataset2, te_dataset]


class SmallNORB_AcauseUAind(MultipleDomainDataset):
    ENVIRONMENTS = [ "tr1", "tr2", "test"]
    N_STEPS = 2000
    CHECKPOINT_FREQ = 200
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        dataset = SmallNORBDataset(dataset_root=f'{root}/small_norb_root')
        self.input_shape = (2, 96, 96,)
        self.num_classes = 5

        dataset_splits = ['train', 'test']
        feat_keys = ['elevation', 'azimuth', 'lighting']
        for dataset_split in dataset_splits:
            # print('Split = ', dataset_split)
            category_feat_counts = {}
            lighting = set()
            for i, small_norb_example in enumerate(dataset.data[dataset_split]):

                if small_norb_example.category not in category_feat_counts:
                    category_feat_counts[small_norb_example.category] = {'elevation': {}, 'azimuth': {}, 'lighting': {}} 
                if small_norb_example.elevation not in category_feat_counts[small_norb_example.category]['elevation']:
                    category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] = 0
                if small_norb_example.azimuth not in category_feat_counts[small_norb_example.category]['azimuth']:
                    category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] = 0
                if small_norb_example.lighting not in category_feat_counts[small_norb_example.category]['lighting']:
                    category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] = 0
                category_feat_counts[small_norb_example.category]['elevation'][small_norb_example.elevation] += 1
                category_feat_counts[small_norb_example.category]['azimuth'][small_norb_example.azimuth] += 1
                category_feat_counts[small_norb_example.category]['lighting'][small_norb_example.lighting] += 1

    
        # corr-lighting_noise
        total_examples = 0
        for i, small_norb_example in enumerate(dataset.data['train']):
            total_examples += 1

        # randomly shuffle data, then divide into domains
        list_indices = list(range(0, total_examples))
        random.shuffle(list_indices)
    

        # corr-lighting + azimuth shift
        tr_dataset1 = SmallNORB_corr_lighting_plus_azimuth_shift_multiple_domains_noise(dataset, 'train', category_feat_counts, 0.90, [0, 2, 4], 0, list_indices)
        tr_dataset2 = SmallNORB_corr_lighting_plus_azimuth_shift_multiple_domains_noise(dataset, 'train', category_feat_counts, 0.95, [6, 8, 10], 1, list_indices)


        # test dataset variants
        te_dataset = SmallNORB_base_noise_azimuth_filter(dataset, 'test', [ai for ai in range(24, 36, 2)])
        self.datasets = [tr_dataset1, tr_dataset2, te_dataset]


# Waterbirds dataset
class Waterbirds(torch.utils.data.Dataset):
    def __init__(self, data_dir, root_images, y_array, confounder_array, transform=None, train=True, augment=False):
        self.data_dir = data_dir
        self.root_images = root_images
        self.y_array = y_array
        self.confounder_array = confounder_array
        self.transform = transform
        self.train = train
        self.augment = augment
        
    def __len__(self):
        return len(self.y_array)

    def __getitem__(self, idx):
        
        y = self.y_array[idx]
        bgd = self.confounder_array[idx]
        img_filename = os.path.join(
            self.data_dir,
            self.root_images[idx])
        img = Image.open(img_filename).convert('RGB')
        
        # Apply weather augmentation
        add_effect_flag = 0
        if self.augment:
            img = np.array(img)
            if self.train:
                add_effect_flag = np.random.choice([0,1])
                if add_effect_flag == 1:
                    img = am.darken(img, darkness_coeff=0.5)
            else:
                img = am.add_rain(img, rain_type='heavy', slant=20)
            img = Image.fromarray(img)
        
        # Apply transform 
        img = self.transform(img)
        x = img

        return x,y,bgd,add_effect_flag
    
class Waterbirds_Acause(MultipleDomainDataset):
    ENVIRONMENTS = [ "tr1", "tr2", "val", "test1", "test2", "test3", "test4"]
    N_STEPS = 2001
    CHECKPOINT_FREQ = 200
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        self.data_dir = os.path.join(
            root,
            'dataset',
            '_'.join(['waterbird_complete95'] + ['forest2water2']))

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Read in metadata
        self.metadata_df = pd.read_csv(
            os.path.join(self.data_dir, 'metadata.csv'))
       
        self.input_shape = (3, 224, 224,)
        self.num_classes = 2
        
        # Get the y values
        self.y_array = self.metadata_df['y'].values

        # We only support one confounder for CUB for now
        self.confounder_array = self.metadata_df['place'].values
        self.n_confounders = 1
        # Map to groups
        self.n_groups = pow(2, 2)
        self.group_array = (self.y_array*(self.n_groups/2) + self.confounder_array).astype('int')

        # Extract filenames and splits
        self.filename_array = self.metadata_df['img_filename'].values
        self.split_array = self.metadata_df['split'].values
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }
        
        self.train_transform = get_transform_cub(
            train=True,
            )
        self.eval_transform = get_transform_cub(
            train=False,
            )
        
        self.datasets =  []
        
        # get train subset
        split = 'train'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
       
        train_filename_array = self.filename_array[indices]
        train_y_array = self.y_array[indices]
        train_group_array = self.group_array[indices]
        train_confounder_array = self.confounder_array[indices]


        # train domains based on |A|
        for group_idx in range(self.n_confounders+1):
            group_mask = train_confounder_array == group_idx
            group_indices = np.where(group_mask)[0]
            self.datasets.append(Waterbirds(self.data_dir, 
                                            train_filename_array[group_indices], 
                                            train_y_array[group_indices], 
                                            train_confounder_array[group_indices],
                                            self.train_transform))
        
        # get val subset
        split = 'val'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
       
        val_filename_array = self.filename_array[indices]
        val_y_array = self.y_array[indices]
        val_group_array = self.group_array[indices]
        val_confounder_array = self.confounder_array[indices]
        
        self.datasets.append(Waterbirds(self.data_dir,
                                        val_filename_array, 
                                        val_y_array, 
                                        val_confounder_array,
                                        self.eval_transform))
        
        # get test subset
        split = 'test'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
        
        test_filename_array = self.filename_array[indices]
        test_y_array = self.y_array[indices]
        test_group_array = self.group_array[indices]
        test_confounder_array = self.confounder_array[indices]
        
        
        # test domains based on |A| x |Y|
        for group_idx in range(self.n_groups):
            group_mask = test_group_array == group_idx
            group_indices = np.where(group_mask)[0]
            self.datasets.append(Waterbirds(self.data_dir, 
                                            test_filename_array[group_indices], 
                                            test_y_array[group_indices], 
                                            test_confounder_array[group_indices],
                                            self.eval_transform))     

class Waterbirds_Acause_CACM(MultipleDomainDataset):
    ENVIRONMENTS = [ "tr", "val", "test1", "test2", "test3", "test4"]
    N_STEPS = 2001
    CHECKPOINT_FREQ = 200
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        self.data_dir = os.path.join
            root,
            'dataset',
            '_'.join(['waterbird_complete95'] + ['forest2water2']))

        if not os.path.exists(self.data_dir):
            raise ValueError(
                f'{self.data_dir} does not exist yet. Please generate the dataset first.')

        # Read in metadata
        self.metadata_df = pd.read_csv(
            os.path.join(self.data_dir, 'metadata.csv'))
       
        self.input_shape = (3, 224, 224,)
        self.num_classes = 2
        
        # Get the y values
        self.y_array = self.metadata_df['y'].values

        # We only support one confounder for CUB for now
        self.confounder_array = self.metadata_df['place'].values
        self.n_confounders = 1
        # Map to groups
        self.n_groups = pow(2, 2)
        self.group_array = (self.y_array*(self.n_groups/2) + self.confounder_array).astype('int')

        # Extract filenames and splits
        self.filename_array = self.metadata_df['img_filename'].values
        self.split_array = self.metadata_df['split'].values
        self.split_dict = {
            'train': 0,
            'val': 1,
            'test': 2
        }
        
        self.train_transform = get_transform_cub(
            train=True,
            )
        self.eval_transform = get_transform_cub(
            train=False,
            )
        
        self.datasets =  []
        
        # get train subset
        split = 'train'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
       
        train_filename_array = self.filename_array[indices]
        train_y_array = self.y_array[indices]
        train_group_array = self.group_array[indices]
        train_confounder_array = self.confounder_array[indices]

        self.datasets.append(Waterbirds(self.data_dir,
                                        train_filename_array, 
                                        train_y_array, 
                                        train_confounder_array,
                                        self.train_transform))
        
        # get val subset
        split = 'val'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
       
        val_filename_array = self.filename_array[indices]
        val_y_array = self.y_array[indices]
        val_group_array = self.group_array[indices]
        val_confounder_array = self.confounder_array[indices]
        
        self.datasets.append(Waterbirds(self.data_dir,
                                        val_filename_array, 
                                        val_y_array, 
                                        val_confounder_array,
                                        self.eval_transform))
        
        # get test subset
        split = 'test'
        mask = self.split_array == self.split_dict[split]
        num_split = np.sum(mask)
        indices = np.where(mask)[0]
        
        test_filename_array = self.filename_array[indices]
        test_y_array = self.y_array[indices]
        test_group_array = self.group_array[indices]
        test_confounder_array = self.confounder_array[indices]
        
        
        # test domains based on |A| x |Y|
        for group_idx in range(self.n_groups):
            group_mask = test_group_array == group_idx
            group_indices = np.where(group_mask)[0]
            self.datasets.append(Waterbirds(self.data_dir, 
                                            test_filename_array[group_indices], 
                                            test_y_array[group_indices], 
                                            test_confounder_array[group_indices],
                                            self.eval_transform))   

def get_transform_cub(train):
    scale = 256.0/224.0
    target_resolution = (224, 224)

    if (not train):
        # Resizes the image to a slightly larger square then crops the center.
        transform = transforms.Compose([
            transforms.Resize((int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    else:
        transform = transforms.Compose([
            transforms.RandomResizedCrop(
                target_resolution,
                scale=(0.7, 1.0),
                ratio=(0.75, 1.3333333333333333),
                interpolation=2),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    return transform


    