import os
from copy import deepcopy

import numpy as np
import skimage.io
import torch
import torchvision
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from src.dataset.Birdsnap import get_birdsnap_object
from src.roar.roar_utils import remove
from src.utils.sysutils import get_cores_count


class CompoundImageFolderDataset(torch.utils.data.Dataset):
    """
    To load CIFAR or similar dataset along with their attribution methods from h5py file.
    """

    datasets_config = {
        'Birdsnap': dict(
            mean=(0.491, 0.506, 0.451),
            std=(0.229, 0.226, 0.267),
            classes=get_birdsnap_object().classes,
        ),
    }

    def __init__(self, dataset_name,
                 image_files_train_path,
                 image_files_validation_path,
                 image_files_test_path,
                 attribution_files_train_path,
                 attribution_files_validation_path,
                 attribution_files_test_path,
                 roar=True,
                 percentile=0.1):
        datasets_supported = self.datasets_config.keys()
        if dataset_name not in datasets_supported:
            raise ValueError(f'Invalid dataset_name {dataset_name}')

        self.dataset_name = dataset_name
        self.attribution_files_train_path = attribution_files_train_path
        self.attribution_files_validation_path = attribution_files_validation_path
        self.attribution_files_test_path = attribution_files_test_path
        self.percentile = percentile
        self.roar = roar

        self.mean = self.datasets_config[dataset_name]['mean']
        self.std = self.datasets_config[dataset_name]['std']
        self.demean = [-m / s for m, s in zip(self.mean, self.std)]
        self.destd = [1 / s for s in self.std]

        self.train_normalize_transform = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor()])
        self.evaluation_normalize_transform = torchvision.transforms.Compose(
            [torchvision.transforms.ToTensor(),
             torchvision.transforms.Normalize(self.mean, self.std)])
        self.denormalize_transform = torchvision.transforms.Normalize(self.demean,
                                                                      self.destd)  # Used for visualization of preprocessed images.

        self.training_images_dataset = torchvision.datasets.ImageFolder(root=image_files_train_path,
                                                                        transform=self.train_normalize_transform)
        self.validation_images_dataset = torchvision.datasets.ImageFolder(root=image_files_validation_path,
                                                                          transform=self.evaluation_normalize_transform)
        self.test_images_dataset = torchvision.datasets.ImageFolder(root=image_files_test_path,
                                                                    transform=self.evaluation_normalize_transform)

        self.training_attribution_map_dataset = torchvision.datasets.ImageFolder(root=attribution_files_train_path,
                                                                                 transform=torchvision.transforms.ToTensor())
        self.validation_attribution_map_dataset = torchvision.datasets.ImageFolder(
            root=attribution_files_validation_path,
            transform=torchvision.transforms.ToTensor())
        self.test_attribution_map_dataset = torchvision.datasets.ImageFolder(root=attribution_files_test_path,
                                                                             transform=torchvision.transforms.ToTensor())

        self.mode = 'training'

    def __getitem__(self, index, return_debug_info=False):
        if self.mode == 'training':
            image, label = self.training_images_dataset[index]
            attribution_map, label = self.training_attribution_map_dataset[index]
            mean = self.mean
        elif self.mode == 'validation':
            image, label = self.validation_images_dataset[index]
            attribution_map, label = self.validation_attribution_map_dataset[index]
            mean = [0,0,0]
        else:
            image, label = self.test_images_dataset[index]
            attribution_map, label = self.test_attribution_map_dataset[index]
            mean = [0, 0, 0]

        image = np.array(image)
        # output = remove(image, attribution_map, self.mean, self.percentile, keep=False,
        #                 gray='gradcam' in self.attribution_files_train_path,
        #                 return_mask=return_debug_info)

        attribution_map = np.max(attribution_map.numpy(), axis=0, keepdims=True)
        # Image is already preprocessed. Replace most/least important pixels with 0
        output = remove(image, attribution_map, mean, self.percentile, keep=not self.roar,
                        gray=True,
                        return_mask=return_debug_info)

        if return_debug_info:
            processed_image, mask = output
        else:
            processed_image = output


        if self.mode == 'training':
            # Do augmentation transform
            augmentation_transform = torchvision.transforms.Compose([
                torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
                torchvision.transforms.RandomCrop(224),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(self.mean, self.std)
            ])
            processed_image = processed_image.transpose(1, 2, 0)
            processed_image = augmentation_transform(Image.fromarray((processed_image * 255).astype(np.uint8)))

        if return_debug_info:
            return image, processed_image, label, attribution_map, mask
        else:
            return processed_image, label

    def __len__(self):
        if self.mode == 'training':
            return self.train_dataset_size
        elif self.mode == 'validation':
            return self.val_dataset_size
        else:
            return self.test_dataset_size

    def get_train_dataloader(self, data_args) -> DataLoader:
        self.mode = 'training'
        return torch.utils.data.DataLoader(self,
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    def get_validation_dataloader(self, data_args) -> DataLoader:
        self.mode = 'validation'
        return torch.utils.data.DataLoader(self,
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    def get_test_dataloader(self, data_args):
        self.mode = 'test'
        return torch.utils.data.DataLoader(self,
                                           batch_size=data_args['batch_size'],
                                           shuffle=data_args['shuffle'],
                                           num_workers=get_cores_count())

    @property
    def classes(self):
        return self.datasets_config[self.dataset_name]['classes']

    @property
    def train_dataset_size(self):
        return len(self.training_images_dataset)

    @property
    def val_dataset_size(self):
        return len(self.validation_images_dataset)

    @property
    def test_dataset_size(self):
        return len(self.test_images_dataset)

    def debug(self, outdir, name, train, indices):
        assert name is not None, "Name is used as title as well as to save image if outdir provided"
        # get some random training images, attribution maps and labels
        fig = plt.figure()
        fig.suptitle(f'{name}')
        # Input Image, Attribution Map, HeatMap of Attriution Map, ThresholdedAttributionMap and Output Image
        plot_rows = len(indices) + 1  # One extra for column title
        plot_columns = 4

        titles = ['Input Image', 'Saliency Map', 'Thresholded-Saliency', 'Output Image']
        for location, title in zip(range(1, plot_columns + 1), titles):
            smap = fig.add_subplot(plot_rows, plot_columns, location)
            smap.axis('off')
            smap.set_title(title, fontsize=12, fontweight='bold')

        for sample, index in enumerate(indices):
            if train:
                self.mode = 'training'
            else:
                self.mode = 'test'
            image, processed_image, label, attribution_map, mask = self.__getitem__(index, return_debug_info=True)

            image = self.denormalize_transform(torch.from_numpy(image))

            # print(f'Sample {sample}, train dataset index: {self.training_indices[index]}, label: {label}')
            locations = [(sample + 1) * plot_columns + i for i in range(1, plot_columns + 1)]

            if type(processed_image) != torch.Tensor:
                processed_image = torch.from_numpy(processed_image)
            images = [np.transpose(image, (1, 2, 0)),
                      np.transpose((attribution_map), (1, 2, 0)),
                      np.transpose((mask), (1, 2, 0)) if len(mask.shape) == 3 else mask,
                      np.transpose(self.denormalize_transform(processed_image).numpy(), (1, 2, 0))]
            display_modes = ['default', 'default', 'float', 'default']

            for location, title, image, display_mode in zip(locations, titles, images, display_modes):
                smap = fig.add_subplot(plot_rows, plot_columns, location)
                smap.axis('off')
                location += 1

                image = np.squeeze(image)

                if display_mode == 'default':
                    plt.imshow(image)
                elif display_mode == 'float':
                    image = np.float64(image)
                    image = image - image.min()
                    image /= image.max()
                    plt.imshow(np.uint8(image * 255))
                elif display_mode == 'RedBlueHeatmap':
                    image = np.float64(image)
                    pos_values = np.copy(image)
                    pos_values[pos_values < 0.] = 0.0
                    neg_values = np.copy(image)
                    neg_values[neg_values > 0.] = 0.0
                    abs_neg_values = abs(neg_values)

                    plt.imshow(pos_values, cmap='Reds')
                    plt.imshow(abs_neg_values, cmap='Blues', alpha=0.5)

        fig = plt.gcf()
        fig.tight_layout()
        plt.show()
        if outdir:
            os.makedirs(outdir, exist_ok=True)
            fig.savefig(f'{outdir}/{name}.png')
