"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.

Portions of the source code are from the OLTR project which
notice below and in LICENSE in the root directory of
this source tree.

Copyright (c) 2019, Zhongqi Miao
All rights reserved.
"""

import os

import numpy as np
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms


class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
    cls_num = 10

    def __init__(
        self, phase, imbalance_ratio, root="/gruntdata5/kaihua/datasets", imb_type="exp"
    ):
        train = phase == "train"
        super().__init__(
            root, train, transform=None, target_transform=None, download=True
        )
        self.train = train
        if self.train:
            img_num_list = self.get_img_num_per_cls(
                self.cls_num, imb_type, imbalance_ratio
            )
            self.gen_imbalanced_data(img_num_list)
            self.transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    # transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                    ),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    # transforms.Resize(224),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
                    ),
                ]
            )

        self.labels = self.targets

        print("{} Mode: Contain {} images".format(phase, len(self.data)))

    def _get_class_dict(self):
        class_dict = {}
        for i, anno in enumerate(self.get_annotations()):
            cat_id = anno["category_id"]
            if not cat_id in class_dict:
                class_dict[cat_id] = []
            class_dict[cat_id].append(i)
        return class_dict

    def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
        img_max = len(self.data) / cls_num
        img_num_per_cls = []
        if imb_type == "exp":
            for cls_idx in range(cls_num):
                num = img_max * (imb_factor ** (cls_idx / (cls_num - 1.0)))
                img_num_per_cls.append(int(num))
        elif imb_type == "step":
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max))
            for cls_idx in range(cls_num // 2):
                img_num_per_cls.append(int(img_max * imb_factor))
        else:
            img_num_per_cls.extend([int(img_max)] * cls_num)
        return img_num_per_cls

    def gen_imbalanced_data(self, img_num_per_cls):
        new_data = []
        new_targets = []
        targets_np = np.array(self.targets, dtype=np.int64)
        classes = np.unique(targets_np)

        self.num_per_cls_dict = {}
        for the_class, the_img_num in zip(classes, img_num_per_cls):
            self.num_per_cls_dict[the_class] = the_img_num
            idx = np.where(targets_np == the_class)[0]
            np.random.shuffle(idx)
            selec_idx = idx[:the_img_num]
            new_data.append(self.data[selec_idx, ...])
            new_targets.extend(
                [
                    the_class,
                ]
                * the_img_num
            )
        new_data = np.vstack(new_data)
        self.data = new_data
        self.targets = new_targets

    def __getitem__(self, index):
        img, label = self.data[index], self.labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return img, label, index

    def __len__(self):
        return len(self.labels)

    def get_num_classes(self):
        return self.cls_num

    def get_annotations(self):
        annos = []
        for label in self.labels:
            annos.append({"category_id": int(label)})
        return annos

    def get_cls_num_list(self):
        cls_num_list = []
        for i in range(self.cls_num):
            cls_num_list.append(self.num_per_cls_dict[i])
        return cls_num_list


class IMBALANCECIFAR100(IMBALANCECIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
    This is a subclass of the `CIFAR10` Dataset.
    """

    cls_num = 100
    base_folder = "cifar-100-python"
    url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]

    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]
    meta = {
        "filename": "meta",
        "key": "fine_label_names",
        "md5": "7973b15100ade9c7d40fb424638fde48",
    }


# Image statistics
RGB_statistics = {
    "iNaturalist18": {"mean": [0.466, 0.471, 0.380], "std": [0.195, 0.194, 0.192]},
    "default": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
}

# Data transformation with augmentation
def get_data_transform(split, rgb_mean, rbg_std, key="default"):
    data_transforms = {
        "train": transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(rgb_mean, rbg_std),
            ]
        )
        if key == "iNaturalist18"
        else transforms.Compose(
            [
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0
                ),
                transforms.ToTensor(),
                transforms.Normalize(rgb_mean, rbg_std),
            ]
        ),
        "val": transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(rgb_mean, rbg_std),
            ]
        ),
        "test": transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(rgb_mean, rbg_std),
            ]
        ),
    }
    return data_transforms[split]


# Dataset
class LT_Dataset(Dataset):
    def __init__(self, root, txt, transform=None, template=None, top_k=None):
        self.img_path = []
        self.labels = []
        self.transform = transform
        with open(txt) as f:
            for line in f:
                self.img_path.append(os.path.join(root, line.split()[0]))
                self.labels.append(int(line.split()[1]))
        # select top k class
        if top_k:
            # only select top k in training, in case train/val/test not matching.
            if "train" in txt:
                max_len = max(self.labels) + 1
                dist = [[i, 0] for i in range(max_len)]
                for i in self.labels:
                    dist[i][-1] += 1
                dist.sort(key=lambda x: x[1], reverse=True)
                # saving
                torch.save(dist, template + "_top_{}_mapping".format(top_k))
            else:
                # loading
                dist = torch.load(template + "_top_{}_mapping".format(top_k))
            selected_labels = {item[0]: i for i, item in enumerate(dist[:top_k])}
            # replace original path and labels
            self.new_img_path = []
            self.new_labels = []
            for path, label in zip(self.img_path, self.labels):
                if label in selected_labels:
                    self.new_img_path.append(path)
                    self.new_labels.append(selected_labels[label])
            self.img_path = self.new_img_path
            self.labels = self.new_labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):

        path = self.img_path[index]
        label = self.labels[index]

        with open(path, "rb") as f:
            sample = Image.open(f).convert("RGB")

        if self.transform is not None:
            sample = self.transform(sample)

        return sample, label, index


# Load datasets
def load_data(
    data_root,
    dataset,
    phase,
    batch_size,
    top_k_class=None,
    sampler_dic=None,
    num_workers=4,
    shuffle=True,
    cifar_imb_ratio=None,
):

    txt_split = phase
    txt = "./data/%s/%s_%s.txt" % (dataset, dataset, txt_split)
    template = "./data/%s/%s" % (dataset, dataset)

    print("Loading data from %s" % (txt))

    if dataset == "iNaturalist18":
        print("===> Loading iNaturalist18 statistics")
        key = "iNaturalist18"
    else:
        key = "default"

    if dataset == "CIFAR100_LT":
        print("====> CIFAR100 Imbalance Ratio: ", cifar_imb_ratio)
        set_ = IMBALANCECIFAR100(phase, imbalance_ratio=cifar_imb_ratio, root=data_root)
    else:
        rgb_mean, rgb_std = RGB_statistics[key]["mean"], RGB_statistics[key]["std"]
        if phase not in ["train", "val"]:
            transform = get_data_transform("test", rgb_mean, rgb_std, key)
        else:
            transform = get_data_transform(phase, rgb_mean, rgb_std, key)
        print("Use data transformation:", transform)

        set_ = LT_Dataset(
            data_root, txt, transform, template=template, top_k=top_k_class
        )

    print(len(set_))

    if sampler_dic and phase == "train":
        print("=====> Using sampler: ", sampler_dic["sampler"])
        # print('Sample %s samples per-class.' % sampler_dic['num_samples_cls'])
        print("=====> Sampler parameters: ", sampler_dic["params"])
        return DataLoader(
            dataset=set_,
            batch_size=batch_size,
            shuffle=False,
            sampler=sampler_dic["sampler"](set_, **sampler_dic["params"]),
            num_workers=num_workers,
        )
    else:
        print("=====> No sampler.")
        print("=====> Shuffle is %s." % (shuffle))
        return DataLoader(
            dataset=set_,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
        )
