import cv2
import numpy as np
import supervision as sv
import argparse
import torch
import torchvision

from groundingdino.util.inference import Model
from segment_anything import SamPredictor
from MobileSAM.setup_mobile_sam import setup_model

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--MOBILE_SAM_CHECKPOINT_PATH", type=str, default="./EfficientSAM/mobile_sam.pt", help="model"
    )
    parser.add_argument(
        "--SOURCE_IMAGE_PATH", type=str, default="./assets/demo2.jpg", help="path to image file"
    )
    parser.add_argument(
        "--CAPTION", type=str, default="The running dog", help="text prompt for GroundingDINO"
    )
    parser.add_argument(
        "--OUT_FILE_BOX", type=str, default="groundingdino_annotated_image.jpg", help="the output filename"
    )
    parser.add_argument(
        "--OUT_FILE_SEG", type=str, default="grounded_mobile_sam_annotated_image.jpg", help="the output filename"
    )
    parser.add_argument(
        "--OUT_FILE_BIN_MASK", type=str, default="grounded_mobile_sam_bin_mask.jpg", help="the output filename"
    )
    parser.add_argument("--BOX_THRESHOLD", type=float, default=0.25, help="")
    parser.add_argument("--TEXT_THRESHOLD", type=float, default=0.25, help="")
    parser.add_argument("--NMS_THRESHOLD", type=float, default=0.8, help="")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    parser.add_argument(
        "--DEVICE", type=str, default=device, help="cuda:[0,1,2,3,4] or cpu"
    )
    return parser.parse_args()

def main(args):
  DEVICE = args.DEVICE

  # GroundingDINO config and checkpoint
  GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
  GROUNDING_DINO_CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"

  # Building GroundingDINO inference model
  grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

  # Building MobileSAM predictor
  MOBILE_SAM_CHECKPOINT_PATH = args.MOBILE_SAM_CHECKPOINT_PATH
  checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
  mobile_sam = setup_model()
  mobile_sam.load_state_dict(checkpoint, strict=True)
  mobile_sam.to(device=DEVICE)

  sam_predictor = SamPredictor(mobile_sam)


  # Predict classes and hyper-param for GroundingDINO
  SOURCE_IMAGE_PATH = args.SOURCE_IMAGE_PATH
  CLASSES = [args.CAPTION]
  BOX_THRESHOLD = args.BOX_THRESHOLD
  TEXT_THRESHOLD = args.TEXT_THRESHOLD
  NMS_THRESHOLD = args.NMS_THRESHOLD


  # load image
  image = cv2.imread(SOURCE_IMAGE_PATH)

  # detect objects
  detections = grounding_dino_model.predict_with_classes(
      image=image,
      classes=CLASSES,
      box_threshold=BOX_THRESHOLD,
      text_threshold=TEXT_THRESHOLD
  )

  # annotate image with detections
  box_annotator = sv.BoxAnnotator()
  labels = [
      f"{CLASSES[class_id]} {confidence:0.2f}" 
      for _, _, confidence, class_id, _ 
      in detections]
  annotated_frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)

  # save the annotated grounding dino image
  cv2.imwrite(args.OUT_FILE_BOX, annotated_frame)


  # NMS post process
  print(f"Before NMS: {len(detections.xyxy)} boxes")
  nms_idx = torchvision.ops.nms(
      torch.from_numpy(detections.xyxy), 
      torch.from_numpy(detections.confidence), 
      NMS_THRESHOLD
  ).numpy().tolist()

  detections.xyxy = detections.xyxy[nms_idx]
  detections.confidence = detections.confidence[nms_idx]
  detections.class_id = detections.class_id[nms_idx]

  print(f"After NMS: {len(detections.xyxy)} boxes")

  # Prompting SAM with detected boxes
  def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
      sam_predictor.set_image(image)
      result_masks = []
      for box in xyxy:
          masks, scores, logits = sam_predictor.predict(
              box=box,
              multimask_output=True
          )
          index = np.argmax(scores)
          result_masks.append(masks[index])
      return np.array(result_masks)


  # convert detections to masks
  detections.mask = segment(
      sam_predictor=sam_predictor,
      image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
      xyxy=detections.xyxy
  )

  binary_mask = detections.mask[0].astype(np.uint8)*255
  cv2.imwrite(args.OUT_FILE_BIN_MASK, binary_mask)

  # annotate image with detections
  box_annotator = sv.BoxAnnotator()
  mask_annotator = sv.MaskAnnotator()
  labels = [
      f"{CLASSES[class_id]} {confidence:0.2f}" 
      for _, _, confidence, class_id, _ 
      in detections]
  annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
  # save the annotated grounded-sam image
  cv2.imwrite(args.OUT_FILE_SEG, annotated_image)
  
if __name__ == "__main__":
  args = parse_args()
  main(args)
