try:
    import detectron2
except:
    import os 
    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
    # os.system('git clone https://github.com/facebookresearch/detectron2.git')
    # os.system('python -m pip install -e detectron2')
    
import gradio as gr
import numpy as np
import cv2
import torch

from detectron2.config import get_cfg
from GLEE.glee.models.glee_model import GLEE_Model
from GLEE.glee.config_deeplab import add_deeplab_config
from GLEE.glee.config import add_glee_config
import torch.nn.functional as F
import torchvision
import math
from obj365_name import categories as OBJ365_CATEGORIESV2


print(f"Is CUDA available: {torch.cuda.is_available()}")
# True
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
# Tesla T4

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)



def scribble2box(img):
    if img.max()==0:
        return None, None
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    all = np.any(img,axis=2)
    R,G,B,A = img[np.where(all)[0][0],np.where(all)[1][0]].tolist()  # get color 
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]
    return np.array([ xmin,ymin, xmax,ymax]), (R,G,B)

 
def LSJ_box_postprocess( out_bbox,  padding_size, crop_size, img_h, img_w):
    # postprocess box height and width
    boxes = box_cxcywh_to_xyxy(out_bbox)
    lsj_sclae = torch.tensor([padding_size[1], padding_size[0], padding_size[1], padding_size[0]]).to(out_bbox)
    crop_scale = torch.tensor([crop_size[1], crop_size[0], crop_size[1], crop_size[0]]).to(out_bbox)
    boxes = boxes * lsj_sclae
    boxes = boxes / crop_scale
    boxes = torch.clamp(boxes,0,1)

    scale_fct = torch.tensor([img_w, img_h, img_w, img_h])
    scale_fct = scale_fct.to(out_bbox)
    boxes = boxes * scale_fct
    return boxes

COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
                [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933],
                [0.494, 0.000, 0.556], [0.494, 0.000, 0.000], [0.000, 0.745, 0.000],
                [0.700, 0.300, 0.600],[0.000, 0.447, 0.741], [0.850, 0.325, 0.098]]



coco_class_name = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
coco_class_name = [cat['name'] for cat in OBJ365_CATEGORIESV2]
OBJ365_class_names = [cat['name'] for cat in OBJ365_CATEGORIESV2]
class_agnostic_name = ['object']

if torch.cuda.is_available():
    print('use cuda')
    device = 'cuda'
else:
    print('use cpu')
    device='cpu'

cfg_r50 = get_cfg()
add_deeplab_config(cfg_r50)
add_glee_config(cfg_r50)
conf_files_r50 = 'GLEE/configs/R50.yaml'
checkpoints_r50 = torch.load('GLEE_R50_Scaleup10m.pth') 
cfg_r50.merge_from_file(conf_files_r50)
GLEEmodel_r50 = GLEE_Model(cfg_r50, None, device, None, True).to(device)
GLEEmodel_r50.load_state_dict(checkpoints_r50, strict=False)
GLEEmodel_r50.eval()


cfg_swin = get_cfg()
add_deeplab_config(cfg_swin)
add_glee_config(cfg_swin)
conf_files_swin = 'GLEE/configs/SwinL.yaml'
checkpoints_swin = torch.load('GLEE_SwinL_Scaleup10m.pth') 
cfg_swin.merge_from_file(conf_files_swin)
GLEEmodel_swin = GLEE_Model(cfg_swin, None, device, None, True).to(device)
GLEEmodel_swin.load_state_dict(checkpoints_swin, strict=False)
GLEEmodel_swin.eval()

pixel_mean = torch.Tensor( [123.675, 116.28, 103.53]).to(device).view(3, 1, 1)
pixel_std = torch.Tensor([58.395, 57.12, 57.375]).to(device).view(3, 1, 1)
normalizer = lambda x: (x - pixel_mean) / pixel_std
inference_size = 800
inference_type = 'resize_shot'  # or LSJ 
size_divisibility = 32

FONT_SCALE = 1.5e-3
THICKNESS_SCALE = 1e-3
TEXT_Y_OFFSET_SCALE = 1e-2 


if inference_type != 'LSJ':
    resizer = torchvision.transforms.Resize(inference_size)


