# Copyright (c) 2021 Copyright holder of the paper "Test-Time Adaptation to Distribution Shifts by Confidence Maximization and Input Transformation" submitted to NeurIPS 2021 for review
# All rights reserved.

import os
import itertools
import copy

import numpy as np
import h5py
import torch
from torchvision import transforms
from torch.utils.data import Dataset
import pandas as pd


def corrupted_imagenet(root, corruption, severity):
    root = os.path.join(root, corruption, str(severity), "dataset.h5")
    return H5Dataset(root)


def get_datasets(config) -> dict:
    if config.get("data_dir", "").endswith(".h5"):
        dataset_name = config["data_dir"].split("/")[-2]
        dataset_name = dataset_name.replace("-", "_")
        dataset_name += "-0"
        return {dataset_name: H5Dataset(config["data_dir"])}

    if "imagenet_c" in config:
        imagenet_c_datasets = {}
        for corruption, severity in itertools.product(
            config["imagenet_c"]["corruptions"], config["imagenet_c"]["severities"]
        ):
            ds = corrupted_imagenet(
                config["imagenet_c"]["data_dir"], corruption, severity
            )
            imagenet_c_datasets[f"{corruption}-{severity}"] = ds
        return imagenet_c_datasets


def split_dataset(dataset, by: str, frac: float, random_state=None):
    """Split dataset

    Arguments:
        dataset: Dataset to split
        by: Split on ``class`` or ``samples_per_class``
            class: The classes will be split between the datasets
            samples_per_class: Both datasets will have a fraction of samples of all classes
        frac: Fraction of items in dataset 1
        random_state: Random seed passed to pandas sample

    Returns:
        Two datasets split either by samples or class.
    """
    if isinstance(dataset, H5Dataset):
        df = pd.DataFrame({"class": list(dataset.labels)})
    else:
        df = pd.DataFrame(dataset.samples, columns=["samples_per_class", "class"])
        dataset_1 = copy.copy(dataset)
        dataset_2 = copy.copy(dataset)

    if by == "samples_per_class":
        samples_1 = df.groupby("class").sample(frac=frac, random_state=random_state)
        samples_2 = df.drop(samples_1.index)
    elif by == "class":
        all_classes = pd.Series(df["class"].unique())
        classes_1 = all_classes.sample(frac=frac, random_state=random_state)
        mask = df["class"].isin(classes_1)
        samples_1 = df[mask]
        samples_2 = df[~mask]
    else:
        raise ValueError(
            f"Can only split by `class` or `samples_per_class`, got `{by}`"
        )

    if isinstance(dataset, H5Dataset):
        dataset_1 = H5DatasetSubset(dataset, samples_1.index)
        dataset_2 = H5DatasetSubset(dataset, samples_2.index)
    else:
        for ds, samples in zip([dataset_1, dataset_2], [samples_1, samples_2]):
            ds.samples = list(samples.itertuples(index=False, name=None))
            ds.imgs = ds.samples

    return dataset_1, dataset_2


def to_dataloader(dataset, config, shuffle=False):
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=config["test_time_batch_size"],
        num_workers=12,
        shuffle=shuffle,  # It is important to shuffle the data for entropy minimization
        pin_memory=True,
    )


class H5Dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.f = h5py.File(file_path, mode="r")
        self.imgs = self.f["imgs"]
        # Casting to np.array so that fancy-indexing is faster
        self.labels = np.array(self.f["labels"])

        if transform is None:
            self.transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                    ),
                ]
            )
        else:
            self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.transform(self.imgs[idx]), self.labels[idx]

    @property
    def used_labels(self):
        return list(set(self.labels))

    @property
    def unused_labels(self):
        return list(set(range(1000)) - set(self.used_labels))


class H5DatasetSubset(Dataset):
    """A thin wrapper around an H5Dataset that only indexes a subset defined by `self.idx_map`"""

    def __init__(self, dataset: H5Dataset, idx_map):
        self.dataset = dataset
        self.idx_map = idx_map

    def __len__(self):
        return len(self.idx_map)

    def __getitem__(self, idx):
        idx = self.idx_map[idx]
        return self.dataset[idx]

    @property
    def used_labels(self):
        labels = self.dataset.labels[self.idx_map]
        return list(set(labels))

    @property
    def unused_labels(self):
        return list(set(range(1000)) - set(self.used_labels))
