import os
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from PIL import Image, ImageFile
from torchvision import transforms
from transformers import (
    BertTokenizer,
    AutoTokenizer,
    DistilBertTokenizer,
    GPT2Tokenizer,
)
from torchvision import datasets

ImageFile.LOAD_TRUNCATED_IMAGES = True

DATASETS = [
    # Synthetic dataset
    "CMNIST",
    # Current subpop datasets
    "Waterbirds",
    "CelebA",
    "CivilCommentsFine",  # "CivilComments"
    "MultiNLI",
    "MetaShift",
    "ImagenetBG",
    "NICOpp",
    "MIMICNoFinding",
    "MIMICNotes",
    "CXRMultisite",
    "CheXpertNoFinding",
    "Living17",
    "Entity13",
    "Entity30",
    "Nonliving26",
]
# TODO: adapt cmnist and civilcomments


def get_dataset_class(dataset_name):
    """Return the dataset class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError(f"Dataset not found: {dataset_name}")
    return globals()[dataset_name]


def num_environments(dataset_name):
    return len(get_dataset_class(dataset_name).ENVIRONMENTS)


class SubpopDataset:
    N_STEPS = 5001  # Default, subclasses may override
    CHECKPOINT_FREQ = 100  # Default, subclasses may override
    N_WORKERS = 8  # Default, subclasses may override
    INPUT_SHAPE = None  # Subclasses should override
    SPLITS = {"tr": 0, "va": 1, "te": 2}  # Default, subclasses may override
    EVAL_SPLITS = ["te"]  # Default, subclasses may override

    def __init__(
        self,
        root,
        split,
        metadata,
        transform,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
    ):
        df = pd.read_csv(metadata)
        df = df[df["split"] == (self.SPLITS[split])]

        self.idx = list(range(len(df)))
        self.x = (
            df["filename"].astype(str).map(lambda x: os.path.join(root, x)).tolist()
        )
        self.y = df["y"].tolist()
        self.a = (
            df["a"].tolist() if train_attr == "yes" else [0] * len(df["a"].tolist())
        )
        self.transform_ = transform
        self._count_groups()

        if subsample_type is not None:
            # if split == 'va':
            #     # For DFR, we don't want the validation set still too large after subsampling.
            #     max_size = int(0.2 * len(df[df["split"] == (self.SPLITS["tr"])]))
            #     self.subsample(subsample_type, max_size=max_size)
            # else:
            #     self.subsample(subsample_type)
            self.subsample(subsample_type)

        if duplicates is not None:
            self.duplicate(duplicates)

    def _count_groups(self):
        self.weights_g, self.weights_y = [], []
        self.num_attributes = len(set(self.a))
        self.num_labels = len(set(self.y))
        self.group_sizes = [0] * self.num_attributes * self.num_labels
        self.class_sizes = [0] * self.num_labels

        for i in self.idx:
            self.group_sizes[self.num_attributes * self.y[i] + self.a[i]] += 1
            self.class_sizes[self.y[i]] += 1

        for i in self.idx:
            self.weights_g.append(
                len(self)
                / self.group_sizes[self.num_attributes * self.y[i] + self.a[i]]
            )
            self.weights_y.append(len(self) / self.class_sizes[self.y[i]])

    def subsample(self, subsample_type, max_size=None):
        assert subsample_type in {"group", "class"}
        perm = torch.randperm(len(self)).tolist()
        min_size = (
            min(list(self.group_sizes))
            if subsample_type == "group"
            else min(list(self.class_sizes))
        )

        counts_g = [0] * self.num_attributes * self.num_labels
        counts_y = [0] * self.num_labels
        new_idx = []
        for p in perm:
            y, a = self.y[self.idx[p]], self.a[self.idx[p]]
            if (
                subsample_type == "group"
                and counts_g[self.num_attributes * int(y) + int(a)] < min_size
            ) or (subsample_type == "class" and counts_y[int(y)] < min_size):
                counts_g[self.num_attributes * int(y) + int(a)] += 1
                counts_y[int(y)] += 1
                new_idx.append(self.idx[p])

        self.idx = new_idx
        if max_size is not None:
            # random uniform subsample
            self.idx = np.random.choice(self.idx, max_size, replace=False).tolist()
        self._count_groups()

    def duplicate(self, duplicates):
        new_idx = []
        for i, duplicate in zip(self.idx, duplicates):
            new_idx += [i] * duplicate
        self.idx = new_idx
        self._count_groups()

    def __getitem__(self, index):
        i = self.idx[index]
        x = self.transform(self.x[i])
        y = torch.tensor(self.y[i], dtype=torch.long)
        a = torch.tensor(self.a[i], dtype=torch.long)
        return i, x, y, a

    def __len__(self):
        return len(self.idx)


# class CMNIST(SubpopDataset):
#     N_STEPS = 15001
#     CHECKPOINT_FREQ = 200
#     # INPUT_SHAPE = (
#     #     3,
#     #     224,
#     #     224,
#     # )
#     INPUT_SHAPE = (
#         3,
#         28,
#         28,
#     )
#     N_WORKERS = 4
#     data_type = "images"

#     def __init__(
#         self,
#         data_path,
#         split,
#         hparams,
#         train_attr="yes",
#         subsample_type=None,
#         duplicates=None,
#         metadata=None,
#     ):
#         root = Path(data_path) / "cmnist"
#         mnist = datasets.MNIST(root, train=True)
#         X, y = mnist.data, mnist.targets

#         if split == "tr":
#             if hparams["data_size"]:
#                 X, y = X[: hparams["data_size"]], y[: hparams["data_size"]]
#             else:
#                 X, y = X[:30000], y[:30000]
#         elif split == "va":
#             X, y = X[30000:40000], y[30000:40000]
#         elif split == "te":
#             X, y = X[40000:], y[40000:]
#         else:
#             raise NotImplementedError

#         rng = np.random.default_rng(hparams["data_seed"])

#         self.binary_label = np.bitwise_xor(
#             y >= 5, (rng.random(len(y)) < hparams["cmnist_flip_prob"])
#         ).numpy()
#         self.color = np.bitwise_xor(
#             self.binary_label, (rng.random(len(y)) < hparams["cmnist_spur_prob"])
#         )
#         self.imgs = torch.stack([X, X, torch.zeros_like(X)], dim=1).numpy()
#         self.imgs[list(range(len(self.imgs))), (1 - self.color), :, :] *= 0

#         if split == "tr":
#             # subsample color = 0
#             if hparams["cmnist_attr_prob"] > 0.5:
#                 n_samples_0 = int(
#                     (self.color == 1).sum()
#                     * (1 - hparams["cmnist_attr_prob"])
#                     / hparams["cmnist_attr_prob"]
#                 )
#                 self._subsample(self.color == 0, n_samples_0, rng)
#             # subsample color = 1
#             elif hparams["cmnist_attr_prob"] < 0.5:
#                 n_samples_1 = int(
#                     (self.color == 0).sum()
#                     * hparams["cmnist_attr_prob"]
#                     / (1 - hparams["cmnist_attr_prob"])
#                 )
#                 self._subsample(self.color == 1, n_samples_1, rng)

#             # subsample y = 0
#             if hparams["cmnist_label_prob"] > 0.5:
#                 n_samples_0 = int(
#                     (self.binary_label == 1).sum()
#                     * (1 - hparams["cmnist_label_prob"])
#                     / hparams["cmnist_label_prob"]
#                 )
#                 self._subsample(self.binary_label == 0, n_samples_0, rng)
#             # subsample y = 1
#             elif hparams["cmnist_label_prob"] < 0.5:
#                 n_samples_1 = int(
#                     (self.binary_label == 0).sum()
#                     * hparams["cmnist_label_prob"]
#                     / (1 - hparams["cmnist_label_prob"])
#                 )
#                 self._subsample(self.binary_label == 1, n_samples_1, rng)

#         self.idx = list(range(len(self.color)))
#         self.x = torch.from_numpy(self.imgs).float() / 255.0
#         self.y = self.binary_label
#         self.a = self.color

#         self.transform_ = transforms.Compose(
#             [
#                 # transforms.Resize((224, 224)),
#                 transforms.Normalize((0.1307, 0.1307, 0.0), (0.3081, 0.3081, 0.3081)),
#             ]
#         )
#         self._count_groups()

#         print("group size: ", self.group_sizes)
#         print("class size: ", self.class_sizes)
#         print("Total size:", sum(self.group_sizes))

#         if subsample_type is not None:
#             self.subsample(subsample_type)

#         if duplicates is not None:
#             self.duplicate(duplicates)

#     def _subsample(self, mask, n_samples, rng):
#         assert n_samples <= mask.sum()
#         idxs = np.concatenate(
#             (
#                 np.nonzero(~mask)[0],
#                 rng.choice(np.nonzero(mask)[0], size=n_samples, replace=False),
#             )
#         )
#         rng.shuffle(idxs)
#         self.imgs = self.imgs[idxs]
#         self.color = self.color[idxs]
#         self.binary_label = self.binary_label[idxs]

#     def transform(self, x):
#         return self.transform_(x)


class CMNIST(SubpopDataset):
    N_STEPS = 5001
    CHECKPOINT_FREQ = 200
    # INPUT_SHAPE = (
    #     3,
    #     224,
    #     224,
    # )
    INPUT_SHAPE = (
        3,
        28,
        28,
    )
    N_WORKERS = 4
    data_type = "images"

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata=None,
        downsample_pixel=True,
    ):
        root = Path(data_path) / "cmnist"
        mnist = datasets.MNIST(root, train=True)
        X, y = mnist.data, mnist.targets

        if split == "tr":
            X, y = X[:30000], y[:30000]
        elif split == "va":
            X, y = X[30000:40000], y[30000:40000]
        elif split == "te":
            X, y = X[40000:], y[40000:]
        else:
            raise NotImplementedError

        rng = np.random.default_rng(hparams["data_seed"])

        self.binary_label = np.bitwise_xor(
            y >= 5, (rng.random(len(y)) < hparams["cmnist_flip_prob"])
        ).numpy()
        self.color = np.bitwise_xor(self.binary_label, (rng.random(len(y)) < 0.5))
        self.imgs = torch.stack([X, X, torch.zeros_like(X)], dim=1).numpy()
        self.imgs[list(range(len(self.imgs))), (1 - self.color), :, :] *= 0

        if split == "tr":
            group_matrix = np.array(
                [
                    [1, 0, 0, 1],
                    [1, 0, 1, 0],
                    [1, 1, 0, 0],
                    [1, 1, 1, 1],
                ]
            )
            b = [
                hparams["cmnist_spur_prob"] * hparams["data_size"],
                hparams["cmnist_attr_prob"] * hparams["data_size"],
                hparams["cmnist_label_prob"] * hparams["data_size"],
                hparams["data_size"],
            ]
            num_per_group = np.linalg.solve(group_matrix, b).astype(int)
            print("number per group to be sampled:", num_per_group)

            # count group
            group_idx = [[], [], [], []]
            for i in range(len(self.color)):
                group_idx[self.binary_label[i] * 2 + self.color[i]].append(i)
            # subsample to the given number
            for i, n in enumerate(num_per_group):
                group_idx[i] = np.random.choice(group_idx[i], n, replace=False)
            group_idx = np.concatenate(group_idx)
            self.color = self.color[group_idx]
            self.binary_label = self.binary_label[group_idx]
            self.imgs = self.imgs[group_idx]

        self.idx = list(range(len(self.color)))
        self.x = torch.from_numpy(self.imgs).float() / 255.0
        if downsample_pixel:
            self.x = self.x[:, :, ::2, ::2]
        self.y = self.binary_label
        self.a = self.color

        if hparams["oversample"] and split == "tr":
            print("oversampling to balance...")
            # get group labels
            group = self.y * 2 + self.a
            idxs = oversample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))
        elif hparams["undersample"] and split == "tr":
            print("undersampling to balance...")
            group = self.y * 2 + self.a
            idxs = undersample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))

        self.transform_ = transforms.Compose(
            [
                # transforms.Resize((224, 224)),
                transforms.Normalize((0.1307, 0.1307, 0.0), (0.3081, 0.3081, 0.3081)),
            ]
        )
        self._count_groups()

        print("group size: ", self.group_sizes)
        print("class size: ", self.class_sizes)
        print("Total size:", sum(self.group_sizes))

        if subsample_type is not None:
            self.subsample(subsample_type)

        if duplicates is not None:
            self.duplicate(duplicates)

    def _subsample(self, mask, n_samples, rng):
        assert n_samples <= mask.sum()
        idxs = np.concatenate(
            (
                np.nonzero(~mask)[0],
                rng.choice(np.nonzero(mask)[0], size=n_samples, replace=False),
            )
        )
        rng.shuffle(idxs)
        self.imgs = self.imgs[idxs]
        self.color = self.color[idxs]
        self.binary_label = self.binary_label[idxs]

    def transform(self, x):
        return self.transform_(x)


class Waterbirds(SubpopDataset):
    CHECKPOINT_FREQ = 300
    INPUT_SHAPE = (
        3,
        224,
        224,
    )

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_waterbirds.csv",
    ):
        root = os.path.join(
            data_path, "waterbirds", "waterbird_complete95_forest2water2"
        )
        metadata = os.path.join(data_path, "waterbirds", metadata)
        transform = transforms.Compose(
            [
                transforms.Resize(
                    (
                        int(224 * (256 / 224)),
                        int(224 * (256 / 224)),
                    )
                ),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.data_type = "images"
        super().__init__(
            root, split, metadata, transform, train_attr, subsample_type, duplicates
        )

    def transform(self, x):
        return self.transform_(Image.open(x).convert("RGB"))


class COCO(SubpopDataset):
    # # N_STEPS = 30001
    # N_STEPS = 1001
    # # CHECKPOINT_FREQ = 1000
    # CHECKPOINT_FREQ = 50
    # INPUT_SHAPE = (
    #     3,
    #     224,
    #     224,
    # )
    # # N_WORKERS = 8
    # N_WORKERS = 4

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
    ):
        root = data_path
        assert hparams["metadata"] is not None
        metadata = hparams["metadata"]
        print("use custom metadata file....")
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]
        )

        ######################
        def extract_path(x, split):
            if split == "tr":
                path = os.path.join(root, "train", "dog", x)
                # check if file exists
                if not os.path.exists(path):
                    path = os.path.join(root, "train", "cat", x)
            elif split == "te":
                path = os.path.join(root, "val_out_of_domain", "dog", x)
                # check if file exists
                if not os.path.exists(path):
                    path = os.path.join(root, "val_out_of_domain", "cat", x)
            return path

        df = pd.read_csv(metadata)
        df = df[df["split"] == (self.SPLITS[split])]
        print("metadata example: ", df.head())

        self.idx = list(range(len(df)))
        # Need to adapt to specific file structure to improve I/O performance
        self.x = (
            df["filename"].astype(str).map(lambda x: extract_path(x, split)).to_numpy()
        )
        self.y = df["y"].to_numpy()
        self.a = (
            df["a"].to_numpy() if train_attr == "yes" else [0] * len(df["a"].to_numpy())
        )

        self.transform_ = transform

        if subsample_type is not None:
            # if split == 'va':
            #     # For DFR, we don't want the validation set still too large after subsampling.
            #     max_size = int(0.2 * len(df[df["split"] == (self.SPLITS["tr"])]))
            #     self.subsample(subsample_type, max_size=max_size)
            # else:
            #     self.subsample(subsample_type)
            self.subsample(subsample_type)

        if duplicates is not None:
            self.duplicate(duplicates)

        if hparams["oversample"] and split == "tr":
            print("oversampling to balance...")
            # get group labels
            group = self.y * 2 + self.a
            idxs = oversample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))
        elif hparams["undersample"] and split == "tr":
            print("undersampling to balance...")
            group = self.y * 2 + self.a
            idxs = undersample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))
        ######################
        self._count_groups()
        print("group size: ", self.group_sizes)
        print("class size: ", self.class_sizes)

    def transform(self, x):
        return self.transform_(Image.open(x).convert("RGB"))


class CelebA(SubpopDataset):
    # N_STEPS = 30001
    N_STEPS = 1001
    # CHECKPOINT_FREQ = 1000
    CHECKPOINT_FREQ = 50
    INPUT_SHAPE = (
        3,
        224,
        224,
    )
    # N_WORKERS = 8
    N_WORKERS = 4

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
    ):
        root = os.path.join(data_path, "CelebFaces", "img_align_celeba")
        if hparams["metadata"] is not None:
            metadata = hparams["metadata"]
            print("use custom metadata file....")
        else:
            metadata = os.path.join(data_path, "celeba", "metadata_celeba.csv")
        transform = transforms.Compose(
            [
                transforms.CenterCrop(178),
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.data_type = "images"
        # super().__init__(
        #     root, split, metadata, transform, train_attr, subsample_type, duplicates
        # )

        ######################
        df = pd.read_csv(metadata)
        df = df[df["split"] == (self.SPLITS[split])]
        print("metadata example: ", df.head())

        self.idx = list(range(len(df)))
        # Need to adapt to specific file structure to improve I/O performance
        self.x = (
            df["filename"]
            .astype(str)
            .map(lambda x: os.path.join(root, x[:3], x))
            .to_numpy()
        )
        self.y = df["y"].to_numpy()
        self.a = (
            df["a"].to_numpy() if train_attr == "yes" else [0] * len(df["a"].to_numpy())
        )
        self.transform_ = transform

        if subsample_type is not None:
            # if split == 'va':
            #     # For DFR, we don't want the validation set still too large after subsampling.
            #     max_size = int(0.2 * len(df[df["split"] == (self.SPLITS["tr"])]))
            #     self.subsample(subsample_type, max_size=max_size)
            # else:
            #     self.subsample(subsample_type)
            self.subsample(subsample_type)

        if duplicates is not None:
            self.duplicate(duplicates)

        if hparams["oversample"] and split == "tr":
            print("oversampling to balance...")
            # get group labels
            group = self.y * 2 + self.a
            idxs = oversample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))
        elif hparams["undersample"] and split == "tr":
            print("undersampling to balance...")
            group = self.y * 2 + self.a
            idxs = undersample(group, 4)
            self.x = self.x[idxs]
            self.y = self.y[idxs]
            self.a = self.a[idxs]
            self.idx = list(range(len(self.y)))
        ######################
        self._count_groups()
        print("group size: ", self.group_sizes)
        print("class size: ", self.class_sizes)

    def transform(self, x):
        return self.transform_(Image.open(x).convert("RGB"))


class MultiNLI(SubpopDataset):
    N_STEPS = 30001
    CHECKPOINT_FREQ = 1000

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_multinli.csv",
    ):
        root = os.path.join(data_path, "multinli", "glue_data", "MNLI")
        metadata = os.path.join(data_path, "multinli", metadata)

        self.features_array = []
        assert hparams["text_arch"] == "bert-base-uncased"
        for feature_file in [
            "cached_train_bert-base-uncased_128_mnli",
            "cached_dev_bert-base-uncased_128_mnli",
            "cached_dev_bert-base-uncased_128_mnli-mm",
        ]:
            features = torch.load(os.path.join(root, feature_file))
            self.features_array += features

        self.all_input_ids = torch.tensor(
            [f.input_ids for f in self.features_array], dtype=torch.long
        )
        self.all_input_masks = torch.tensor(
            [f.input_mask for f in self.features_array], dtype=torch.long
        )
        self.all_segment_ids = torch.tensor(
            [f.segment_ids for f in self.features_array], dtype=torch.long
        )
        self.all_label_ids = torch.tensor(
            [f.label_id for f in self.features_array], dtype=torch.long
        )
        self.x_array = torch.stack(
            (self.all_input_ids, self.all_input_masks, self.all_segment_ids), dim=2
        )
        self.data_type = "text"
        super().__init__(
            "", split, metadata, self.transform, train_attr, subsample_type, duplicates
        )

    def transform(self, i):
        return self.x_array[int(i)]


class CivilComments(SubpopDataset):
    N_STEPS = 30001
    CHECKPOINT_FREQ = 1000

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        granularity="coarse",
    ):
        text = pd.read_csv(
            os.path.join(
                data_path, "civilcomments/civilcomments_{}.csv".format(granularity)
            )
        )
        metadata = os.path.join(
            data_path,
            "civilcomments",
            "metadata_civilcomments_{}.csv".format(granularity),
        )

        self.text_array = list(text["comment_text"])
        if hparams["text_arch"] == "bert-base-uncased":
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        elif hparams["text_arch"] in [
            "xlm-roberta-base",
            "allenai/scibert_scivocab_uncased",
        ]:
            self.tokenizer = AutoTokenizer.from_pretrained(hparams["text_arch"])
        elif hparams["text_arch"] == "distilbert-base-uncased":
            self.tokenizer = DistilBertTokenizer.from_pretrained(
                "distilbert-base-uncased"
            )
        elif hparams["text_arch"] == "gpt2":
            self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            raise NotImplementedError
        self.data_type = "text"
        super().__init__(
            "", split, metadata, self.transform, train_attr, subsample_type, duplicates
        )

    def transform(self, i):
        text = self.text_array[int(i)]
        tokens = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=220,
            return_tensors="pt",
        )

        if len(tokens) == 3:
            return torch.squeeze(
                torch.stack(
                    (
                        tokens["input_ids"],
                        tokens["attention_mask"],
                        tokens["token_type_ids"],
                    ),
                    dim=2,
                ),
                dim=0,
            )
        else:
            return torch.squeeze(
                torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2),
                dim=0,
            )


class CivilCommentsFine(CivilComments):

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
    ):
        super().__init__(
            data_path, split, hparams, train_attr, subsample_type, duplicates, "fine"
        )


class MetaShift(SubpopDataset):
    CHECKPOINT_FREQ = 300
    INPUT_SHAPE = (
        3,
        224,
        224,
    )

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_metashift.csv",
    ):
        metadata = os.path.join(data_path, "metashift", metadata)

        transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.data_type = "images"
        super().__init__(
            "/", split, metadata, transform, train_attr, subsample_type, duplicates
        )

    def transform(self, x):
        return self.transform_(Image.open(x).convert("RGB"))


class ImagenetBG(SubpopDataset):
    INPUT_SHAPE = (
        3,
        224,
        224,
    )
    SPLITS = {
        "tr": "train",
        "va": "val",
        "te": "test",
        "mixed_rand": "mixed_rand",
        "no_fg": "no_fg",
        "only_fg": "only_fg",
    }
    EVAL_SPLITS = ["te", "mixed_rand", "no_fg", "only_fg"]
    N_STEPS = 10001
    CHECKPOINT_FREQ = 500

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata.csv",
    ):
        metadata = os.path.join(data_path, "backgrounds_challenge", metadata)

        transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.data_type = "images"
        super().__init__(
            "/", split, metadata, transform, train_attr, subsample_type, duplicates
        )

    def transform(self, x):
        return self.transform_(Image.open(x).convert("RGB"))


#####################################################################################
#####################################################################################
#####################################################################################
#####################################################################################
class BaseImageDataset(SubpopDataset):

    def __init__(
        self, metadata, split, train_attr="yes", subsample_type=None, duplicates=None
    ):
        transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        self.data_type = "images"
        super().__init__(
            "/", split, metadata, transform, train_attr, subsample_type, duplicates
        )

    def transform(self, x):
        if (
            self.__class__.__name__ in ["MIMICNoFinding", "CXRMultisite"]
            and "MIMIC-CXR-JPG" in x
        ):
            reduced_img_path = list(Path(x).parts)
            reduced_img_path[-5] = "downsampled_files"
            reduced_img_path = Path(*reduced_img_path).with_suffix(".png")

            if reduced_img_path.is_file():
                x = str(reduced_img_path.resolve())

        return self.transform_(Image.open(x).convert("RGB"))


class NICOpp(BaseImageDataset):
    N_STEPS = 30001
    CHECKPOINT_FREQ = 1000
    INPUT_SHAPE = (
        3,
        224,
        224,
    )

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata.csv",
    ):
        metadata = os.path.join(data_path, "nicopp", metadata)
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class MIMICNoFinding(BaseImageDataset):
    N_STEPS = 20001
    CHECKPOINT_FREQ = 1000
    N_WORKERS = 16
    INPUT_SHAPE = (
        3,
        224,
        224,
    )

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_no_finding.csv",
    ):
        metadata = os.path.join(
            data_path, "MIMIC-CXR-JPG", "subpop_bench_meta", metadata
        )
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class CheXpertNoFinding(BaseImageDataset):
    N_STEPS = 20001
    CHECKPOINT_FREQ = 1000
    N_WORKERS = 16
    INPUT_SHAPE = (
        3,
        224,
        224,
    )

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_no_finding.csv",
    ):
        metadata = os.path.join(data_path, "chexpert", "subpop_bench_meta", metadata)
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class CXRMultisite(BaseImageDataset):
    N_STEPS = 20001
    CHECKPOINT_FREQ = 1000
    N_WORKERS = 16
    INPUT_SHAPE = (
        3,
        224,
        224,
    )
    SPLITS = {"tr": 0, "va": 1, "te": 2, "deploy": 3}
    EVAL_SPLITS = ["te", "deploy"]

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_multisite.csv",
    ):
        metadata = os.path.join(
            data_path, "MIMIC-CXR-JPG", "subpop_bench_meta", metadata
        )
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class MIMICNotes(SubpopDataset):
    N_STEPS = 10001
    CHECKPOINT_FREQ = 200
    INPUT_SHAPE = (10000,)

    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="yes",
        subsample_type=None,
        duplicates=None,
        metadata="metadata.csv",
    ):
        assert hparams["text_arch"] == "bert-base-uncased"
        metadata = os.path.join(data_path, "mimic_notes", "subpop_bench_meta", metadata)
        self.x_array = np.load(os.path.join(data_path, "mimic_notes", "features.npy"))
        self.data_type = "tabular"
        super().__init__(
            "", split, metadata, self.transform, train_attr, subsample_type, duplicates
        )

    def transform(self, x):
        return self.x_array[int(x), :].astype("float32")


class BREEDSBase(BaseImageDataset):
    N_STEPS = 60_001
    CHECKPOINT_FREQ = 2000
    N_WORKERS = 16
    INPUT_SHAPE = (
        3,
        224,
        224,
    )
    SPLITS = {"tr": 0, "va": 1, "te": 2, "zs": 3}
    EVAL_SPLITS = ["te", "zs"]


class Living17(BREEDSBase):
    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="no",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_living17.csv",
    ):
        metadata = os.path.join(data_path, "breeds", metadata)
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class Entity13(BREEDSBase):
    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="no",
        subsample_type=None,
        duplicates=None,
        metadata="metadata_entity13.csv",
    ):
        metadata = os.path.join(data_path, "breeds", metadata)
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class Entity30(BREEDSBase):
    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="no",
        subsample_type=None,
        duplicates=None,
    ):
        metadata = os.path.join(data_path, "breeds", "metadata_entity30.csv")
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


class Nonliving26(BREEDSBase):
    def __init__(
        self,
        data_path,
        split,
        hparams,
        train_attr="no",
        subsample_type=None,
        duplicates=None,
    ):
        metadata = os.path.join(data_path, "breeds", "metadata_nonliving26.csv")
        super().__init__(metadata, split, train_attr, subsample_type, duplicates)


def oversample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        if group_counts[group_idx] < max(group_counts):
            for _ in range(max(group_counts) // group_counts[group_idx]):
                resampled_idx.append(idx)
            resampled_idx.append(
                np.random.choice(
                    idx, max(group_counts) % group_counts[group_idx], replace=False
                )
            )
        else:
            resampled_idx.append(idx)
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx


def undersample(g, n_groups):
    group_counts = []
    for group_idx in range(n_groups):
        group_counts.append((g == group_idx).sum())
    resampled_idx = []
    for group_idx in range(n_groups):
        (idx,) = np.where(g == group_idx)
        resampled_idx.append(np.random.choice(idx, min(group_counts), replace=False))
    resampled_idx = np.concatenate(resampled_idx)
    return resampled_idx
