from argparse import ArgumentParser
import os
import os.path as osp
from typing import List
from glob import glob

import cv2
import numpy as np
from sde.utils import mkdir_or_exist
from tqdm import tqdm


def parse_args():
    parser = ArgumentParser(
        'Compute the difference of IG attribution arrays, which are obtained '
        'using different baselines.')

    parser.add_argument('attr_array_roots', nargs='+', help='Paths of the attribution arrays')
    parser.add_argument('-o', '--out-path', default='workdirs/debug', help='output directory to save the result.')

    return parser.parse_args()


def compute_mae(attr_array_roots: List[str], out_path: str) -> None:
    mkdir_or_exist(out_path)
    one_array_root = attr_array_roots[0]
    for sub_dir in os.listdir(one_array_root):
        mkdir_or_exist(osp.join(out_path, sub_dir))

    files = [osp.relpath(x, one_array_root) for x in glob(osp.join(one_array_root, '**/*.npy'))]

    for file in tqdm(files):
        attr_arrays = np.stack([np.load(osp.join(root, file)) for root in attr_array_roots], 0)
        max_attr_array = attr_arrays.max(0)
        min_attr_array = attr_arrays.min(0)
        diff = max_attr_array - min_attr_array
        diff = np.clip(diff * 255, 0, 255).astype(np.uint8)
        out_file = osp.splitext(file)[0] + '.png'
        is_written = cv2.imwrite(osp.join(out_path, out_file), diff)
        if not is_written:
            raise RuntimeError('Result is not written')


if __name__ == '__main__':
    args = vars(parse_args())
    compute_mae(**args)
