from argparse import ArgumentParser
from logging import Logger
from typing import Dict

import yaml

from sde.datasets import ImageFolderWithAttribution
from sde.utils import DictAction, merge_from_options, read_config, setup_logger


def parse_args():
    parser = ArgumentParser('Check abnormal attribution maps for which the attribution'
                            'Percentages are zero.')

    parser.add_argument('config', help='Config file to compute the attribution percentage.')
    parser.add_argument('check_ind', type=int, help='Sample index to check.')
    parser.add_argument(
        '--cfg-options', action=DictAction, nargs='+', help='Override config. Key-value pairs must be format xxx=yyy.')

    args = parser.parse_args()
    return args


def check_abnormal_attr_maps(cfg: Dict, check_ind: int, logger: Logger) -> None:
    dataset = ImageFolderWithAttribution(**cfg['dataset'])
    sample = dataset[check_ind]

    log_str = f'sample index: {check_ind}, '
    log_str += f"img_file: {sample['img_file']}"

    logger.info(log_str)


if __name__ == '__main__':
    args = parse_args()
    cfg = read_config(args.config)
    if args.cfg_options is not None:
        cfg = merge_from_options(cfg, args.cfg_options)
    logger = setup_logger('sde')
    logger.info(f'Using config:\n{yaml.dump(cfg, indent=4, sort_keys=False)}\n' + '-' * 60)

    check_abnormal_attr_maps(cfg, check_ind=args.check_ind, logger=logger)
