import gc
import os
import os.path as osp
from argparse import ArgumentParser
from collections import OrderedDict

import torch
import yaml
from torch.utils.data import DataLoader, random_split
from torchvision.models import vgg16
from torchvision.transforms import Compose, Grayscale, ToTensor
from tqdm import tqdm

from sde.evaluation import AttributionDataset
from sde.evaluation.sanity_check import SanityCheck
from sde.utils import DictAction, merge_from_options, read_config, set_random_seed


def parse_args():
    parser = ArgumentParser('Sanity check')
    parser.add_argument('config', help='config file of the attribution method')
    parser.add_argument('image_dir', help='directory of the images')
    parser.add_argument('heatmap_dir', help='directory of the heatmaps')
    parser.add_argument('work_dir', help='directory to save the result file')
    parser.add_argument('file_name', help='file name for saving the results')
    parser.add_argument(
        '--num-samples', type=int, default=0, help='Number of samples to check, 0 means checking all the samples')
    parser.add_argument(
        '--save-heatmaps',
        action='store_true',
        default=False,
        help='Whether to save the heatmaps produced by the perturbed models')
    parser.add_argument('--gpu-id', type=int, default=0, help='GPU id')
    parser.add_argument('--seed', type=int, default=2021, help='random seed')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='Override some settings in the used config, the key-value pair in xxx=yyy '
        'will be merged into config file.')
    args = parser.parse_args()
    return args


def sanity_check(
        cfg,
        work_dir,
        file_name,
        image_folder,
        attribution_map_folder,
        num_samples=0,
        save_heatmaps=False,
        device='cuda:0'):
    if not os.path.exists(work_dir):
        os.mkdir(work_dir)
    # get dataset
    dataset = AttributionDataset(
        image_folder=image_folder,
        attribution_map_folder=attribution_map_folder,
        transform=Compose([
            Grayscale(num_output_channels=3),
            ToTensor(),
        ]),
    )

    if num_samples != 0:
        dataset_to_test, _ = random_split(dataset, [num_samples, len(dataset) - num_samples])
    else:
        dataset_to_test = dataset
    dataloader = DataLoader(dataset_to_test, batch_size=8, num_workers=10, persistent_workers=True)

    # init model
    checkpoint = torch.load(cfg['model_args']["pretrained_weights"])
    model_checkpoint = checkpoint["state_dict"]
    model = vgg16(num_classes=cfg['model_args']["num_classes"]).to(device)
    adjusted_model_checkpoint = OrderedDict()
    for key in model_checkpoint.keys():
        adjusted_model_checkpoint[".".join(key.split(".")[1:])] = model_checkpoint[key]
    model.load_state_dict(adjusted_model_checkpoint)
    evaluator = SanityCheck(model, device, perturb_mode=cfg["sanity_check"]["weight_init"])

    results = {}
    try:
        for batch in tqdm(dataloader, total=len(dataloader)):
            inputs, attribution_maps, targets, input_names = batch

            for input_tensor, attribution_map, target, input_name in zip(inputs, attribution_maps, targets,
                                                                         input_names):
                input_tensor = input_tensor.to(device)
                target = target.item()

                ssim_dict = evaluator.evaluate(
                    heatmap=attribution_map,
                    input_tensor=input_tensor,
                    target=target,
                    perturb_layers=cfg["sanity_check"]['perturb_layers'],
                    save_dir=osp.join(work_dir, input_name),
                    attribution_method_cfg=cfg["attribution_method_cfg"],
                    save_heatmaps=save_heatmaps)
                results.update({input_name: ssim_dict['ssim_all']})
                gc.collect()
    except KeyboardInterrupt:
        print("Abort! Save intermediate result.")
        with open(osp.join(work_dir, file_name), 'w') as f:
            yaml.dump(results, f)
        return

    with open(osp.join(work_dir, file_name), 'w') as f:
        yaml.dump(results, f)


def main():
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, args.cfg_options)

    set_random_seed(args.seed)
    sanity_check(
        cfg=cfg,
        work_dir=args.work_dir,
        file_name=args.file_name,
        image_folder=args.image_dir,
        attribution_map_folder=args.heatmap_dir,
        num_samples=args.num_samples,
        save_heatmaps=args.save_heatmaps,
        device=f'cuda:{args.gpu_id}')


if __name__ == '__main__':
    main()
