import json
import os.path as osp
from argparse import ArgumentParser
from logging import Logger
from typing import Dict, Optional

import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from sde.datasets import ImageFolderWithAttribution
from sde.evaluation import build_evaluator
from sde.models import load_synthetic_model
from sde.utils import DictAction, merge_from_options, mkdir_or_exist, read_config, set_random_seed, setup_logger


def parse_args():
    parser = ArgumentParser('Evaluating attribution maps.')

    parser.add_argument('config', help='Config file to use.')
    parser.add_argument('-o', '--out-dir', default='output/dummy/', help='Directory to save output files.')
    parser.add_argument('-s', '--subset-size', type=int, help='Subset size. If not set, the full dataset will be used.')
    parser.add_argument('--seed', type=int, help='Random seed.')
    parser.add_argument('--gpu-id', type=int, default=0, help='GPU device ID.')
    parser.add_argument(
        '--cfg-options', action=DictAction, nargs='+', help='Override config. Key-value pairs must be format xxx=yyy.')

    args = parser.parse_args()
    return args


@torch.no_grad()
def evaluate_attribution_maps(
        cfg: Dict, out_dir: str, subset_size: Optional[int], seed: Optional[int], device: torch.device,
        logger: Logger) -> None:
    if seed is not None:
        set_random_seed(seed)
    mkdir_or_exist(out_dir)

    dataset = ImageFolderWithAttribution(**cfg['dataset'])
    if subset_size is not None:
        rng = np.random.default_rng()
        inds = rng.integers(0, len(dataset), size=subset_size)
        dataset = Subset(dataset, indices=inds)
    logger.info(f'Dataset length: {len(dataset)}')
    data_loader = DataLoader(dataset, **cfg['data_loader'])

    model = load_synthetic_model(**cfg['model'])
    model.eval()
    model.to(device)

    evaluator = build_evaluator(cfg['evaluator'])
    for i in tqdm(range(evaluator.num_steps)):
        evaluator.reset_cache()

        for batch_data in data_loader:
            imgs = batch_data['img'].to(device)
            labels = batch_data['label'].to(device)
            attr_maps = batch_data['attr_map'].to(device)
            gt_masks = batch_data['gt_mask'].to(device)
            img_files = batch_data['img_file']

            for img, label, attr_map, gt_mask, img_file in zip(imgs, labels, attr_maps, gt_masks, img_files):
                evaluator.evaluate(
                    model=model, img=img, label=label, attr_map=attr_map, gt_mask=gt_mask, img_file=img_file)

        evaluator.summarize_step()
        evaluator.increment_step()

    result = evaluator.summarize_total()
    out_file = osp.join(out_dir, 'evaluation_result.json')
    with open(out_file, 'w') as f:
        json.dump(result, f)
        logger.info(f'Result is dumped to: {out_file}')

    if cfg['visualize_result']:
        vis_file = osp.join(out_dir, 'visualized_result.jpeg')
        evaluator.visualize_result(result, save_path=vis_file)
        logger.info(f'Visualized result is dumped to: {vis_file}')


if __name__ == '__main__':
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, args.cfg_options)
    logger = setup_logger('sde')
    logger.info(f'Using config:\n{yaml.dump(cfg, indent=4, sort_keys=False)}\n' + '-' * 60)

    device = torch.device(f'cuda:{args.gpu_id}')
    out_dir = args.out_dir
    subset_size = args.subset_size
    seed = args.seed
    evaluate_attribution_maps(cfg, out_dir=out_dir, subset_size=subset_size, seed=seed, device=device, logger=logger)
