import region_utils as utils
import os
import argparse
import json
import os
import copy
import cv2
from typing import Any, Dict, List
import argparse
from time import time
from tqdm import tqdm
import numpy as np
from typing import List, Optional, Tuple, Union
from PIL import Image
from datasets import load_dataset
import pandas as pd

def get_args_parser():
    parser = argparse.ArgumentParser()

    #sam regions
    parser.add_argument(
    "--input",
    type=str,
    default=None,
    help="Path to either a single input image or folder of images.")

    parser.add_argument(
        "--data_path",
        type=str,
        default=None,
        help="If exits, should be a json file with image paths and annotations",
    )

    parser.add_argument(
    "--output",
    type=str,
    default=None,
    help=(
        "Path to the directory where masks will be output. Output will be either a folder "
        "of PNGs per image or a single json with COCO-style masks."
    ))

    parser.add_argument(
        "--use-hq",
        action="store_true",
        help="Use HQ-SAM model for segmentation."
    )

    parser.add_argument(
        "--use-mobile",
        action="store_true",
        help="Use Mobile-SAM model for segmentation."
    )

    parser.add_argument(
        "--use-sam2",
        action="store_true",
        help="Use SAM2 model for segmentation."
    )

    parser.add_argument(
        "--sam2-model-cfg",
        type=str,
        default='sam2.1_hiera_l.yaml',
        choices=["sam2.1_hiera_l.yaml", "sam2.1_hiera_b+.yaml", "sam2.1_hiera_s.yaml", "sam2.1_hiera_t.yaml"],
        help="SAM2 model config, in ['sam2.1_hiera_l.yaml', 'sam2.1_hiera_b+.yaml', 'sam2.1_hiera_s.yaml', 'sam2.1_hiera_t.yaml']. ",
    )

    parser.add_argument(
        "--model-type",
        type=str,
        default='vit_h',
        choices=["default", "vit_h", "vit_l", "vit_b", "vit_t"],
        help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b', 'vit_t']. ",
    )

    parser.add_argument(
        "--checkpoint",
        type=str,
        default=None,
        help="The path to the SAM checkpoint to use for mask generation.",
    )

    parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")

    parser.add_argument(
        "--convert-to-rle",
        action="store_true",
        help=(
            "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
            "Requires pycocotools."
        ),
    )

    parser.add_argument(
        "--visualize",
        action="store_true",
        help=(
            "Visualize segmentation results instead of saving masks. "
        ),
    )

    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=(
            "Overwrite existing mask files. If not set, existing mask files will be skipped."
        ),
    )

    parser.add_argument(
        "--benchmark",
        action="store_true",
        help=(
            "Evaluate how long it takes on average for the model to generate masks for an image without saving"
            " any output regions. Runs --num-benchmark-trials times.",
        )
    )

    parser.add_argument(
        "--num-benchmark-trials",
        type=int,
        default=100,
        help="The number of times to run mask generation for benchmarking.",
    )

    amg_settings = parser.add_argument_group("AMG Settings")

    amg_settings.add_argument(
        "--num-granularities",
        type=int,
        default=1,
        help="The number of different granularities to generate masks at.",
    )

    parser.add_argument(
        "--nmasks-soft-min",
        type=int,
        default=64,
        help="Soft lower bound on the number of masks to generate.",
    )

    parser.add_argument(
        "--nmasks-soft-max",
        type=int,
        default=160,
        help="Soft upper bound on the number of masks to generate.",
    )

    amg_settings.add_argument(
        "--points-per-side",
        type=int,
        nargs="*",
        default=None,
        help="Generate masks by sampling a grid over the image with this many points to a side.",
    )

    amg_settings.add_argument(
        "--points-per-batch",
        type=int,
        nargs="*",
        default=None,
        help="How many input points to process simultaneously in one batch.",
    )

    amg_settings.add_argument(
        "--pred-iou-thresh",
        type=float,
        nargs="*",
        default=None,
        help="Exclude masks with a predicted score from the model that is lower than this threshold.",
    )

    amg_settings.add_argument(
        "--stability-score-thresh",
        type=float,
        nargs="*",
        default=None,
        help="Exclude masks with a stability score lower than this threshold.",
    )

    amg_settings.add_argument(
        "--stability-score-offset",
        type=float,
        nargs="*",
        default=None,
        help="Larger values perturb the mask more when measuring stability score.",
    )

    amg_settings.add_argument(
        "--box-nms-thresh",
        type=float,
        nargs="*",
        default=None,
        help="The overlap threshold for excluding a duplicate mask.",
    )

    amg_settings.add_argument(
        "--crop-n-layers",
        type=int,
        nargs="*",
        default=None,
        help=(
            "If >0, mask generation is run on smaller crops of the image to generate more masks. "
            "The value sets how many different scales to crop at."
        ),
    )

    amg_settings.add_argument(
        "--crop-nms-thresh",
        type=float,
        nargs="*",
        default=None,
        help="The overlap threshold for excluding duplicate masks across different crops.",
    )

    amg_settings.add_argument(
        "--crop-overlap-ratio",
        type=int,
        nargs="*",
        default=None,
        help="Larger numbers mean image crops will overlap more.",
    )

    amg_settings.add_argument(
        "--crop-n-points-downscale-factor",
        type=int,
        nargs="*",
        default=None,
        help="The number of points-per-side in each layer of crop is reduced by this factor.",
    )

    amg_settings.add_argument(
        "--min-mask-region-area",
        type=int,
        nargs="*",
        default=None,
        help=(
            "Disconnected mask regions or holes with area smaller than this value "
            "in pixels are removed by postprocessing."
        ),
    )

    return parser