def inference(img_path=None,prompt_mode='categories', categoryname='COCO-80', custom_category=None, expressiong=None, results_select=['box', 'mask', 'name', 'score'], num_inst_select=15, threshold_select=0.3, mask_image_mix_ration=0.65):
    '''
    img{'background':  (800,1200,4), 'layers': list[ (800,1200,4(RGBA)) ],  'composite': (800,1200,4(RGBA))}
    prompt_mode: 'categories'
    categoryname: 'COCO-80'
    custom_category: 'dog, cat, car, person'
    expressiong: 'the red car'
    results_select: ['box', 'mask', 'name', 'score']
    num_inst_select: 15
    threshold_select: 0.2
    mask_image_mix_ration: 0.65
    model_selection: 'GLEE-Plus (SwinL)'
    '''

    GLEEmodel = GLEEmodel_swin
    print('use GLEE-Plus')

    # copyed_img = img['background'][:,:,:3].copy()
    # ori_image = torch.as_tensor(np.ascontiguousarray(copyed_img.transpose(2, 0, 1)))
    copyed_img = cv2.imread(img_path)
    copyed_img = cv2.cvtColor(copyed_img, cv2.COLOR_BGR2RGB)
    ori_image = torch.as_tensor(np.ascontiguousarray(copyed_img.transpose(2, 0, 1)))
    
    ori_image = normalizer(ori_image.to(device))[None,]
    _,_, ori_height, ori_width = ori_image.shape

    if inference_type == 'LSJ':
        infer_image = torch.zeros(1,3,1024,1024).to(ori_image)
        infer_image[:,:,:inference_size,:inference_size] = ori_image
    else:
        resize_image = resizer(ori_image)
        image_size = torch.as_tensor((resize_image.shape[-2],resize_image.shape[-1]))
        re_size = resize_image.shape[-2:]
        if size_divisibility > 1:
            stride = size_divisibility
            # the last two dims are H,W, both subject to divisibility requirement
            padding_size = ((image_size + (stride - 1)).div(stride, rounding_mode="floor") * stride).tolist()
            infer_image = torch.zeros(1,3,padding_size[0],padding_size[1]).to(resize_image)
            infer_image[0,:,:image_size[0],:image_size[1]] = resize_image
            # reversed_image = infer_image*pixel_std +  pixel_mean
            # reversed_image = torch.clip(reversed_image,min=0,max=255)
            # reversed_image = reversed_image[0].permute(1,2,0)
            # reversed_image = reversed_image.int().cpu().numpy().copy()
            # cv2.imwrite('test.png',reversed_image[:,:,::-1])


    if prompt_mode == 'categories' or prompt_mode == 'expression':
        if len(results_select)==0:
            results_select=['box']
        if  categoryname =="COCO-80":
            batch_category_name = coco_class_name
        else:
            batch_category_name = class_agnostic_name

        # mask_ori = torch.from_numpy(np.load('03_moto_mask.npy'))[None,]
        # mask_ori = (F.interpolate(mask_ori, (height, width), mode='bilinear') > 0).to(device)
        # prompt_list = [mask_ori[0]]
        prompt_list = []
        with torch.no_grad():
            (outputs,_) = GLEEmodel(infer_image, prompt_list, task="coco", batch_name_list=batch_category_name, is_train=False)
        topK_instance = max(num_inst_select,1)

        mask_pred = outputs['pred_masks'][0]
        mask_cls = outputs['pred_logits'][0]
        boxes_pred = outputs['pred_boxes'][0]
        
        features = outputs['pred_track_embed'][0]
        
        scores = mask_cls.sigmoid().max(-1)[0]
        scores_per_image, topk_indices = scores.topk(topK_instance, sorted=True)
        if  prompt_mode == 'categories':
            valid = scores_per_image>threshold_select
            topk_indices = topk_indices[valid]
            scores_per_image = scores_per_image[valid]

        pred_class = mask_cls[topk_indices].max(-1)[1].tolist()
        pred_boxes = boxes_pred[topk_indices] 


        boxes = LSJ_box_postprocess(pred_boxes,padding_size,re_size, ori_height,ori_width)
        mask_pred = mask_pred[topk_indices]
        pred_masks = F.interpolate(mask_pred[None,], size=(padding_size[0], padding_size[1]), mode="bilinear", align_corners=False  )
        pred_masks = pred_masks[:,:,:re_size[0],:re_size[1]]
        pred_masks = F.interpolate( pred_masks, size=(ori_height,ori_width), mode="bilinear", align_corners=False  )
        pred_masks = (pred_masks>0).detach().cpu().numpy()[0]
        
        features_pred = features[topk_indices]
        results = torch.cat((boxes, scores_per_image.unsqueeze(1),mask_cls[topk_indices].max(-1)[1].unsqueeze(1)), 1)
        return results, pred_masks, features_pred

