from segment_anything import SamPredictor, sam_model_registry
import torch
import cv2
import numpy as np
from tqdm import tqdm
def test():
    sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    device = "cuda"
    sam.to(device=device)
    predictor = SamPredictor(sam)

    image = cv2.imread('/Code/dust3r/croco/assets/test/color_0000914.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    input_boxes = torch.tensor([
        [346,24,762,549],
        [711,124,1028,616],
    ], device=predictor.device)
    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )


import numpy as np

def calculate_iou(box1, box2):
    x1, y1, x2, y2 = max(box1[0], box2[0]), max(box1[1], box2[1]), min(box1[2], box2[2]), min(box1[3], box2[3])
    intersection_area = max(0, x2 - x1) * max(0, y2 - y1)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = box1_area + box2_area - intersection_area
    return intersection_area / union_area if union_area != 0 else 0

def nms(boxes, confidence_threshold=0.06, iou_threshold=0.7):
    boxes = [box for box in boxes if box[4] >= confidence_threshold]

    boxes.sort(key=lambda x: x[4], reverse=True)

    selected_boxes = []
    while boxes:
        current_box = boxes.pop(0)
        selected_boxes.append(current_box)

        boxes = [box for box in boxes if calculate_iou(current_box, box) < iou_threshold]

    return selected_boxes



def main_ablation(img_root, det_root, save_root, img_size=(512,288)):
    sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    device = "cuda"
    sam.to(device=device)
    predictor = SamPredictor(sam)
    from glob import glob
    # imgs = glob(img_root + '/*.jpg')
    # imgs.sort()
    
    imgs = glob(img_root + '/*.jpg')
    imgs.sort()
    
    import os
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    dets = np.load(det_root, allow_pickle=True).item()
    for img in tqdm(imgs):
        origin_image = cv2.imread(img)
        # read shape of image
        h, w, _ = origin_image.shape
        
        image = cv2.resize(origin_image, img_size)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image)
        # read boxes from npy
        bboxes = dets[int(float(img.split('/')[-1].split('.')[0]))]
        
        # use NMS to filter boxes
        bboxes = np.array(bboxes)
        boxes = bboxes[:, :5]
        input_boxes = nms(boxes)
        
        # to numpy
        input_boxes = np.array(input_boxes)
        
        # transform boxes to tensor
        input_boxes = torch.tensor(input_boxes[:, :4], device=predictor.device)
        # create a dictionary to store the boxes and corresponding masks
        mask_dict = {}
        # resize the boxes to image space
        try:
            input_boxes[:,0] = input_boxes[:,0] * img_size[0] / w
            input_boxes[:,1] = input_boxes[:,1] * img_size[1] / h
            input_boxes[:,2] = input_boxes[:,2] * img_size[0] / w
            input_boxes[:,3] = input_boxes[:,3] * img_size[1] / h
        except:
            import pdb; pdb.set_trace()

        # transform boxes to image space
        
        transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes[:, :4], image.shape[:2])
        # predict masks
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,    
        )    
           
        # save masks and corresponding box to dict
        for i, mask in enumerate(masks):
            mask = mask.cpu().numpy()[0]
            mask_dict[i] = {'mask': mask, 'box': input_boxes[i].cpu().numpy()}
        # save mask dict to file
        np.save(save_root + img.split('/')[-1].split('.')[0] + '.npy', mask_dict)
    print('segmentation done')



