import cv2
import sys
import argparse
import numpy as np
from PIL import Image
from pathlib import Path
from matplotlib import pyplot as plt
from typing import Any, Dict, List
import glob

from utils import load_img_to_array, show_mask


def setup_args(parser):
    parser.add_argument(
        "--input_img", type=str, required=True,
        help="Path to a single input img",
    )
    parser.add_argument(
        "--input_mask_glob", type=str, required=True,
        help="Glob to input masks",
    )
    parser.add_argument(
        "--output_dir", type=str, required=True,
        help="Output path to the directory with results.",
    )

if __name__ == "__main__":
    """Example usage:
    python visual_mask_on_img.py \
        --input_img FA_demo/FA1_dog.png \
        --input_mask_glob "results/FA1_dog/mask*.png" \
        --output_dir results
    """
    parser = argparse.ArgumentParser()
    setup_args(parser)
    args = parser.parse_args(sys.argv[1:])

    img = load_img_to_array(args.input_img)
    img_stem = Path(args.input_img).stem

    mask_ps = sorted(glob.glob(args.input_mask_glob))

    out_dir = Path(args.output_dir) / img_stem
    out_dir.mkdir(parents=True, exist_ok=True)

    for mask_p in mask_ps:
        mask = load_img_to_array(mask_p)
        mask = mask.astype(np.uint8)

        # path to the results
        img_mask_p = out_dir / f"with_{Path(mask_p).name}"

        # save the masked image
        dpi = plt.rcParams['figure.dpi']
        height, width = img.shape[:2]
        plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
        plt.imshow(img)
        plt.axis('off')
        show_mask(plt.gca(), mask, random_color=False)
        plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
        plt.close()
