import gc
import os
import os.path as osp
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader, random_split
from torchvision.models import vgg16
from torchvision.transforms import Compose, Grayscale, ToTensor
from tqdm import tqdm

from sde.evaluation import AttributionGroundTruthDataset, GroundTruthAttributionEvaluation
from sde.utils import DictAction, merge_from_options, read_config, set_random_seed


def parse_args():
    parser = ArgumentParser('Ground Truth Evaluation')
    parser.add_argument('config', help='config file of the attribution method')
    parser.add_argument('image_dir', help='directory of the images')
    parser.add_argument('heatmap_dir', help='directory of the heatmaps')
    parser.add_argument('work_dir', help='directory to save the result file')
    parser.add_argument('file_name', help='file name for saving the results')
    parser.add_argument(
        '--num-samples', type=int, default=0, help='Number of samples to check, 0 means checking all the samples')
    parser.add_argument('--gpu-id', type=int, default=0, help='GPU id')
    parser.add_argument('--seed', type=int, default=2021, help='random seed')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='Override some settings in the used config, the key-value pair in xxx=yyy '
        'will be merged into config file.')
    args = parser.parse_args()
    return args


def ground_truth_eval(cfg, work_dir, file_name, image_folder, attribution_map_folder, num_samples=0, device='cuda:0'):
    if not os.path.exists(work_dir):
        os.mkdir(work_dir)
    # get dataset
    # TODO: now the attribution maps are 2-D, here needs to be correspondingly updated
    dataset = AttributionGroundTruthDataset(
        image_folder=image_folder,
        attribution_map_folder=attribution_map_folder,
        transform=Compose([Grayscale(num_output_channels=3), ToTensor()]))
    if num_samples != 0:
        dataset_to_test, _ = random_split(dataset, [num_samples, len(dataset) - num_samples])
    else:
        dataset_to_test = dataset
    dataloader = DataLoader(dataset_to_test, batch_size=8, num_workers=10, persistent_workers=True)

    # init model
    checkpoint = torch.load(cfg['model_args']["pretrained_weights"])
    model_checkpoint = checkpoint["state_dict"]
    model = vgg16(num_classes=cfg['model_args']["num_classes"]).to(device)
    adjusted_model_checkpoint = OrderedDict()
    for key in model_checkpoint.keys():
        adjusted_model_checkpoint[".".join(key.split(".")[1:])] = model_checkpoint[key]
    model.load_state_dict(adjusted_model_checkpoint)

    evaluator = GroundTruthAttributionEvaluation(model, device)

    results = {}
    try:
        for batch in tqdm(dataloader, total=len(dataloader)):
            inputs, attribution_maps, ground_truths, targets, input_names = batch

            for input_tensor, attribution_map, ground_truth, target, input_name in zip(
                    inputs, attribution_maps, ground_truths, targets, input_names):
                input_tensor = input_tensor.to(device)
                target = target.item()

                result_dict = evaluator.evaluate(
                    heatmap=attribution_map,
                    input_tensor=input_tensor,
                    target=target,
                    ground_truth=ground_truth,
                )
                results[input_name] = result_dict['gt_coverage']
                gc.collect()

    except KeyboardInterrupt:
        print("Abort! Save intermediate result.")
        count_overall_and_dump(results, work_dir, file_name)

    count_overall_and_dump(results, work_dir, file_name)


def count_overall_and_dump(results, work_dir, file_name):
    score_list = [item for item in results.values()]
    results["overall"] = np.array(score_list).mean().item()
    with open(osp.join(work_dir, file_name), 'w') as f:
        yaml.dump(results, f)


def main():
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, options=args.cfg_options)
    set_random_seed(args.seed)
    ground_truth_eval(
        cfg=cfg,
        work_dir=args.work_dir,
        file_name=args.file_name,
        image_folder=args.image_dir,
        attribution_map_folder=args.heatmap_dir,
        num_samples=args.num_samples,
        device=f'cuda:{args.gpu_id}')


if __name__ == '__main__':
    main()