from io import BytesIO
import base64
def load_image_from_base64(image):
    return Image.open(BytesIO(base64.b64decode(image)))

def load_sam_modules(args):
    checkpoint_basename = os.path.basename(args.checkpoint).lower()

    if args.use_hq and args.use_mobile:
        raise ValueError("Cannot have both --use-hq and --use-mobile")

    if args.use_hq:
        try:
            from segment_anything_hq import sam_model_registry, SamAutomaticMaskGenerator

        except ImportError as e:
            print(e)
            raise ImportError(
                "If segment_anything_hq is not installed, please install it via: pip install segment-anything-hq"
            )

        if args.model_type == 'vit_t':
            args.model_type = 'vit_tiny' # SAM-HQ calls it 'vit_tiny' instead of 'vit_t'

        # Verify that 'sam_hq_vit' is in the checkpoint name
        if not "sam_hq_vit" in checkpoint_basename:
            raise ValueError(
                f"Expected 'sam_hq_vit' in checkpoint name '{checkpoint_basename}'\n"
                + f"Please ensure that the checkpoint is downloaded from: https://github.com/SysCV/sam-hq#model-checkpoints"
            )

    elif args.use_mobile:
        try:
            from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator

        except ImportError as e:
            print(e)
            raise ImportError(
                "If mobile_sam is not installed, please install it via: pip install git+https://github.com/ChaoningZhang/MobileSAM.git"
            )

        # Verify that 'mobile_sam' is in the checkpoint name
        if not "mobile_sam" in checkpoint_basename:
            raise ValueError(
                f"Expected 'mobile_sam' in checkpoint name '{checkpoint_basename}'\n"
                + f"Please ensure that the checkpoint is downloaded from: https://github.com/ChaoningZhang/MobileSAM/raw/master/weights/mobile_sam.pt"
            )

        if not args.model_type == "vit_t":
            print("WARNING: Mobile-SAM uses the 'vit_t' model type; setting --model-type to 'vit_t'")
            args.model_type = "vit_t"

    else: # Default SAM model
        try:
            from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

        except ImportError as e:
            print(e)
            raise ImportError(
                "If segment_anything is not installed, please install it via: pip install git+https://github.com/facebookresearch/segment-anything.git"
            )

        # Verify that 'sam_vit' is in the checkpoint name
        if not "sam_vit" in checkpoint_basename:
            raise ValueError(
                f"Expected 'sam_vit' in checkpoint name '{checkpoint_basename}'\n"
                + f"Please ensure that the checkpoint is downloaded from: https://github.com/facebookresearch/segment-anything"
            )

        if args.model_type == "vit_t":
            raise ValueError(
                "The default SAM library does not support the 'vit_t' model type. "
                + "Please use --use-hq or --use-mobile to use a different model."
            )

    return sam_model_registry, SamAutomaticMaskGenerator

