import json
import os
import random
from collections import defaultdict

import albumentations
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from albumentations.pytorch import ToTensor
from torch.utils.data import Dataset


def load_corruption(path):
    data = np.load(path)
    return np.array(np.array_split(data, 5))


class CIFARDataset(Dataset):
    def __init__(self, split, root_dir):

        if split == 'train':
            self.root_dir = root_dir + 'CIFAR-10-C-new/train/'
            corruptions = ['gaussian_noise', 'shot_noise', 'defocus_blur', 'glass_blur',
                           'zoom_blur', 'snow', 'frost', 'brightness', 'contrast', 'pixelate']
            other_idx = [0, 1, 2, 5, 6, 7]

        if split == 'val':
            self.root_dir = root_dir + 'CIFAR-10-C-new/val/'
            corruptions = ['speckle_noise', 'gaussian_blur', 'saturate']
            other_idx = [3, 9]

        if split == 'test':
            self.root_dir = root_dir + 'CIFAR-10-C/'
            corruptions = ['impulse_noise',
                           'motion_blur', 'fog', 'elastic_transform']
            other_idx = [4, 8]

        print("Loading CIFAR-10-C")
        other = [load_corruption(self.root_dir + corruption + '.npy')
                 for corruption in ['spatter', 'jpeg_compression']]
        other = np.concatenate(other, axis=0)[other_idx]

        data = [load_corruption(self.root_dir + corruption + '.npy')
                for corruption in corruptions]
        data = np.concatenate(data, axis=0)

        self._X = np.concatenate([other, data], axis=0)

        n_images_per_group = self._X.shape[1]

        self.n_groups = self._X.shape[0]
        self.groups = list(range(self.n_groups))
        self.image_shape = (3, 32, 32)
        self._X = self._X.reshape((-1, 32, 32, 3))
        self.num_classes = 10

        # Store information about domains (groups) of examples
        self.group_ids_boundaries = {}

        if split == 'test':
            n_images = 10000
            self._y = np.load(self.root_dir + 'labels.npy')[:n_images]
            self._y = np.tile(self._y, self.n_groups)
            self.group_ids = np.array(
                [[i]*n_images for i in range(self.n_groups)]).flatten()
            self.group_ids_boundaries = [
                (i*n_images, (i + 1) * n_images - 1) for i in range(self.n_groups)]
        else:
            n_images = 1000
            other_labels = [load_corruption(self.root_dir + corruption + '_labels.npy')
                            for corruption in ['spatter', 'jpeg_compression']]
            other_labels = np.concatenate(other_labels, axis=0)[other_idx]
            data_labels = [load_corruption(
                self.root_dir + corruption + '_labels.npy') for corruption in corruptions]
            data_labels = np.concatenate(data_labels, axis=0)
            self._y = np.concatenate(
                [other_labels, data_labels], axis=0).flatten()
            self.group_ids = np.array(
                [[i]*n_images for i in range(self.n_groups)]).flatten()
            self.group_ids_boundaries = [
                (i*n_images, (i + 1) * n_images - 1) for i in range(self.n_groups)]

        self._len = len(self.group_ids)
        print("Loaded")

        self.group_counts, _ = np.histogram(self.group_ids,
                                            bins=range(self.n_groups + 1),
                                            density=False)
        self.transform = get_transform(split)

        self.rng = None

        print("Split: ", split)
        print("n groups: ", self.n_groups)
        print("Dataset size: ", len(self._y))
        print("Smallest group: ", np.min(self.group_counts))
        print("Largest group: ", np.max(self.group_counts))

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        # Index determines the query example
        x_q = self.transform(**{'image': self._X[index]})['image']
        y_q = torch.tensor(self._y[index], dtype=torch.long)
        g_q = torch.tensor(self.group_ids[index], dtype=torch.long)

        # Support examples are generated randomly considering the domain
        # of the query example
        # We use seeded generators to ensure we use exactly the same tasks
        # during validation and testing
        if self.rng:
            s_idx = self.rng.randint(self.group_ids_boundaries[g_q.item(
            )][0], self.group_ids_boundaries[g_q.item()][1])
        else:
            s_idx = random.randint(self.group_ids_boundaries[g_q.item(
            )][0], self.group_ids_boundaries[g_q.item()][1])
        x_s = self.transform(**{'image': self._X[s_idx]})['image']
        y_s = torch.tensor(self._y[s_idx], dtype=torch.long)
        g_s = torch.tensor(self.group_ids[s_idx], dtype=torch.long)

        assert g_s == g_q, (g_s, g_q)

        return (x_s, x_q), (y_s, y_q), (g_s, g_q)


def get_transform(split):
    # We use data augmentation during training but not during evaluation
    if split == 'train':
        transform = albumentations.Compose([
            albumentations.RandomResizedCrop(32, 32, p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.Rotate(30, p=0.5),
            albumentations.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225], max_pixel_value=255,
                                     p=1.0, always_apply=True),
            ToTensor()
        ])
    else:
        transform = albumentations.Compose([
            albumentations.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225], max_pixel_value=255,
                                     p=1.0, always_apply=True),
            ToTensor()
        ])
    return transform
