import torch
import argparse
import numpy as np
from matplotlib import pyplot as plt
import os
import json
from tqdm import tqdm
from lama_inpaint import inpaint_img_with_lama_loaded, load_lama_model
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask


def main(args):
    lama_config="lama/configs/prediction/default.yaml"
    lama_ckpt="./pretrained_models/big-lama"
    model, predict_config = load_lama_model(lama_config, lama_ckpt)

    os.makedirs(args.output_folder, exist_ok=True)
    data = [json.loads(line.strip()) for line in open(args.annotation_file, "r")]
    for d in tqdm(data):
        question_id = d["question_id"]
        file_id = d["file_id"]
        image_path = os.path.join(args.image_folder, file_id + '.jpg')
        if "COCO" in file_id:
            sub_folder = file_id.split("_")[1]
            image_path = os.path.join(args.image_folder, sub_folder, file_id + '.jpg')
            assert os.path.exists(image_path)
        output_filename = f"{file_id}-{question_id}"
        mask_file = os.path.join(args.mask_folder, f'{file_id}- {question_id}_mask.npy')
        if not os.path.exists(mask_file):
            print(f"Mask file {mask_file} does not exist")
            continue
        remove(model, predict_config, image_path, mask_file, output_filename, args.output_folder)


def remove(model, predict_config, img_file, mask_file, output_filename, output_dir, dilate_kernel_size=15):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    img = load_img_to_array(img_file)


    masks = np.load(mask_file)
    num_objects = len(masks)
    masks = masks.sum(axis=0).astype(np.bool)
    masks = masks.astype(np.uint8) * 255

    # dilate mask to avoid unmasked edge effect
    if num_objects > 1:
      dilate_kernel_size = 20
    if dilate_kernel_size is not None:
      masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
        
    # visualize the segmentation results
    for idx, mask in enumerate(masks):
        # path to the results
        mask_p = os.path.join(output_dir, f"{output_filename}_mask_{idx}.png")
        img_mask_p = os.path.join(output_dir, f"{output_filename}_w_mask_{idx}.png")
        if not os.path.exists(mask_p):
            # save the mask
            save_array_to_img(mask, mask_p)

        if not os.path.exists(img_mask_p):
            # save the pointed and masked image
            dpi = plt.rcParams['figure.dpi']
            height, width = img.shape[:2]
            plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
            plt.imshow(img)
            plt.axis('off')
            show_mask(plt.gca(), mask, random_color=False)
            plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
            plt.close()

    # inpaint the masked image
    for idx, mask in enumerate(masks):
        mask_p = os.path.join(output_dir, f"{output_filename}_mask_{idx}.png")
        img_inpainted_p = os.path.join(output_dir, f"{output_filename}_remove_{idx}.png")
        if not os.path.exists(img_inpainted_p):
            img_inpainted = inpaint_img_with_lama_loaded(
                model, predict_config, img, mask, device=device)
            save_array_to_img(img_inpainted, img_inpainted_p)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--annotation_file', type=str, default='../data/vqav2/vqa_k_test_noun_gpt4.jsonl')
    parser.add_argument('--image_folder', type=str, default='../data/vqav2/images') 
    parser.add_argument('--mask_folder', type=str, default='../data/vqav2/images/remove_anything/gsam_masks') 
    parser.add_argument('--output_folder', type=str, default='../data/vqav2/images/remove_anything/lama') 
    args = parser.parse_args()
    main(args)
