import argparse
import itertools
import os
import random
from collections import defaultdict
from pprint import pformat

import numpy as np
import skimage.io
import torch
import torchvision
from tqdm import tqdm

from src.dataset.factory import create_dataset, MAP_DATASET_TO_ENUM
from src.generate_attribution import generate_attribution
from src.models.utils import get_model, append_linear_layer_transform

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CIFAR10', help='The dataset to choose', choices=['CIFAR10',
                                                                                                     'Birdsnap'])
parser.add_argument('--model', type=str, default='Resnet8', help='The model to use', choices=['Resnet8',
                                                                                              'Resnet8v2',
                                                                                              'Resnet10',
                                                                                              'Resnet18',
                                                                                              'Resnet50',
                                                                                              'ConvNetSimple'])
parser.add_argument('--model_weights', type=str, required=False, default=None, help='Pretrained weights')
parser.add_argument('--output_dir', type=str, required=False, default='./roar_kar_experiments_abs_no_abs/',
                    help='Output directory path')
opt = parser.parse_args()

# Load model with pretrained models if no model weights provided
if opt.dataset == 'CIFAR10':
    model_weights_paths = dict(
        ConvNetSimple='./logs/model_ConvNetSimple_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0010-model-val_accuracy_76.0.pth',
        Resnet8='./logs/model_ResNet8_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0024-model-val_accuracy_84.36.pth',
        Resnet8v2='./logs/2019-11-13T22:01:56.569700_model_ResNet8v2_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0021-model-val_accuracy_81.56.pth',
        Resnet10='./logs/2019-11-12T16:31:15.904242_model_ResNet10_dataset_CIFAR10_bs_256_name_torch.optim.Adam_lr_0.001/epoch_0024-model-val_accuracy_84.34.pth',
        Resnet18='None'
    )
elif opt.dataset == 'Food101':
    model_weights_paths = dict(
        Resnet18='./logs/2019-11-10T19:57:59.721028_model_resnet18_dataset_Food101_bs_128_name_torch.optim.Adam_lr_0.0001_weight_decay_0.0001/epoch_0019-model-val_accuracy_72.52805280528052.pth'
    )
elif opt.dataset == 'Birdsnap':
    model_weights_paths = dict(
        Resnet18='./logs_birdsnap/2020-02-15T08:00:52.435319_NoWeights/epoch_0021-model-val_acc_53.33943275388838.pth',
        Resnet50='./logs_birdsnap/2020-02-26T12:35:01.546063/epoch_0011-model-val_acc_60.910338517840806.pth',  # Accuracy of the network on the 4977 test images: 58.62969660438015 %%
    )

if not opt.model_weights:
    opt.model_weights = model_weights_paths[opt.model]


def convert_float_to_percentiled_3channel_image(arr):
    """
    Returns image with color pixel intensity as [ 0,  1,  2,  3, ...., 255], where pixels
    :param percentiles:
    :param arr: All NDarray supported.
    """
    if arr.max() == arr.min():
        return np.zeros(arr.shape).astype('uint8')

    percentiled_3channel_image = ((arr - arr.min()) / (arr.max() - arr.min()))
    percentiled_3channel_image = np.uint8(percentiled_3channel_image * 255)
    return percentiled_3channel_image


