import json
import os
import random
import re
from collections import defaultdict

import albumentations
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from albumentations.pytorch import ToTensor
from PIL import Image
from torch.utils.data import Dataset


class ImageNetDataset(Dataset):

    def __init__(self, split, root_dir):
        if split == 'train':
            self.root_dir = root_dir + 'Tiny-ImageNet-C-new/train/'
            corruptions = ['gaussian_noise', 'shot_noise', 'defocus_blur',
                           'glass_blur', 'zoom_blur', 'snow', 'brightness', 'contrast', 'pixelate']
            frost_idx = [1, 2, 3]
            jpeg_idx = [1, 2, 3]
        if split == 'val':
            self.root_dir = root_dir + 'Tiny-ImageNet-C-new/val'
            corruptions = ['speckle_noise', 'gaussian_blur', 'saturate']
            frost_idx = [4]
            jpeg_idx = [5]
        if split == 'test':
            self.root_dir = root_dir + 'Tiny-ImageNet-C/'
            corruptions = ['impulse_noise',
                           'motion_blur', 'fog', 'elastic_transform']
            frost_idx = [5]
            jpeg_idx = [4]
        print("Loading TinyImageNet-c")

        self.num_classes = 200

        data = []
        for level in frost_idx:
            data.extend(self.construct_imdb('frost', level))
        for level in jpeg_idx:
            data.extend(self.construct_imdb('jpeg_compression', level))
        for corruption in corruptions:
            for level in [1, 2, 3, 4, 5]:
                data.extend(self.construct_imdb(corruption, level))
        self._X = data
        self.n_groups = len(frost_idx) + len(jpeg_idx) + 5*len(corruptions)
        self.groups = list(range(self.n_groups))

        self.image_shape = (3, 64, 64)

        # Store information about domains (groups) of examples
        self.group_ids_boundaries = {}

        if split == 'test':
            self.group_ids = np.array(
                [[i]*50*self.num_classes for i in range(self.n_groups)]).flatten()
            self.group_ids_boundaries = [
                (i*50*self.num_classes, (i + 1) * 50*self.num_classes - 1) for i in range(self.n_groups)]
        else:
            self.group_ids = np.array(
                [[i]*10*self.num_classes for i in range(self.n_groups)]).flatten()
            self.group_ids_boundaries = [
                (i*10*self.num_classes, (i + 1) * 10*self.num_classes - 1) for i in range(self.n_groups)]

        self._len = len(self.group_ids)
        print("Loaded")

        self.group_counts, _ = np.histogram(self.group_ids,
                                            bins=range(self.n_groups + 1),
                                            density=False)

        self.transform = get_transform(split)

        self.rng = None

        print("Split: ", split)
        print("n groups: ", self.n_groups)
        print("Dataset size: ", len(self._X))

        print("Smallest group: ", np.min(self.group_counts))
        print("Largest group: ", np.max(self.group_counts))

    def construct_imdb(self, corruption, level):
        """Constructs the imdb."""
        # Compile the split data path
        split_path = os.path.join(self.root_dir, corruption, str(level))
        re_pattern = r"^n[0-9]+$"
        # Images are stored per class in subdirs (format: n<number>)
        class_ids = sorted(f for f in os.listdir(
            split_path) if re.match(re_pattern, f))
        # select a subset of the classes
        class_ids = class_ids[:self.num_classes]
        # Map ImageNet class ids to contiguous ids
        class_id_cont_id = {v: i for i, v in enumerate(class_ids)}
        # Construct the image db
        imdb = []
        for class_id in class_ids:
            cont_id = class_id_cont_id[class_id]
            im_dir = os.path.join(split_path, class_id)
            for im_name in os.listdir(im_dir):
                imdb.append(
                    {"im_path": os.path.join(
                        im_dir, im_name), "class": cont_id}
                )
        return imdb

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        # Index determines the query example
        x_q = Image.open(self._X[index]["im_path"])
        x_q = np.array(x_q)
        x_q = self.transform(**{"image": x_q})['image']
        y_q = torch.tensor(self._X[index]["class"], dtype=torch.long)
        g_q = torch.tensor(self.group_ids[index], dtype=torch.long)

        # Support examples are generated randomly considering the domain
        # of the query example
        # We use seeded generators to ensure we use exactly the same tasks
        # during validation and testing
        if self.rng:
            s_idx = self.rng.randint(self.group_ids_boundaries[g_q.item(
            )][0], self.group_ids_boundaries[g_q.item()][1])
        else:
            s_idx = random.randint(self.group_ids_boundaries[g_q.item(
            )][0], self.group_ids_boundaries[g_q.item()][1])

        x_s = Image.open(self._X[s_idx]["im_path"])
        x_s = np.array(x_s)
        x_s = self.transform(**{"image": x_s})['image']
        y_s = torch.tensor(self._X[s_idx]["class"], dtype=torch.long)
        g_s = torch.tensor(self.group_ids[s_idx], dtype=torch.long)

        assert g_s == g_q, (g_s, g_q)

        return (x_s, x_q), (y_s, y_q), (g_s, g_q)


def get_transform(split):
    # We use data augmentation during training but not during evaluation
    if split == 'train':
        transform = albumentations.Compose([
            albumentations.RandomResizedCrop(64, 64, p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.Rotate(30, p=0.5),
            albumentations.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225], max_pixel_value=255,
                                     p=1.0, always_apply=True),
            ToTensor(),
        ])
    else:
        transform = albumentations.Compose([
            albumentations.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225], max_pixel_value=255,
                                     p=1.0, always_apply=True),
            ToTensor(),
        ])
    return transform