def main(img_root, det_root, save_root, img_size=(512,288)):
    sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
    model_type = "vit_h"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    device = "cuda"
    sam.to(device=device)
    predictor = SamPredictor(sam)
    from glob import glob
    # imgs = glob(img_root + '/*.jpg')
    # imgs.sort()
    
    imgs = glob(det_root + '/*.txt')
    imgs.sort()
    
    imgs = [img_root + img.split('/')[-1].split('.')[0] + '.jpg' for img in imgs]
    imgs.sort()
    
    import os
    
    if not os.path.exists(save_root):
        os.makedirs(save_root)
    
    for img in imgs:
        origin_image = cv2.imread(img)
        # read shape of image
        h, w, _ = origin_image.shape
        
        image = cv2.resize(origin_image, img_size)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        predictor.set_image(image)
        # read boxes from file
        f = open(det_root + img.split('/')[-1].split('.')[0] + '.txt', 'r')
        lines = f.readlines()
        dets = [line.strip().split(',') for line in lines]
        # transform boxes to tensor
        input_boxes = torch.tensor([[float(x) for x in det[2:8]] for det in dets], device=predictor.device)
        # create a dictionary to store the boxes and corresponding masks
        mask_dict = {}
        # resize the boxes to image space
        try:
            input_boxes[:,0] = input_boxes[:,0] * img_size[0] / w
            input_boxes[:,1] = input_boxes[:,1] * img_size[1] / h
            input_boxes[:,2] = input_boxes[:,2] * img_size[0] / w
            input_boxes[:,3] = input_boxes[:,3] * img_size[1] / h
        except:
            input_boxes[:,0] = 100 * img_size[0] / w
            input_boxes[:,1] = 100 * img_size[1] / h
            input_boxes[:,2] = 110 * img_size[0] / w
            input_boxes[:,3] = 110 * img_size[1] / h

        # transform boxes to image space
        
        transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes[:, :4], image.shape[:2])
        # predict masks
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,    
        )
        
        # # visualize masks
        # def show_mask(mask, ax, random_color=False):
        #     if random_color:
        #         color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
        #     else:
        #         color = np.array([30/255, 144/255, 255/255, 0.6])
        #     h, w = mask.shape[-2:]
        #     mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
        #     ax.imshow(mask_image)
        #import matplotlib.pyplot as plt
        # for i, mask in enumerate(masks):
        #     mask = mask.cpu().numpy()
        #     plt.figure(figsize=(10,10))
        #     plt.imshow(image)
        #     show_mask(mask, plt.gca())
        #     plt.axis('off')
        #     plt.show()      
          
            
        
        # save masks and corresponding box to dict
        for i, mask in enumerate(masks):
            mask = mask.cpu().numpy()[0]
            mask_dict[i] = {'mask': mask, 'box': input_boxes[i].cpu().numpy()}
        # save mask dict to file
        np.save(save_root + img.split('/')[-1].split('.')[0] + '.npy', mask_dict)
    print('segmentation done')
            
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='SAM demo')
    parser.add_argument('--img_root', type=str, default='/Code/dust3r/croco/assets/test/')
    parser.add_argument('--save_root', type=str, default='/Code/dust3r/outputs/')
    parser.add_argument('--ab_root', type=str, default='/Code/dust3r/outputs/')
    parser.add_argument('--exp_name', type=str, default='test10')
    parser.add_argument('--img_size', type=tuple, default=(512,384))
    parser.add_argument('--data_type', type=str, default='.txt')
    
    args = parser.parse_args()
    if '1' in args.img_root.split('/')[-3]:
        args.img_size = (512,288)
    elif '2' in args.img_root.split('/')[-3]:
        args.img_size = (512, 384)
    elif '3' in args.img_root.split('/')[-3]:
        args.img_size = (512, 384)
    elif '4' in args.img_root.split('/')[-3]:
        args.img_size = (512, 384)
    elif '5' in args.img_root.split('/')[-3]:
        args.img_size = (512, 384)
    elif '6' in args.img_root.split('/')[-3]:
        args.img_size = (512, 384)

    if args.data_type == '.txt':
        args.det_root = args.save_root + args.exp_name + '/dets/'
        args.save_root = args.save_root + args.exp_name + '/masks/'
        print('start segmentation')
        main(args.img_root, args.det_root, args.save_root, img_size=args.img_size)
    elif args.data_type == '.npy':
        img_set = args.img_root.split('/')[-3]
        args.det_root = args.ab_root + 'dets/npy/' + img_set + '.npy'
        save_path = args.ab_root + 'masks/' + img_set + '/'
        print('start segmentation')
        main_ablation(args.img_root, args.det_root, save_path, img_size=args.img_size)