import argparse
import os
from typing import List

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import yaml
from typing import Optional

def load_images(
        img_path,
        attr_map_path,
        class_idx=None,
        img_idx=None,
        seed=None,
        config_attribution_methods=[],
        use_array: bool = True,
        load_gt_mask: bool = True):
    """
    load images from a list of paths, if img_idx is not None, random load image from the path
    return a list of images as List of np.arrays, and a list of name of attribution methods
    """
    # load random seed
    np.random.seed(seed)

    # load all file names under path, save them in list attribution_methods
    attribution_methods = ['image', 'ground_truth'] if load_gt_mask else ['image']

    for attribution_method in config_attribution_methods:
        attribution_methods.append(attribution_method)

    # load all file names under path
    if load_gt_mask:
        cls_names = os.listdir(os.path.join(attr_map_path, attribution_methods[2]))
    else:
        cls_names = os.listdir(os.path.join(attr_map_path, attribution_methods[1]))
    if class_idx is not None:
        assert class_idx < len(cls_names), 'class index exceed the number of classes'
    else:
        class_idx = str(np.random.randint(0, len(cls_names))) if load_gt_mask else cls_names[np.random.randint(0, len(cls_names))]

    # random select one file if img_idx is not given, else select the img_idx-th file
    file_names = os.listdir(os.path.join(img_path, 'images', str(class_idx)))
    if img_idx is None:
        img_idx = np.random.randint(0, len(file_names))

    print(f'load image {file_names[img_idx]} from class {class_idx} under path {attr_map_path}')

    # load image as np.array
    img_list = []
    dataset_folder = ['images', 'gt_masks'] if load_gt_mask else ['images']
    for folder_name in dataset_folder:
        img = plt.imread(os.path.join(img_path, folder_name, str(class_idx), file_names[img_idx]))
        if folder_name == 'gt_masks' and use_array:
            # locate negative ground truth
            img[(0.1 < img) & (img < 1)] = -1
        img_list.append(img)

    true_attribution_methods = attribution_methods[2:] if load_gt_mask else attribution_methods[1:]
    for attribution_method in true_attribution_methods:
        file_name = file_names[img_idx]
        if use_array:
            # load npy file
            file_name = file_name.split('.')[0] + '.npy'
            img = np.load(os.path.join(attr_map_path, attribution_method, str(class_idx), file_name))
        else:
            # load png image
            img = plt.imread(os.path.join(attr_map_path, attribution_method, str(class_idx), file_name))
        img_list.append(img)
    return img_list, attribution_methods


def show_images(images, cols=1, titles=None, save_path='test.png', use_array: bool = True, figure_size = None):
    """
    display a list of images in a single figure with matplotlib.

    Parameters
    ---------
    images: List of np.arrays compatible with plt.imshow.

    cols (Default = 1): Number of columns in figure (number of rows is
                        set to np.ceil(n_images/float(cols))).

    titles: List of titles corresponding to each image. Must have
            the same length as titles.
    """
    assert ((titles is None) or (len(images[0]) == len(titles)))
    n_samples = len(images)
    n_images = len(images[0])
    if titles is None:
        titles = ['Image (%d)' % i for i in range(1, n_images + 1)]
    fig = plt.figure()
    plot_title = False
    for sample_idx, images_per_sample in enumerate(images):
        for n, (image, title) in enumerate(zip(images_per_sample, titles)):
            a = fig.add_subplot(n_samples, int(np.ceil(n_images / float(cols))), sample_idx * n_images + n + 1)
            if not use_array:
                plt.gray()
                plt.imshow(image, aspect='auto')
            else:
                if title == 'Image':
                    aspect_ratio = image.shape[1] / float(image.shape[0])
                    plt.imshow(image, aspect=aspect_ratio)
                else:

                    # Create a custom colormap that highlights positive and negative values differently
                    cmap = plt.cm.get_cmap('seismic')

                    # Create a plot with the 2D array and the custom colormap
                    aspect_ratio = image.shape[1] / float(image.shape[0])
                    plt.imshow(unbiasd_normalize(image), cmap=cmap, vmin=0, vmax=1, aspect=aspect_ratio)

            a.axes.get_yaxis().set_visible(False)
            a.axes.get_xaxis().set_visible(False)
            if not plot_title:
                a.set_title(title, fontsize=40)
        plot_title = True
    plt.subplots_adjust(wspace=0.1)
    if figure_size:
        fig.set_size_inches(figure_size)
    else:
        fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
    plt.savefig(save_path, bbox_inches='tight')