def get_sam_regions(args, dedup=True):
    if args.use_sam2:
        from sam2.build_sam import build_sam2
        from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator as SamAutomaticMaskGeneratorCls
        sam2_config = f"configs/{args.sam2_model_cfg.split('_')[0]}/{args.sam2_model_cfg}"
        sam = build_sam2(sam2_config, args.checkpoint, device=args.device)
    else:
        # basically copy of  segment-anything/scripts/amg.py
        sam_model_registry, SamAutomaticMaskGeneratorCls = load_sam_modules(args)
        sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
        sam.to(device=args.device)

    output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
    amg_kwargs = get_amg_kwargs(args)
    if args.num_granularities > 1:
        generator_list = [SamAutomaticMaskGeneratorCls(sam, output_mode=output_mode, **kwargs) for kwargs in amg_kwargs]
        generator = generator_list[(args.num_granularities - 1) // 2]
    else:
        generator = SamAutomaticMaskGeneratorCls(sam, output_mode=output_mode, **amg_kwargs)
    
    img_loaded = False
    if args.data_path is not None:
        if args.data_path.endswith('.json'):
            with open(args.data_path, 'r') as f:
                list_data_dict = json.load(f)
        elif args.data_path.endswith('.jsonl'):
            with open(args.data_path, 'r') as f:
                list_data_dict = [json.loads(l) for l in f]
        elif args.data_path.endswith('.tsv'):
            questions = pd.read_table(os.path.expanduser(args.data_path))
            all_image_files = [f"{q}.png" for q in questions['index']]
            targets = list(questions['image'])
            img_loaded = True
        elif args.data_path == "Lin-Chen/MMStar":
            questions = load_dataset("Lin-Chen/MMStar", "val")["val"]
            all_image_files = [f"{q}.png" for q in questions['index']]
            targets = list(questions['image'])
            img_loaded = True
        elif args.data_path == "echo840/OCRBench":
            questions = load_dataset("echo840/OCRBench")["test"]
            all_image_files = [f"{q}.png" for q in range(len(questions))]
            targets = list(questions['image'])
            img_loaded = True
        if not img_loaded:
            all_image_files = [d['image'] for d in list_data_dict if 'image' in d]
            if dedup:
                all_image_files = list(set(all_image_files))
            targets = [os.path.join(args.input, f) for f in all_image_files]
    elif not os.path.isdir(args.input):
        all_image_files = [os.path.basename(args.input)]
        targets = [args.input]
    else:
        all_image_files = [
            f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
        ]
        targets = [os.path.join(args.input, f) for f in all_image_files]

    os.makedirs(args.output, exist_ok=True)

    if args.benchmark:
        targets = targets[:args.num_benchmark_trials]
        print(f'Benchmarking with {len(targets)} trials...')

    gen_times = []
    num_regions = []
    pbar = tqdm(zip(all_image_files, targets), total=len(targets))
    for im, t in pbar:
        pbar.set_description(f"Processing {im}")
        base = os.path.splitext(im)[0]
        save_base = os.path.join(args.output, base)

        if os.path.isfile(save_base+".json") and not args.benchmark and args.convert_to_rle and not args.overwrite:
            continue

        if img_loaded:
            image = load_image_from_base64(t) if not isinstance(t, Image.Image) else t
            if image.mode != 'RGB':
                image = image.convert('RGB')
            REF_MAX_SOLUTION = 2048
            if max(image.size) > REF_MAX_SOLUTION:
                scale = int(max(image.size) / REF_MAX_SOLUTION + 0.8)
                if scale > 1:
                    new_size = (image.size[0] // scale, image.size[1] // scale)
                    print(f"Resize image from {image.size} to {new_size} for processing.")
                    image = image.resize(new_size, Image.LANCZOS)
            image = np.array(image)
        else:
            image = cv2.imread(t)
            if image is None:
                print(f"Could not load '{t}' as an image, skipping...")
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        start_time = time()
        try:
            masks = generator.generate(image)
        except torch.cuda.OutOfMemoryError as e:
            print('Cuda out of memory error')
            torch.cuda.empty_cache()
            continue
        if args.num_granularities > 1:
            idx = (args.num_granularities - 1) // 2
            used = [idx]
            while len(masks) < args.nmasks_soft_min and idx < args.num_granularities-1:
                idx += 1
                masks = generator_list[idx].generate(image)
                used.append(idx)
            while len(masks) > args.nmasks_soft_max and idx > 0:
                idx -= 1
                masks = generator_list[idx].generate(image)
                used.append(idx)
            if args.visualize:
                print(used, time()-start_time)
        end_time = time()

        gen_times.append(end_time - start_time)
        if args.benchmark:
            continue

        if args.visualize:
            os.makedirs(save_base, exist_ok=True)
            cv2.imwrite(os.path.join(save_base, "img.png"), cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
            utils.show_image(image, masks, os.path.join(save_base, "masks.png"))
            open(os.path.join(save_base, f"num={len(masks)}.txt"), "w")
            if args.num_granularities > 1:
                open(os.path.join(save_base, f"granularity={idx}.txt"), "w")
            metas = {
                "area": [m["area"] for m in masks],
                "predicted_iou": [m["predicted_iou"] for m in masks],
                "stability_score": [m["stability_score"] for m in masks],
            }
            with open(os.path.join(save_base, "meta.json"), "w") as f:
                json.dump(metas, f, indent=2)
            num_regions.append(len(masks))
        elif output_mode == "binary_mask":
            os.makedirs(save_base, exist_ok=False)
            write_masks_to_folder(masks, save_base)
        else:
            save_file = save_base + ".json"
            os.makedirs(os.path.dirname(save_file), exist_ok=True)
            with open(save_file, "w") as f:
                json.dump(masks, f)

    if len(gen_times) > 0: # May not generate anything if everything already exists
        print(f"Average time per image with {len(gen_times)} trials (seconds): {sum(gen_times)/len(gen_times)}")

    if args.convert_to_rle:
        # add region ids
        if args.data_path is not None:
            sam_files = [os.path.splitext(f)[0]+".json" for f in all_image_files]
        else:
            sam_files = os.listdir(args.output)
        for f in sam_files:
            new_sam_regions = []
            try:
                all_regions = utils.open_file(os.path.join(args.output,f))
            except:
                print(f"Could not load '{f}' as a json file, skipping...")
                continue
            for i,region in enumerate(all_regions):
                image_id = f.replace('.json','')
                region_id = f'{image_id}_region_{i}'
                new_region = copy.deepcopy(region)
                new_region['region_id'] = region_id
                new_sam_regions.append(new_region)

            utils.save_file(os.path.join(args.output, f), new_sam_regions)

    print("Done!")
    if args.visualize:
        print(f"Average number of regions: {sum(num_regions)/len(num_regions)}")

def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
    header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h"  # noqa
    metadata = [header]
    for i, mask_data in enumerate(masks):
        mask = mask_data["segmentation"]
        filename = f"{i}.png"
        cv2.imwrite(os.path.join(path, filename), mask * 255)
        mask_metadata = [
            str(i),
            str(mask_data["area"]),
            *[str(x) for x in mask_data["bbox"]],
            *[str(x) for x in mask_data["point_coords"][0]],
            str(mask_data["predicted_iou"]),
            str(mask_data["stability_score"]),
            *[str(x) for x in mask_data["crop_box"]],
        ]
        row = ",".join(mask_metadata)
        metadata.append(row)
    metadata_path = os.path.join(path, "metadata.csv")
    with open(metadata_path, "w") as f:
        f.write("\n".join(metadata))

    return

def get_amg_kwargs(args):
    amg_kwargs = {
        "points_per_side": args.points_per_side,
        "points_per_batch": args.points_per_batch,
        "pred_iou_thresh": args.pred_iou_thresh,
        "stability_score_thresh": args.stability_score_thresh,
        "stability_score_offset": args.stability_score_offset,
        "box_nms_thresh": args.box_nms_thresh,
        "crop_n_layers": args.crop_n_layers,
        "crop_nms_thresh": args.crop_nms_thresh,
        "crop_overlap_ratio": args.crop_overlap_ratio,
        "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
        "min_mask_region_area": args.min_mask_region_area,
    }
    amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
    if args.num_granularities > 1:
        for k, v in amg_kwargs.items():
            if len(v) == 1:
                amg_kwargs[k] = v * args.num_granularities
        amg_kwargs = [dict(zip(amg_kwargs, t)) for t in zip(*amg_kwargs.values())]
    else:
        amg_kwargs = {k: v[0] for k, v in amg_kwargs.items()}
    # import pdb; pdb.set_trace()
    print(json.dumps(amg_kwargs, indent=4))
    return amg_kwargs

if __name__ == '__main__':
    parser = get_args_parser()
    args = parser.parse_args()
    get_sam_regions(args)