def dump_saliency_data():
    """
    Main Pipeline for training and cross-validation.
    """

    dataset_args = dict(
        name=MAP_DATASET_TO_ENUM[opt.dataset],
    )

    if opt.dataset == 'CIFAR10':
        dataset_args['split_ratio'] = 1

    if opt.dataset == 'Birdsnap':
        mean = (0.491, 0.506, 0.451)
        std = (0.229, 0.226, 0.267)
        dataset_args['train_transform'] = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
                                                                          torchvision.transforms.CenterCrop(224),
                                                                          torchvision.transforms.ToTensor(),
                                                                          torchvision.transforms.Normalize(mean, std)])
        dataset_args['eval_transform'] = dataset_args['train_transform']

    val_data_args = dict(
        batch_size=4,
        shuffle=False,
        validate_step_size=1,
    )

    # Shuffling should be off
    assert not val_data_args['shuffle']

    model_config = dict(
        ConvNetSimple=dict(
            model_arch_name='src.models.classification.ConvNetSimple.ConvNetSimple',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict(
                input_size=dataset_args['name'].value['image_size'],
                number_of_input_channels=dataset_args['name'].value['channels'],
                number_of_classes=dataset_args['name'].value['labels_count'],
            )),
        Resnet8=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict()),
        Resnet8v2=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet8v2',
            model_weights_path=None,
            model_constructor_args=dict()),
        Resnet10=dict(
            model_arch_name='src.models.classification.PytorchCifarResnet.ResNet10',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict()),
        Resnet18=dict(
            model_arch_name='torchvision.models.resnet18',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict(pretrained=False)),
        Resnet50=dict(
            model_arch_name='torchvision.models.resnet50',
            model_weights_path=opt.model_weights,
            model_constructor_args=dict(pretrained=False)),
    )

    model_args = model_config[opt.model]
    if opt.dataset == 'Birdsnap':
        model_args['model_transformer'] = append_linear_layer_transform

    arguments = dict(
        dataset_args=dataset_args,
        val_data_args=val_data_args,
        model_args=model_args,
        outdir=opt.output_dir,
        random_seed=42
    )

    """ Setup result directory """
    outdir = arguments.get("outdir")
    print('Arguments:\n{}'.format(pformat(arguments)))

    """ Set random seed throughout python"""
    print('Using Random Seed value as: %d' % arguments['random_seed'])
    torch.manual_seed(arguments['random_seed'])  # Set for pytorch, used for cuda as well.
    random.seed(arguments['random_seed'])  # Set for python
    np.random.seed(arguments['random_seed'])  # Set for numpy

    """ Set device - cpu or gpu """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device - {device}')

    """ Load Model with weights(if available) """
    model: torch.nn.Module = get_model(arguments.get('model_args'), device, arguments['dataset_args']).to(device)

    """ Load parameters for the Dataset """
    dataset = create_dataset(arguments['dataset_args'],
                             arguments['val_data_args'],  # Just use val_data_args as train_data_args to use same batch size and shuffle off.
                             arguments['val_data_args'],
                             )

    """ Sample and View the inputs to model """
    dataset.debug()

    has_separate_validation_set = False
    if opt.dataset == 'Birdsnap':
        has_separate_validation_set = True

    dataloaders = [dataset.train_dataloader, dataset.test_dataloader]
    dataset_modes = ['train', 'test']
    if has_separate_validation_set:
        dataloaders.append(dataset.validation_dataloader)
        dataset_modes.append('validation')

    save_input_images = True
    save_attribution = True

    attribution_methods = ['VanillaSaliency', 'input_x_grad', 'gbp', 'gradcam', 'rectgard', 'integrad', 'sparsity0_prune_pgd']
    prune_methods = ['prune_grad_abs', 'prune_pgd_abs', 'PruneGrad-Mid']
    if opt.dataset == 'CIFAR10':
        if opt.model == 'Resnet8':
            prune_thresholds = [92]
    elif opt.dataset == 'Birdsnap':
        if opt.model == 'Resnet50':
            prune_thresholds = [64]

    attribution_methods.extend(prune_method + '_' + str(prune_threshold)
                               for prune_method, prune_threshold in itertools.product(prune_methods, prune_thresholds))

    print(attribution_methods)

    # Save attribution maps in ./roar_kar_experiments/DATASET_[train/test]/MODEL_ATTRIBUTIONMETHOD/ImageIndex.png
    index = 0
    for attribution_method in attribution_methods:
        for dataloader, dataset_mode in zip(dataloaders, dataset_modes):
            # Need to save cropped dataset once - Although its not needed, we still preferred this due to simpler design of having parallel attribution and images dataset
            if save_input_images:
                if opt.dataset == 'CIFAR10':
                    images_output_dir = os.path.join(outdir,
                                                     f'{opt.dataset}_{dataset_mode}/input/')  # E.g. roar_kar_experiments/CIFAR10_train/input
                    os.makedirs(images_output_dir, exist_ok=True)
                    images_output_dirs = list(os.path.join(images_output_dir, dataset_mode_tmp)
                                              for dataset_mode_tmp in dataset_modes)
                else:
                    # Create Image folder for storing cropped images from which attribution maps were generated
                    images_output_dir = os.path.join(outdir,
                                                     f'{opt.dataset}_{dataset_mode}/input/')  # E.g. roar_kar_experiments/Birdsnap_train/input
                    os.makedirs(images_output_dir, exist_ok=False)
                    images_output_dirs = list(os.path.join(images_output_dir, str(i)) for i in range(len(dataset.classes)))
                [os.makedirs(dir, exist_ok=True) for dir in images_output_dirs]
                del images_output_dirs

            os.makedirs(os.path.join(outdir, f'{opt.dataset}_{dataset_mode}'), exist_ok=True)
            print(f"Generating attribution for {attribution_method} in {dataset_mode}")
            counter = 0
            attribution_output_dir = os.path.join(outdir,
                                                  f'{opt.dataset}_{dataset_mode}/'  # E.g. roar_kar_experiments/CIFAR10_train/modelname_attrname/ 
                                                  f'{arguments["model_args"]["model_arch_name"].split(".")[-1]}'
                                                  f'_{attribution_method}')
            os.makedirs(attribution_output_dir, exist_ok=True)
            if opt.dataset == 'Birdsnap':
                # Create labelled attribution folder
                attribution_labels_dirs = list(os.path.join(attribution_output_dir, str(i)) for i in range(len(dataset.classes)))
                [os.makedirs(dir, exist_ok=True) for dir in attribution_labels_dirs]
                counters = defaultdict(int)
                del attribution_labels_dirs

            model.eval()
            for i, data in enumerate(tqdm(dataloader)):
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, max_prob_indices = torch.max(outputs.data, 1)

                # ToDo: Add Batch Support.
                #  Not possible for now since Pruner needs to be adapted to allow for NxCxHxW modes(focus on N)
                for preprocessed_image, max_prob_index, label in zip(inputs, max_prob_indices, labels):
                    if save_attribution:
                        attribution_map = generate_attribution(opt.dataset,
                                                               model,
                                                               opt.model,
                                                               preprocessed_image.unsqueeze(0),
                                                               max_prob_index,
                                                               device,
                                                               attribution_method)
                        # Save in attribution_output_dir as an int image
                        percentiled_image = convert_float_to_percentiled_3channel_image(attribution_map)
                    if opt.dataset == 'CIFAR10':
                        if save_attribution:
                            skimage.io.imsave(f'{attribution_output_dir}/{str(counter).zfill(5)}.png',
                                              percentiled_image.transpose(1, 2, 0))
                        if save_input_images:
                            rgb_image = dataset.denormalization_transform(preprocessed_image.cpu())
                            rgb_image = torch.clamp(rgb_image, 0.0, 1.0).numpy()
                            skimage.io.imsave(f'{os.path.join(images_output_dir, dataset_mode, str(counter).zfill(5))}.png',
                                              rgb_image.transpose(1, 2, 0),
                                              check_contrast=False)
                        counter += 1
                    elif opt.dataset == 'Birdsnap':
                        if save_input_images:
                            # Denormalize the image and save
                            rgb_image = dataset.denormalization_transform(preprocessed_image.cpu())
                            rgb_image = torch.clamp(rgb_image, 0.0, 1.0).numpy()
                            skimage.io.imsave(f'{os.path.join(images_output_dir, str(label.item()))}/'
                                              f'{str(counters[label.item()]).zfill(5)}.png',
                                              rgb_image.transpose(1, 2, 0),
                                              check_contrast=False)

                        if save_attribution:
                            skimage.io.imsave(f'{os.path.join(attribution_output_dir, str(label.item()))}/'
                                              f'{str(counters[label.item()]).zfill(5)}.png',
                                              percentiled_image.transpose(1, 2, 0),
                                              check_contrast=False)
                        counters[label.item()] += 1
                    index += 1
        save_input_images = False  # No need to resave input images for second method

if __name__ == '__main__':
    dump_saliency_data()
