import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
from tqdm import tqdm
from glob import glob
import supervision as sv
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import time 
import argparse


parser = argparse.ArgumentParser(description="Extract quadric segmentation masks from images.")


parser.add_argument("--model_type", dest="model_type", help="path to pretrained models")
parser.set_defaults(pretrained_models="vit")

parser.add_argument("--pretrained_models", dest="pretrained_models", help="path to pretrained models")
parser.set_defaults(pretrained_models="ckpts/sam_vit_h_4b8939.pth")

parser.add_argument("--data_root", dest="data_root", help="path to rgb image")
parser.set_defaults(data_root="/data/nerf_dataset/raw_replica/Replica_full")

parser.add_argument("--output_path", dest="output_path", help="path to where output image should be stored")
parser.set_defaults(output_path="/data/nerf_dataset/raw_replica/Replica_full")


args = parser.parse_args()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def label_image(anns):
    if len(anns) == 0:
        return None
    
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    seg = np.zeros_like(sorted_anns[0]['segmentation'], dtype=np.uint8)

    for i, ann in enumerate(sorted_anns):
        m = ann['segmentation']
        seg[m] = i + 1
    
    return seg

def resort_segment(SegResult):
    new_annos = []
    N = int(np.max(SegResult))
    for idx in range(N):
        mask = (SegResult == (idx + 1))
        new_anno = {"area": int(np.sum(mask)), "segmentation": mask}
        new_annos.append(new_anno)
    
    return label_image(new_annos)
      

def process_all():
    
    model_type = args.model_type
    sam_checkpoint = args.pretrained_models

    root_path = args.data_root

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    mask_generator = SamAutomaticMaskGenerator(sam, min_mask_region_area=200)
    scenes = ["room_0", "room_1", "room_2", "office_0", "office_1", "office_2", "office_3", "office_4"]

    for scene in scenes:
        scene_path = os.path.join(root_path, scene)
        if not os.path.exists(scene_path):
            continue
        
        print("Processing scene:{}".format(scene))
        image_dir = os.path.join(scene_path, "rgb")
        all_images = sorted(glob(os.path.join(image_dir, "rgb*")))
        
        save_path = os.path.join(args.output_path, scene, "segmentation")
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        for image_path in tqdm(all_images):
            
            index = int(image_path.split("/")[-1].split(".")[0].split("_")[-1])
     
            cv_image = cv2.imread(image_path)
            
            image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
            masks = mask_generator.generate(image)
            
            assert len(masks) < 256, "too many segments, image path:{}".format(image_path)

            SegResult = label_image(masks)
           
            mask_annotator = sv.MaskAnnotator()
            detections = sv.Detections.from_sam(masks)
            annotated_image = mask_annotator.annotate(cv_image, detections)
        
            np.save(os.path.join(save_path, "segment{}.npy".format(index)), SegResult)
            cv2.imwrite(os.path.join(save_path, "segment_vis{}.png".format(index)), annotated_image)
        

if __name__ == "__main__":
    process_all()