def main(img_root=None, save_root=None, threshold_select=0.3):
    # demo.launch(inbrowser=True,)
    from glob import glob
    from tqdm import tqdm
    import os
    imgs = glob(img_root+'*.jpg')
    imgs.sort()
    
    area = cv2.imread(imgs[0]).shape[0] * cv2.imread(imgs[0]).shape[1]
    if os.path.exists(save_root) == False:
        os.makedirs(save_root)
    save_format = '{frame},-1,{x1:.2f},{y1:.2f},{x2:.2f},{y2:.2f},{score:.2f},{cls},-1,-1\n'
    mask_dict = {}
    npy_data = np.zeros((1,266))
    for i, img in tqdm(enumerate(imgs)):
        save_name = img.split('/')[-1].split('.')[0]
        # f = open('../../dust3r/test.txt','w')
        try:
            bboxes, pred_masks, features_pred = inference(img_path=img,threshold_select=threshold_select)
        except:
            continue
        
        # if all bboxes are 0, skip
        if torch.sum(bboxes[:,-1]) == 0:
            continue
            # # save bboxes
            # f = open(save_root+save_name+'.txt', 'w')
            # for box in bboxes:
            #     fid = int(float(img.split('/')[-1].split('.')[0]))
            #     f.write(save_format.format(frame=fid, x1=box[0], y1=box[1], x2=box[2], y2=box[3], score=box[4], cls=int(box[5])))
            # f.close()
            
            # # save masks and corresponding box to dict
            # for i, mask in enumerate(pred_masks):
            #     # mask = mask.cpu().numpy()
            #     zeros = np.zeros((mask.shape[0], mask.shape[1]))
            #     x1 = int(bboxes[i][0])
            #     y1 = int(bboxes[i][1])
            #     x2 = int(bboxes[i][2])
            #     y2 = int(bboxes[i][3])

            #     zeros[y1:y2, x1:x2] = True           
            #     mask = mask * zeros
            #     mask = mask.astype(np.bool_)
                
            #     mask_dict[i] = {'mask': mask, 'box': bboxes[i].cpu().numpy(), 'feature': features_pred[i].cpu().numpy()}

            #     # cv2.imwrite('{}_after.jpg'.format(i), mask*255)
            # for boxid, box in enumerate(bboxes):
            #     temp_data = np.zeros((1,266))
            #     temp_data[:,0] = int(float(img.split('/')[-1].split('.')[0]))
            #     temp_data[:,2:8] = box.cpu().numpy()[:6]
            #     temp_data[:,10:] = features_pred[boxid].cpu().numpy()
            #     npy_data = np.vstack((npy_data, temp_data))

        
        # save bboxes
        f = open(save_root+save_name+'.txt', 'w')
        for box in bboxes:
            if int(box[5]) == 0:
                continue
            if (box[2]-box[0])*(box[3]-box[1]) < 0.001*area:
                continue
            fid = int(float(img.split('/')[-1].split('.')[0]))
            f.write(save_format.format(frame=fid, x1=box[0], y1=box[1], x2=box[2], y2=box[3], score=box[4], cls=int(box[5])))
        f.close()
        
        # save masks and corresponding box to dict
        for i, mask in enumerate(pred_masks):
            # mask = mask.cpu().numpy()
            zeros = np.zeros((mask.shape[0], mask.shape[1]))
            if int(bboxes[i][-1]) == 0:
                continue
            x1 = int(bboxes[i][0])
            y1 = int(bboxes[i][1])
            x2 = int(bboxes[i][2])
            y2 = int(bboxes[i][3])
            if (x2-x1)*(y2-y1) < 0.001*area:
                continue
            zeros[y1:y2, x1:x2] = True           
            mask = mask * zeros
            mask = mask.astype(np.bool_)
            
            mask_dict[i] = {'mask': mask, 'box': bboxes[i].cpu().numpy(), 'feature': features_pred[i].cpu().numpy()}

            # cv2.imwrite('{}_after.jpg'.format(i), mask*255)
        for boxid, box in enumerate(bboxes):
            if int(box[5]) == 0:
                continue
            
            if (box[2]-box[0])*(box[3]-box[1]) < 0.001*area:
                continue
            
            temp_data = np.zeros((1,266))
            temp_data[:,0] = int(float(img.split('/')[-1].split('.')[0]))
            temp_data[:,2:8] = box.cpu().numpy()[:6]
            temp_data[:,10:] = features_pred[boxid].cpu().numpy()
            npy_data = np.vstack((npy_data, temp_data))
    np.save(save_root.split('dets/')[0] + 'dets.npy', npy_data[1:])    
        # save mask dict to file
    np.save(save_root + img.split('/')[-1].split('.')[0] + '.npy', mask_dict)        

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='GLEE demo')
    # we need image root and save rott as parser
    parser.add_argument('--img_root', type=str, default='/Code/dust3r/croco/assets/test/')
    parser.add_argument('--save_root', type=str, default='/Code/dust3r/outputs/')
    parser.add_argument('--exp_name', type=str, default='test10')
    parser.add_argument('--threshold_select', type=float, default=0.3)
    args = parser.parse_args()
    args.save_root = args.save_root + args.exp_name + '/dets/'
    
    main(args.img_root, args.save_root, args.threshold_select)
