import os
from copy import deepcopy

import numpy as np
import skimage.io
import torch
import torchvision
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from src.roar.roar_utils import remove
from src.utils.sysutils import get_cores_count


class CompoundCifarDataset(torch.utils.data.Dataset):
    """
    To load CIFAR or similar dataset along with their attribution methods from h5py file.
    """

    datasets_config = {
        'CIFAR10': dict(
            mean=(0.4914, 0.4822, 0.4465),
            std=(0.2470, 0.2435, 0.2616),
            classes=['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
        ),

        'Food101': dict(
            mean=(0.561, 0.440, 0.312),
            std=(0.252, 0.256, 0.259),
            classes=['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad',
                     'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad',
                     'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate',
                     'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse',
                     'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame',
                     'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots',
                     'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup',
                     'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi',
                     'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger',
                     'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',
                     'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',
                     'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters', 'pad_thai', 'paella', 'pancakes',
                     'panna_cotta', 'peking_duck', 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',
                     'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi',
                     'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', 'spaghetti_carbonara',
                     'spring_rolls', 'steak', 'strawberry_shortcake', 'sushi', 'tacos', 'takoyaki', 'tiramisu',
                     'tuna_tartare', 'waffles']
        )
    }

    def __init__(self, dataset_name,
                 attribution_files_train_path, attribution_files_test_path,
                 train_val_split=0.9, roar=True, percentile=0.1):
        datasets_supported = self.datasets_config.keys()
        if dataset_name not in datasets_supported:
            raise ValueError(f'Invalid dataset_name {dataset_name}')

        self.dataset_name = dataset_name
        self.mean = self.datasets_config[dataset_name]['mean']
        self.std = self.datasets_config[dataset_name]['std']
        self.demean = [-m / s for m, s in zip(self.mean, self.std)]
        self.destd = [1 / s for s in self.std]

        self.normalize_transform = torchvision.transforms.Compose([torchvision.transforms.Normalize(self.mean,
                                                                                                    self.std)])

        self.denormalize_transform = torchvision.transforms.Normalize(self.demean, self.destd)
        if dataset_name == 'CIFAR10':
            self.training_set = torchvision.datasets.CIFAR10(root='./data/' + dataset_name,
                                                             train=True,
                                                             download=True,
                                                             transform=torchvision.transforms.ToTensor())
            self.test_set = torchvision.datasets.CIFAR10(root='./data/' + dataset_name,
                                                         train=False,
                                                         download=True,
                                                         transform=torchvision.transforms.ToTensor())
        else:
            food101_dataset_dir = './data/food-101/'
            self.training_set = torchvision.datasets.ImageFolder(root=food101_dataset_dir + 'train/',
                                                                 transform=torchvision.transforms.ToTensor())
            self.testset = torchvision.datasets.ImageFolder(root=food101_dataset_dir + 'test/',
                                                            transform=torchvision.transforms.ToTensor())

        self.roar = roar

        self.attribution_files_train_path = attribution_files_train_path
        self.attribution_files_test_path = attribution_files_test_path

        self.percentile = percentile
        self.mode = 'training'

        # Necessary to index both
        self.training_indices, self.validation_indices = self._uniform_train_val_split(split_ratio=train_val_split)

    def __getitem__(self, index, return_debug_info=False):
        if self.mode == 'training':
            image, label, attribution_map = self.training_set[self.training_indices[index]] + \
                                            (read_attribution_map(self.attribution_files_train_path,
                                                                  self.training_indices[index]),)
        elif self.mode == 'validation':
            image, label, attribution_map = self.training_set[self.validation_indices[index]] + \
                                            (read_attribution_map(self.attribution_files_train_path,
                                                                  self.validation_indices[index]),)
        else:
            image, label, attribution_map = self.test_set[index] + \
                                            (read_attribution_map(self.attribution_files_test_path, index),)

        image = np.array(image)
        # output = remove(image, attribution_map, self.mean, self.percentile, keep=False,
        #                 gray='gradcam' in self.attribution_files_train_path,
        #                 return_mask=return_debug_info)

        attribution_map = np.max(attribution_map, axis=0, keepdims=True)
        output = remove(image, attribution_map, self.mean, self.percentile, keep=not self.roar,
                        gray=True,
                        return_mask=return_debug_info)

        if return_debug_info:
            processed_image, mask = output
        else:
            processed_image = output
        processed_image = self.normalize_transform(torch.from_numpy(processed_image))

        if return_debug_info:
            return image, processed_image, label, attribution_map, mask
        else:
            return processed_image, label

    def __len__(self):
        if self.mode == 'training':
            return self.train_dataset_size
        elif self.mode == 'validation':
            return self.val_dataset_size
        else:
            return self.test_dataset_size

    def get_train_dataloader(self, data_args) -> DataLoader:
        self.mode = 'training'
        return torch.utils.data.DataLoader(deepcopy(self),
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    def get_validation_dataloader(self, data_args) -> DataLoader:
        self.mode = 'validation'
        return torch.utils.data.DataLoader(deepcopy(self),
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    def get_test_dataloader(self, data_args):
        self.mode = 'test'
        return torch.utils.data.DataLoader(deepcopy(self),
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    @property
    def classes(self):
        return self.datasets_config[self.dataset_name]['classes']

    @property
    def train_dataset_size(self):
        return len(self.training_indices)

    @property
    def val_dataset_size(self):
        return len(self.validation_indices)

    @property
    def test_dataset_size(self):
        return len(self.test_set)

    def _uniform_train_val_split(self, split_ratio):
        targets = self.training_set.targets
        if type(targets) == list:
            targets = np.array(targets)
            labels = targets
        elif type(targets) == torch.tensor or type(targets) == torch.Tensor:
            labels = targets.numpy()
        training_indices = []
        validation_indices = []
        for i in range(len(self.classes)):
            label_indices = np.argwhere(labels == i)
            samples_per_label = int(split_ratio * len(label_indices))
            training_label_indices = label_indices[:samples_per_label]
            validation_label_indices = label_indices[samples_per_label:]
            training_indices.extend(training_label_indices.squeeze().tolist())
            validation_indices.extend(validation_label_indices.squeeze().tolist())
            assert not set(training_label_indices.ravel().tolist()) & set(validation_label_indices.ravel().tolist())

        assert not set(training_indices) & set(validation_indices)
        return training_indices, validation_indices

    def debug(self, outdir, name, train, indices):
        assert name is not None, "Name is used as title as well as to save image if outdir provided"
        # get some random training images, attribution maps and labels
        fig = plt.figure()
        fig.suptitle(f'{name}')
        # Input Image, Attribution Map, HeatMap of Attriution Map, ThresholdedAttributionMap and Output Image
        plot_rows = len(indices) + 1  # One extra for column title
        plot_columns = 4

        titles = ['Input Image', 'Saliency Map', 'Thresholded-Saliency', 'Output Image']
        for location, title in zip(range(1, plot_columns + 1), titles):
            smap = fig.add_subplot(plot_rows, plot_columns, location)
            smap.axis('off')
            smap.set_title(title, fontsize=12, fontweight='bold')

        for sample, index in enumerate(indices):
            if train:
                self.mode = 'training'
            else:
                self.mode = 'test'
            image, processed_image, label, attribution_map, mask = self.__getitem__(index, return_debug_info=True)

            # print(f'Sample {sample}, train dataset index: {self.training_indices[index]}, label: {label}')
            locations = [(sample + 1) * plot_columns + i for i in range(1, plot_columns + 1)]

            images = [np.transpose(image, (1, 2, 0)),
                      np.transpose((attribution_map), (1, 2, 0)),
                      np.transpose((mask), (1, 2, 0)) if len(mask.shape) == 3 else mask,
                      np.transpose(self.denormalize_transform(processed_image).numpy(), (1, 2, 0))]
            display_modes = ['default', 'default', 'float', 'default']

            for location, title, image, display_mode in zip(locations, titles, images, display_modes):
                smap = fig.add_subplot(plot_rows, plot_columns, location)
                smap.axis('off')
                location += 1

                image = np.squeeze(image)

                if display_mode == 'default':
                    plt.imshow(image)
                elif display_mode == 'float':
                    image = np.float64(image)
                    image = image - image.min()
                    image /= image.max()
                    plt.imshow(np.uint8(image * 255))
                elif display_mode == 'RedBlueHeatmap':
                    image = np.float64(image)
                    pos_values = np.copy(image)
                    pos_values[pos_values < 0.] = 0.0
                    neg_values = np.copy(image)
                    neg_values[neg_values > 0.] = 0.0
                    abs_neg_values = abs(neg_values)

                    plt.imshow(pos_values, cmap='Reds')
                    plt.imshow(abs_neg_values, cmap='Blues', alpha=0.5)

        fig = plt.gcf()
        fig.tight_layout()
        plt.show()
        if outdir:
            os.makedirs(outdir, exist_ok=True)
            fig.savefig(f'{outdir}/{name}.png')


def read_attribution_map(attribution_folder_path, index):
    try:
        return skimage.io.imread(f'{attribution_folder_path}/'
                                 f'{str(index).zfill(5)}.png').transpose(2, 0, 1)
    except Exception:
        print(attribution_folder_path, index)
        raise