def unbiasd_normalize(image: np.array):
    # check if the image has negative value, if yes
    # normalize the image such that the negative value is -1 and positive value is 1
    # but the zero is still zero
    # if image only has positive value, normalize the image such that the max value is 1
    if image.min() < 0:
        mask = image >= 0
        image[mask] = image[mask] / np.max(image[mask])
        image[~mask] = image[~mask] / np.abs(np.min(image[~mask]))
    else:
        image = image / np.max(image)

    # shift the image such that the minimum value is 0 and max is 1
    if image.min() < 0:
        image = (image - np.min(image)) / (np.max(image) - np.min(image))
    else:
        image = (image + 1) / 2
    return image


def plot_attribution_map_images(
        img_path: str,
        attr_map_path: str,
        config_attribution_methods: List[str],
        file_to_title: dict,
        save_path: str,
        use_array: bool,
        samples: int = 1,
        seed: Optional[int] = None):
    total_images = []
    for _ in range(samples):
        img_index = 0 if samples == 1 else None
        images, attribution_methods = load_images(img_path, attr_map_path, img_idx=img_index,
                                                  config_attribution_methods=config_attribution_methods,
                                                  use_array=use_array, seed=seed)
        total_images.append(images)
    attribution_method_titles = []
    for attribution_method in attribution_methods:
        attribution_method_titles.append(file_to_title[attribution_method])
    show_images(total_images, titles=attribution_method_titles, save_path=save_path, use_array=use_array)


def plot_attribution_map_images_imagenet(
        img_path: str,
        attr_map_path: str,
        img_path_imagenet: str,
        attr_map_path_imagenet: str,
        config_attribution_methods: List[str],
        file_to_title: dict,
        save_path: str,
        use_array: bool,
        figure_size: (60, 20),
        seed: Optional[int] = None):
    """ Plot two rows, first row is synthetic image result, second row is imagenet result """
    total_images = []
    
    # load synthetic result
    images, attribution_methods = load_images(img_path, attr_map_path, img_idx=None,
                                              config_attribution_methods=config_attribution_methods,
                                              use_array=use_array, seed=seed, load_gt_mask=False)
    total_images.append(images)

    # load imagenet result
    images, _ = load_images(img_path_imagenet, attr_map_path_imagenet, img_idx=None,
                                              config_attribution_methods=config_attribution_methods,
                                              use_array=use_array, seed=seed, load_gt_mask=False)
    total_images.append(images)

    attribution_method_titles = []
    for attribution_method in attribution_methods:
        attribution_method_titles.append(file_to_title[attribution_method])
    show_images(total_images, titles=attribution_method_titles, save_path=save_path, use_array=use_array, figure_size=figure_size)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default="configs/plot_two_row.yml")
    args = parser.parse_args()

    # load config
    with open(args.config, 'r') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

    if "imagenet" in config and config['imagenet'] == True:
        plot_attribution_map_images_imagenet(
            config['img_path'],
            config['attr_map_path'],
            config['img_path_imagenet'],
            config['attr_map_path_imagenet'],
            config['attribution_methods'],
            config['file_to_title'],
            config['save_path'],
            config['use_array'],
            config['figure_size'],
            config['seed'],
        )
    else:
        plot_attribution_map_images(
            config['img_path'],
            config['attr_map_path'],
            config['attribution_methods'],
            config['file_to_title'],
            config['save_path'],
            config['use_array'],
            config['samples'],
            config['seed'],
        )


if __name__ == '__main__':
    main()
