import os
import os.path as osp
import warnings
from argparse import ArgumentParser
from logging import Logger
from typing import Dict

import numpy as np
import skimage
import torch
import yaml
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Compose, Lambda, PILToTensor, ToTensor
from torchvision.models import vgg16
from tqdm import tqdm

from sde.attribution_methods import AttributionGenerator
from sde.datasets import IndexedImageFolderPath, ImageFolderWithAttribution
from sde.models import load_synthetic_model, load_trained_vgg
from sde.utils import DictAction, merge_from_options, mkdir_or_exist, read_config, set_random_seed, setup_logger


def parse_args():
    parser = ArgumentParser()
    parser.add_argument("config", help="Configuration file to use.")
    parser.add_argument('-o', '--out-dir', default='attribution_maps/default_out_dir/', help='Output directory')
    parser.add_argument('--gpu-id', type=int, default=0, help='GPU device ID.')
    parser.add_argument('--seed', type=int, help='Random seed. If None, the program will not set seed explicitly.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='Override some settings in the used config, the key-value pair in xxx=yyy '
        'will be merged into config file.')
    args = parser.parse_args()
    return args


def dump_saliency_data(cfg: Dict, out_dir: str, device: torch.device, logger: Logger) -> None:

    # Common Configuration
    dataset_cfg = cfg['dataset']

    # Load parameters for the Dataset
    if "imagenet" in dataset_cfg and dataset_cfg["imagenet"]:
        is_imagenet = True
        dataset_cfg.pop("imagenet")
        dataset = ImageFolderWithAttribution(**dataset_cfg)
        ind_to_cls = {v: k for k, v in dataset.cls_to_ind_dict.items()}
    else:
        is_imagenet = False
        transform = Compose([PILToTensor(), Lambda(lambda x: x.float())
                            ]) if dataset_cfg['num_channels'] == 3 else ToTensor()
        dataset = IndexedImageFolderPath(transform=transform, **dataset_cfg)
        ind_to_cls = {v: k for k, v in dataset.class_to_idx.items()}

    # split dataset and set one portion for generation if needed
    split_cfg = cfg.get('split_cfg', None)
    if split_cfg is not None and split_cfg['num_splits'] > 1:
        # fix seed to make sure every split is identical across runs
        split_lens = [len(dataset) // dataset_cfg["split_dataset"] for _ in range(dataset_cfg["split_dataset"])]
        split_lens[-1] = len(dataset) - (dataset_cfg["split_dataset"] - 1) * (
            len(dataset) // dataset_cfg["split_dataset"])
        logger.info(f"Dataset is split into datasets of length: {split_lens}")
        splited_datasets = random_split(dataset, split_lens, generator=torch.Generator().manual_seed(42))
        dataset_ind = split_cfg["run_on_split"]
        logger.info(f"Use the {dataset_ind}. splited dataset")
        dataset = splited_datasets[dataset_ind - 1]

    dataloader = DataLoader(dataset, **cfg['data_loader'])

    attribution_methods = cfg['attribution_methods']
    run_on = cfg['run_on']
    if run_on != 'all':
        assert isinstance(run_on, list), "run_on must be a list of str"
        attribution_methods = [a for a in attribution_methods if a['name'] in run_on]

    for attribution_method in attribution_methods:
        if not cfg['model']['is_synthetic']:
            if cfg['model']['imagenet_model']:
                model = vgg16(pretrained=True).to(device).eval()
            else:
                # lets just assume all models are trained VGG16
                model = load_trained_vgg(cfg['model']["pretrained_weights"], cfg['model']["num_classes"], device=device).eval()
        else:
            model = load_synthetic_model(**cfg['model']['parameter']).to(device).eval()

        attr_method_name = f'{attribution_method["name"]}'
        if attr_method_name in ('extremal_perturbation', 'lrp'):
            if hasattr(model, 'softmax') and model.softmax is not None:
                model.softmax = None
                warnings.warn(
                    f'Running attribution method name : {attr_method_name}, and '
                    'model.softmax is not None. For Extremal Perturbation or LRP, the '
                    'attribution method is applied to the layer before softmax, so we '
                    'will explicitly set the softmax to None here. You can also edit '
                    'the config to avoid building softmax layer for the model.')

        # Thank god, finally let CAM extraction begin
        logger.info(f"Generating attribution for:\n"
                    f"{yaml.dump(attribution_method, indent=4, sort_keys=False)}")
        attribution_generator = AttributionGenerator(attribution_method=attribution_method)

        # Only perform shape checking for the first 10 samples
        max_checked_samples = 10
        num_checked_samples = 0

        for i, data in enumerate(tqdm(dataloader)):
            if is_imagenet:
                inputs, labels, paths = data['img'], data['label'], data['img_file']
            else:
                inputs, labels, paths = data
            inputs = inputs.float().to(device)
            labels = labels.to(device)

            for preprocessed_image, label, path in zip(inputs, labels, paths):
                need_check_shape = num_checked_samples <= max_checked_samples
                attr_array, attr_map = attribution_generator.generate_attribution(
                    model, preprocessed_image.unsqueeze(0), label, need_check_shape)

                # assume that attribution_map is already normalized and converted to uint8 image
                attr_map_file = osp.basename(path)
                attr_map_sub_dir = osp.join(out_dir, 'attr_maps', attr_method_name, ind_to_cls[label.item()])
                mkdir_or_exist(attr_map_sub_dir)
                attr_map_path = osp.join(attr_map_sub_dir, attr_map_file)
                skimage.io.imsave(attr_map_path, attr_map, check_contrast=False)

                attr_array_file = osp.splitext(attr_map_file)[0] + '.npy'
                attr_array_sub_dir = osp.join(out_dir, 'attr_arrays', attr_method_name, ind_to_cls[label.item()])
                mkdir_or_exist(attr_array_sub_dir)
                attr_array_path = osp.join(attr_array_sub_dir, attr_array_file)
                np.save(attr_array_path, attr_array)

                num_checked_samples += 1


if __name__ == '__main__':
    args = parse_args()
    logger = setup_logger("sde")

    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, options=args.cfg_options)
    logger.info(f'Using config:\n{yaml.dump(cfg, indent=4, sort_keys=False)}\n' + '-' * 60)

    if args.seed is not None:
        set_random_seed(args.seed)
        logger.info('Using seed: {args.seed}')

    # Setup result directory
    out_dir = args.out_dir
    mkdir_or_exist(out_dir)
    if len(os.listdir(out_dir)) > 0:
        warnings.warn(f'The output directory: {out_dir} is not empty! Please specify an empty out_dir')

    device = torch.device(f'cuda:{args.gpu_id}')
    dump_saliency_data(cfg, out_dir, device=device, logger=logger)
