from dust3r.inference import inference, load_model
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode

from glob import glob
import numpy as np
import torch

from dust3r.utils.geometry import find_reciprocal_matches, xy_grid

import cv2
from scipy.optimize import linear_sum_assignment


def transfer_pos(coordinates, img_shape):
    bool_array = np.full(img_shape[::-1], False)
    for x, y in coordinates:
        bool_array[x, y] = True
    return bool_array

def find_matching_area(imgs, confidence_masks, pts3d):
    pts2d_list, pts3d_list = [], []
    for i in range(2):
        conf_i = confidence_masks[i].cpu().numpy()
        conf_i[:,:] = True
        pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)
        pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
    reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
    
    print(f'found {num_matches} matches')
    matches_im1 = pts2d_list[1][reciprocal_in_P2]
    matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
    matrix_0 = transfer_pos(matches_im0, imgs[0].shape[:2])
    matrix_1 = transfer_pos(matches_im1, imgs[0].shape[:2])
    
    matchings = {}
    for i in range(len(matches_im0)):
        matchings[tuple(matches_im1[i])] = matches_im0[i]
    return matrix_0, matrix_1, matchings

def find_pre_points(matrix_full, mask_0, img, pts3d_input):
    pts2d = {}
    pts3d = {}
    
    objs = [i for i in mask_0.keys()]
    
    for i in objs:
        matrix = np.uint8(mask_0[i]['mask'])
        matrix = cv2.resize(matrix, img.shape[:2][::-1])
        seg = matrix.T*matrix_full
        
        x, y = np.where(seg)
        # combine x and y
        pts2d[i] = np.array([x, y]).T
        # find corresponding 3d points, seg size is (512, 288), pts3d_input size is (288, 512, 3)
        pts3d[i] = pts3d_input[y, x].detach().cpu().numpy()
        # save 4 decimal points
        pts3d[i] = np.round(pts3d[i], 4)
    return pts2d, pts3d    

def find_cur_points(matrix_full, mask_0, img, pts3d_input, matching_dict):
    pts2d = {}
    pts3d = {}
    
    objs = [i for i in mask_0.keys()]
    #matching_dict[(303,217)] array([252, 281])
    for i in objs:
        matrix = np.uint8(mask_0[i]['mask'])
        matrix = cv2.resize(matrix, img.shape[:2][::-1])
        seg = matrix.T*matrix_full
        
        x, y = np.where(seg)
        # find corresponding 3d points, seg size is (512, 288), pts3d_input size is (288, 512, 3)
        pts3d[i] = pts3d_input[y, x].detach().cpu().numpy()
        # save 4 decimal points
        pts3d[i] = np.round(pts3d[i], 4)        
        # find matching points
        for j in range(len(x)):
            x[j], y[j] = matching_dict[x[j], y[j]] # matching[cur_x, cur_y] = pre_x, pre_y
        
        # combine x and y
        pts2d[i] = np.array([x, y]).T

    return pts2d, pts3d  

def init_Matching(confidence_masks, imgs, pts3d, mask_0, mask_1, id_start=1):
    # find 2D-2D matches between the two images
    
    # find matching area
    matrix_pre, matrix_cur, matching_dict = find_matching_area(imgs, confidence_masks, pts3d)
    # 1. find positive points from img0 and img1
    # for previous image     
    pre_pts2d, pre_pts3d = find_pre_points(matrix_pre, mask_0, imgs[0], pts3d[0])
    # for current image
    cur_pts2d, cur_pts3d = find_cur_points(matrix_cur, mask_1, imgs[1], pts3d[1], matching_dict)
    
    # 2. find matching points between two boxes from previous and current image
    matching_num = np.zeros((len(pre_pts2d), len(cur_pts2d)))
    
    for i in range(len(pre_pts2d)):
        pre_pts = pre_pts2d[i]
        for j in range(len(cur_pts2d)):
            cur_pts = cur_pts2d[j]
            set_A = set(map(tuple, pre_pts))
            set_B = set(map(tuple, cur_pts))
            duplicates = set_A.intersection(set_B)
            matching_num[i][j] = len(duplicates)

    # 3. Hungarian algorithm to find the best matching
    cost_matrix = matching_num.max() - matching_num
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    matching_result = list(zip(row_ind, col_ind))

    # filter the matching pairs which is less than 20
    match_details = [(row, col) for row, col in matching_result if matching_num[row, col] > 20]
    
    # find non-matching pairs
    matched_row = [row for row, _ in match_details]
    matched_col = [col for _, col in match_details]
    pre_non_match = [i for i in range(len(pre_pts2d)) if i not in matched_row]
    cur_non_match = [i for i in range(len(cur_pts2d)) if i not in matched_col]
    
    # Re identify the object id
    matched_ids = []
    unmatched_ids = []
    
    matched_dict = {}
    unmatched_dict = {}
    # matched_dict['obj_id'] = {'0':{'mask': mask, 'box':box, 'pts2d': pts2d, 'pts3d': pts3d}} 0 is the frame id
    
    
    obj_id = id_start
    for row, col in match_details:
        pre_id = list(pre_pts2d.keys())[row]
        cur_id = list(cur_pts2d.keys())[col]
        
        matched_dict[obj_id] = {}
        matched_dict[obj_id][0] = {}
        matched_dict[obj_id][1] = {}
        
        matched_dict[obj_id][0]['mask'] = mask_0[pre_id]['mask']
        matched_dict[obj_id][0]['box'] = mask_0[pre_id]['box']
        matched_dict[obj_id][0]['pts2d'] = np.array(np.where(mask_0[pre_id]['mask'])).T
        matched_dict[obj_id][0]['pts3d'] = pts3d[0][mask_0[pre_id]['mask'].astype(bool)].detach().cpu().numpy()
        
        matched_dict[obj_id][1]['mask'] = mask_1[cur_id]['mask']
        matched_dict[obj_id][1]['box'] = mask_1[cur_id]['box']
        matched_dict[obj_id][1]['pts2d'] = np.array(np.where(mask_1[cur_id]['mask'])).T
        matched_dict[obj_id][1]['pts3d'] = pts3d[1][mask_1[cur_id]['mask'].astype(bool)].detach().cpu().numpy()
        matched_ids.append(obj_id)
        obj_id += 1
    
    # assign new object id to non-matching pairs, and use memory to store the non-matching object and its 3d points
    for i in pre_non_match:
        tgt_obj = list(pre_pts2d.keys())[i]
        unmatched_dict[obj_id] = {}
        unmatched_dict[obj_id][0] = {}
        unmatched_dict[obj_id][0]['mask'] = mask_0[tgt_obj]['mask']
        unmatched_dict[obj_id][0]['box'] = mask_0[tgt_obj]['box']
        temp_mask = np.uint8(mask_0[tgt_obj]['mask'])
        temp_mask = cv2.resize(temp_mask, imgs[0].shape[:2][::-1])
        unmatched_dict[obj_id][0]['pts2d'] = np.array(np.where(temp_mask)).T # (N, 2)
        unmatched_dict[obj_id][0]['pts3d'] = pts3d[0][temp_mask.astype(bool)].detach().cpu().numpy()
        unmatched_ids.append(obj_id)
        obj_id += 1
        
    for i in cur_non_match:
        tgt_obj = list(cur_pts2d.keys())[i]
        unmatched_dict[obj_id] = {}
        unmatched_dict[obj_id][1] = {}
        unmatched_dict[obj_id][1]['mask'] = mask_1[tgt_obj]['mask']
        unmatched_dict[obj_id][1]['box'] = mask_1[tgt_obj]['box']
        temp_mask = np.uint8(mask_1[tgt_obj]['mask'])
        temp_mask = cv2.resize(temp_mask, imgs[1].shape[:2][::-1])
        unmatched_dict[obj_id][1]['pts2d'] = np.array(np.where(temp_mask)).T # (N, 2)
        unmatched_dict[obj_id][1]['pts3d'] = pts3d[1][temp_mask.astype(bool)].detach().cpu().numpy()
        unmatched_ids.append(obj_id)
        obj_id += 1
    
    return matched_dict, unmatched_dict, matched_ids, unmatched_ids, obj_id

def Tracker(matched_dict, unmatched_dict, imgs, confidence_masks, pts3d, mask_cur, fid, obj_id, memory=None):
    '''
    Matching the object between two frames after the initialization.
    input: 
        matched_dict: {'obj_id': {'0': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}, '1': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}}}
        unmatched_dict: {'obj_id': {'0': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}, '1': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}}}
        matched_ids: [obj_id]
        unmatched_ids: [obj_id]
        imgs: [img0, img1]
        confidence_masks: [mask0, mask1]
        pts3d: [pts3d0, pts3d1]
    output:
        matched_dict: {'obj_id': {'0': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}, '1': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}}}
        unmatched_dict: {'obj_id': {'0': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}, '1': {'mask': mask, 'box': box, 'pts2d': pts2d, 'pts3d': pts3d}}}
        matched_ids: [obj_id]
        unmatched_ids: [obj_id]
        obj_id: int
    '''
    # find 2D-2D matches between the two images
    # find matching area
    def _find_matching_area(matched_dict, unmatched_dict, imgs, confidence_masks, pts3d):
        # only need to match the matched objects from preivous frame to current pts3d
        pts2d_list, pts3d_list = [], []
        temp_pts2d, temp_pts3d = [], []
        id_mask = []
        
        # add points of previous frame

        for obj_id in matched_dict.keys():
            pts2d_pre = [matched_dict[obj_id][i]['pts2d'] for i in matched_dict[obj_id].keys()]
            pts3d_pre = [matched_dict[obj_id][i]['pts3d'] for i in matched_dict[obj_id].keys()]
            
            # concatenate the points
            # temp_pts2d.append(np.concatenate(pts2d_pre, axis=0))
            temp_pts3d.append(np.concatenate(pts3d_pre, axis=0))
            id_mask.append(np.ones(len(np.concatenate(pts2d_pre, axis=0)))*obj_id)
            
        for obj_id in unmatched_dict.keys():
            pts2d_un = [unmatched_dict[obj_id][i]['pts2d'] for i in unmatched_dict[obj_id].keys()]
            pts3d_un = [unmatched_dict[obj_id][i]['pts3d'] for i in unmatched_dict[obj_id].keys()]
            # concatenate the points
            # temp_pts2d.append(np.concatenate(pts2d_un, axis=0))
            temp_pts3d.append(np.concatenate(pts3d_un, axis=0))
            id_mask.append(np.ones(len(np.concatenate(pts2d_un, axis=0)))*obj_id)
        
        # # concatenate the points
        # pts2d_list.append(np.concatenate(temp_pts2d, axis=0))
        # # transfer x, y to y, x
        # pts2d_list[0] = np.array([pts2d_list[0][:, 1], pts2d_list[0][:, 0]]).T
        pts3d_list.append(np.concatenate(temp_pts3d, axis=0))
        id_mask = np.concatenate(id_mask, axis=0)
        # load current points
        conf = confidence_masks[1].cpu().numpy()
        conf[:,:] = True
        pts2d_list.append(xy_grid(*imgs[1].shape[:2][::-1])[conf])  # imgs[i].shape[:2] = (H, W)
        pts3d_list.append(pts3d[1].detach().cpu().numpy()[conf])
        
        reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
        
        print(f'found {num_matches} matches')
        matches_im1 = pts2d_list[0][reciprocal_in_P2]
        # matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
        corresponding_ids = id_mask[nn2_in_P1][reciprocal_in_P2]

        matrix_1 = transfer_pos(matches_im1, imgs[0].shape[:2])

        matchings = {}
        for i in range(len(matches_im1)):
            matchings[tuple(matches_im1[i])] = corresponding_ids[i]

        return matrix_1, matchings, corresponding_ids
    
    def memory_module(matched_dict, unmatched_dict, memory):
        
        for obj_id in matched_dict.keys():
            
            if len(matched_dict[obj_id].keys()) > memory:
                # only keep the last memory frames
                matched_dict[obj_id] = dict(sorted(matched_dict[obj_id].items(), key=lambda x: x[0], reverse=True)[:memory])
        
        for obj_id in unmatched_dict.keys():
            if len(unmatched_dict[obj_id].keys()) > memory:
                # only keep the last memory frames
                unmatched_dict[obj_id] = dict(sorted(unmatched_dict[obj_id].items(), key=lambda x: x[0], reverse=True)[:memory])
        
        return matched_dict, unmatched_dict
    
    # if memory is 100, select the last 100 frames to match
    if memory:
        matched_dict, unmatched_dict = memory_module(matched_dict, unmatched_dict, memory)

    # 1. get the matching pts2d and pts3d
    matrix_cur, matchings, _ = _find_matching_area(matched_dict, unmatched_dict, imgs, confidence_masks, pts3d)
    # 2. find the current points

    # # draw the mask on the image
    # img = imgs[1].copy()
    # # transfer img to 255
    # img = img * 255
    # img[matrix_cur.T] = [0, 255, 0]
    # cv2.imwrite('mask{}.jpg'.format(fid), img)
    # if fid == 10:

    #     import pdb; pdb.set_trace()
    
    objs = [i for i in mask_cur.keys()]
    for i in objs:
        matrix = np.uint8(mask_cur[i]['mask'])
        matrix = cv2.resize(matrix, imgs[1].shape[:2][::-1])
        seg = matrix.T*matrix_cur
        
        x, y = np.where(seg)
        # count the most frequent id
        
        id_list = [matchings[(x[i], y[i])] for i in range(len(x))]
        
        if len(id_list) == 0:
            tgt_obj_id = obj_id + 1
            obj_id += 1
            tgt_obj_id = int(tgt_obj_id)
            unmatched_dict[tgt_obj_id] = {}
            unmatched_dict[tgt_obj_id][fid] = {}
            unmatched_dict[tgt_obj_id][fid]['mask'] = mask_cur[i]['mask']
            unmatched_dict[tgt_obj_id][fid]['box'] = mask_cur[i]['box']
            unmatched_dict[tgt_obj_id][fid]['pts2d'] = np.array(np.where(mask_cur[i]['mask'])).T
            unmatched_dict[tgt_obj_id][fid]['pts3d'] = pts3d[1][mask_cur[i]['mask'].astype(bool)].detach().cpu().numpy()
            continue
        # if fid == 5:
        #     import pdb; pdb.set_trace()
        temp_id = max(set(id_list), key=id_list.count)
        if id_list.count(temp_id) > 50:
            tgt_obj_id = max(set(id_list), key=id_list.count)
            tgt_obj_id = int(tgt_obj_id)
            if tgt_obj_id in matched_dict.keys():
    
                matched_dict[tgt_obj_id][fid] = {}
                matched_dict[tgt_obj_id][fid]['mask'] = mask_cur[i]['mask']
                matched_dict[tgt_obj_id][fid]['box'] = mask_cur[i]['box']
                matched_dict[tgt_obj_id][fid]['pts2d'] = np.array(np.where(mask_cur[i]['mask'])).T
                matched_dict[tgt_obj_id][fid]['pts3d'] = pts3d[1][mask_cur[i]['mask'].astype(bool)].detach().cpu().numpy()
            else:  
                # move the object from unmatched to matched
                matched_dict[tgt_obj_id] = unmatched_dict[tgt_obj_id]
                matched_dict[tgt_obj_id][fid] = {}
                matched_dict[tgt_obj_id][fid]['mask'] = mask_cur[i]['mask']
                matched_dict[tgt_obj_id][fid]['box'] = mask_cur[i]['box']
                matched_dict[tgt_obj_id][fid]['pts2d'] = np.array(np.where(mask_cur[i]['mask'])).T
                matched_dict[tgt_obj_id][fid]['pts3d'] = pts3d[1][mask_cur[i]['mask'].astype(bool)].detach().cpu().numpy()
                unmatched_dict.pop(tgt_obj_id)
        else:
            # create a new object
            if mask_cur[i]['mask'].sum() > 100:
                tgt_obj_id = obj_id + 1
                obj_id += 1
                tgt_obj_id = int(tgt_obj_id)
                unmatched_dict[tgt_obj_id] = {}
                unmatched_dict[tgt_obj_id][fid] = {}
                unmatched_dict[tgt_obj_id][fid]['mask'] = mask_cur[i]['mask']
                unmatched_dict[tgt_obj_id][fid]['box'] = mask_cur[i]['box']
                unmatched_dict[tgt_obj_id][fid]['pts2d'] = np.array(np.where(mask_cur[i]['mask'])).T
                unmatched_dict[tgt_obj_id][fid]['pts3d'] = pts3d[1][mask_cur[i]['mask'].astype(bool)].detach().cpu().numpy()
                
    return matched_dict, unmatched_dict, obj_id

def save_results(matched_dict, unmatched_dict, lines, frame_id=None, is_init=False):
    # save results
    if is_init:
        for obj_id in matched_dict.keys():
            for fid in matched_dict[obj_id].keys():
                box = matched_dict[obj_id][fid]['box']
                # resize box from (512, 288) to (1920, 1080)
                box = [box[0], box[1], box[2], box[3]]
                lines = np.vstack((lines, np.array([fid, obj_id, box[0], box[1], box[2], box[3], -1, -1, -1, -1])))
        for obj_id in unmatched_dict.keys():
            for fid in unmatched_dict[obj_id].keys():
                box = unmatched_dict[obj_id][fid]['box']
                # resize box from (512, 288) to (1920, 1080)
                box = [box[0], box[1], box[2], box[3]]
                lines = np.vstack((lines, np.array([fid, obj_id, box[0], box[1], box[2], box[3], -1, -1, -1, -1])))
        lines = lines[1:,]
    else:
        for obj_id in matched_dict.keys():
            for fid in matched_dict[obj_id].keys():
                if fid == frame_id:
                    box = matched_dict[obj_id][fid]['box']
                    # resize box from (512, 288) to (1920, 1080)
                    box = [box[0], box[1], box[2], box[3]]
                    lines = np.vstack((lines, np.array([fid, obj_id, box[0], box[1], box[2], box[3], -1, -1, -1, -1])))
        for obj_id in unmatched_dict.keys():
            for fid in unmatched_dict[obj_id].keys():
                if fid == frame_id:
                    box = unmatched_dict[obj_id][fid]['box']
                    # resize box from (512, 288) to (1920, 1080)
                    box = [box[0], box[1], box[2], box[3]]
                    lines = np.vstack((lines, np.array([fid, obj_id, box[0], box[1], box[2], box[3], -1, -1, -1, -1])))        
    # # only save repeated tracks one time
    # lines = np.unique(lines, axis=0)
    return lines

def main(img_root, mask_root, vis_path, vis=False, Track=False):
    model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
    device = 'cuda'
    batch_size = 1
    schedule = 'cosine'
    lr = 0.01
    niter = 300
    T_window = 10
    step = 7
    model = load_model(model_path, device)
    
    # img_list = glob(img_root + '/*.jpg')
    # img_list.sort()
    
    img_list = glob(mask_root + '/*.npy')
    img_list = [img_root + img.split('/')[-1].split('.')[0] + '.jpg' for img in img_list]
    img_list.sort()

    img_list = img_list[:50]
    
    obj_id = 1
    
    lines = np.ones((1, 10))
    save_format = 'frame_id, obj_id, x1, y1, x2, y2, -1, -1, -1, -1\n'
    ids_list = [i for i in range(len(img_list))]

    sampled_frames = [img_list[i:i+T_window] for i in range(0, len(img_list) - T_window + 1, step)]
    sampled_frames_id = [ids_list[i:i+T_window] for i in range(0, len(ids_list) - T_window + 1, step)]
    
    if sampled_frames_id[-1][-1] != ids_list[-1]:
        # combine the last frames to the last sampled frames
        sampled_frames[-1] += img_list[-(T_window - step):]
        sampled_frames_id[-1] += ids_list[-(T_window - step):]

    import time
    
    for sample_i, samples in enumerate(sampled_frames):
        # Time window is 5 and step is 2
        # if i=T_window-1, we should initialize the first T_window frames and the tracker by function init_Matching
        if sample_i == 0:
            sub_images = load_images(samples, size=512)
            pairs = make_pairs(sub_images, scene_graph='complete', prefilter=None, symmetrize=True)
            output = inference(pairs, model, device, batch_size=batch_size)

            # next we'll use the global_aligner to align the predictions
            # depending on your task, you may be fine with the raw output and not need it
            # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
            # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
            scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
            loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
            
            # retrieve useful values from scene:
            imgs = scene.imgs
            pts3d = scene.get_pts3d()
            confidence_masks = scene.get_masks()
            # we need to select the first 5 frames and initialize the tracker from the first 2 frames
            for j in range(0, len(sub_images)):
                frame_id = j
                if j == 0:
                    
                    masks_dict_pre = np.load(mask_root + samples[j].split('/')[-1].split('.')[0] + '.npy', allow_pickle=True).item()
                    masks_dict_cur = np.load(mask_root + samples[j+1].split('/')[-1].split('.')[0] + '.npy', allow_pickle=True).item()
                    # find 2D-2D matches between the two images
                    confidence_masks_input = [confidence_masks[j], confidence_masks[j+1]]
                    imgs_input = [imgs[j], imgs[j+1]]
                    pts3d_input = [pts3d[j], pts3d[j+1]]
                    matched_dict, unmatched_dict, matched_ids, unmatched_ids, obj_id = init_Matching(confidence_masks_input, imgs_input, pts3d_input, masks_dict_pre, masks_dict_cur, id_start=obj_id)
                    
                    lines = save_results(matched_dict, unmatched_dict, lines, frame_id=frame_id, is_init=True)    
                    
                elif j > 1:
                    
                    masks_dict_cur = np.load(mask_root + samples[j].split('/')[-1].split('.')[0] + '.npy', allow_pickle=True).item()
                    confidence_masks_input = [confidence_masks[j-1], confidence_masks[j]]
                    imgs_input = [imgs[j-1], imgs[j]]
                    pts3d_input = [pts3d[j-1], pts3d[j]]
                    matched_dict, unmatched_dict, obj_id = Tracker(matched_dict, unmatched_dict, imgs_input, confidence_masks_input, pts3d_input, masks_dict_cur, frame_id, obj_id)
                    lines = save_results(matched_dict, unmatched_dict, lines, frame_id=frame_id) 
            
            resize_w, resize_h = confidence_masks_input[0].shape[1], confidence_masks_input[0].shape[0] 
            for ind, img in enumerate(img_list):
                tracks = lines[lines[:, 0] == ind]
                # reset frame id of new_lines
                # new_lines[new_lines[:, 0] == ind, 0] = int(float((img.split('/')[-1].split('.')[0])))
                
                img = cv2.imread(img)
                h, w = img.shape[:2]
                for track in tracks:
                    x1, y1, x2, y2 = track[2:6]
                    # resize from (512, 384) to (1440, 1080)
                    # resize from (512, 384) to (1920, 1440)
                    # x1 = int(x1 * 1920 / 512)
                    # y1 = int(y1 * 1440 / 384)
                    # x2 = int(x2 * 1920 / 512)
                    # y2 = int(y2 * 1440 / 384)
                    x1 = int(x1 * w / resize_w)
                    y1 = int(y1 * h / resize_h)
                    x2 = int(x2 * w / resize_w)
                    y2 = int(y2 * h / resize_h)
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    temp_id = int(track[1])
                    cv2.putText(img, str(temp_id), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                cv2.imwrite(vis_path + str(ind) + '.jpg', img)     
            
                
        else:

            if len(samples) == T_window:
                start_time = time.time()
                old_pts3d = pts3d[step:T_window]
                # old_pts3d = [old_pts3d[0].clone().detach(), old_pts3d[1].clone().detach()]
                # old_pts3d = [old_pts3d[0].reshape(-1, 3), old_pts3d[1].reshape(-1, 3)]
                old_pts3d = [pts3d_i.clone().detach() for pts3d_i in old_pts3d]
                old_pts3d = [pts3d_i.reshape(-1, 3) for pts3d_i in old_pts3d]
                old_pts3d = torch.cat(old_pts3d, axis=0)
                
                sub_images = load_images(samples, size=512)
                pairs = make_pairs(sub_images, scene_graph='complete', prefilter=None, symmetrize=True)
                output = inference(pairs, model, device, batch_size=batch_size)

                # next we'll use the global_aligner to align the predictions
                # depending on your task, you may be fine with the raw output and not need it
                # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
                # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
                scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
                loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
                print('sample_{} cost time: {}'.format(sample_i, time.time()-start_time))
                start_time = time.time()
                # retrieve useful values from scene:
                imgs = scene.imgs
                pts3d = scene.get_pts3d()
                confidence_masks = scene.get_masks()
                
                def optimizing_projection_matrix(old_pts3d, pts3d, T_window, step, device):
                    cur_pts3d = pts3d[:T_window - step]
                    # cur_pts3d = [cur_pts3d[0].clone().detach(), cur_pts3d[1].clone().detach()]
                    # cur_pts3d = [cur_pts3d[0].reshape(-1, 3), cur_pts3d[1].reshape(-1, 3)]
                    cur_pts3d = [pts3d_i.clone().detach() for pts3d_i in cur_pts3d]
                    cur_pts3d = [pts3d_i.reshape(-1, 3) for pts3d_i in cur_pts3d]
                    cur_pts3d = torch.cat(cur_pts3d, axis=0)
                    
                    # Optimizing RT matrix between two time windows
                    params = torch.eye(4).to(device)
                    # on GPU
                    params.requires_grad = True
                    optimizer = torch.optim.Adam([params], lr=0.001)

                    for epoch in range(100):
                        optimizer.zero_grad()
                        temp_pts3d = torch.cat([cur_pts3d, torch.ones(cur_pts3d.shape[0], 1).to(device)], dim=1)
                        cur_pts3d_transformed = torch.mm(temp_pts3d, params.t())[:, :3]
                        loss = torch.mean((old_pts3d - cur_pts3d_transformed) ** 2)
                        loss.backward()
                        optimizer.step()
                        
                        if epoch % 10 == 0:
                            print(f'epoch {epoch}, loss {loss.item()}')
                    return params

                RT_matrix = optimizing_projection_matrix(old_pts3d, pts3d, T_window, step, device)
                print('sample_{} cost time: {}'.format(sample_i, time.time()-start_time))
                start_time = time.time()
                # transform all the pts3d to the previous frame
                pts3d_transformed = []
                for pts3d_i in range(len(pts3d)):
                    temp_pts3d = torch.cat([pts3d[pts3d_i].detach().reshape(-1, 3), torch.ones(pts3d[pts3d_i].reshape(-1, 3).shape[0], 1).to(device)], dim=1)  
                    pts3d_transformed_temp = torch.mm(temp_pts3d, RT_matrix.detach().t())[:, :3]
                    pts3d_transformed_temp = pts3d_transformed_temp.reshape(pts3d[pts3d_i].shape)    
                    pts3d_transformed.append(pts3d_transformed_temp)
                print('transform_{} cost time: {}'.format(sample_i, time.time()-start_time))
                start_time = time.time()
                # Tracking
                # 1. update the matched_dict and unmatched_dict
                overlap_frames = set(sampled_frames_id[sample_i]) & set(sampled_frames_id[sample_i-1])
                # to list
                overlap_frames = list(overlap_frames)
                
                for matched_id in matched_dict.keys():
                    for fid in matched_dict[matched_id].keys():
                        if fid in overlap_frames:
                            temp_mask = matched_dict[matched_id][fid]['mask']
                            matched_dict[matched_id][fid]['pts3d'] = pts3d_transformed[sampled_frames_id[sample_i].index(fid)][temp_mask.astype(bool)].detach().cpu().numpy()
                
                for unmatched_id in unmatched_dict.keys():
                    for fid in unmatched_dict[unmatched_id].keys():
                        if fid in overlap_frames:
                            temp_mask = unmatched_dict[unmatched_id][fid]['mask']
                            unmatched_dict[unmatched_id][fid]['pts3d'] = pts3d_transformed[sampled_frames_id[sample_i].index(fid)][temp_mask.astype(bool)].detach().cpu().numpy()
                pts3d = pts3d_transformed
                print('step1_{} cost time: {}'.format(sample_i, time.time()-start_time))
                start_time = time.time()
                # 2. start tracking
                for j in range(T_window-step,T_window):
                    frame_id = sampled_frames_id[sample_i][j]
                    print(f'frame_id: {frame_id}')
                    masks_dict_cur = np.load(mask_root + samples[j].split('/')[-1].split('.')[0] + '.npy', allow_pickle=True).item()
                    confidence_masks_input = [confidence_masks[j-1], confidence_masks[j]]
                    imgs_input = [imgs[j-1], imgs[j]]
                    pts3d_input = [pts3d[j-1], pts3d[j]]
                    matched_dict, unmatched_dict, obj_id = Tracker(matched_dict, unmatched_dict, imgs_input, confidence_masks_input, pts3d_input, masks_dict_cur, frame_id, obj_id, memory=30)
                    lines = save_results(matched_dict, unmatched_dict, lines, frame_id=frame_id)  
                print('step2_{} cost time: {}'.format(sample_i, time.time()-start_time))
            resize_w, resize_h = confidence_masks_input[0].shape[1], confidence_masks_input[0].shape[0] 
            for ind, img in enumerate(img_list):
                tracks = lines[lines[:, 0] == ind]
                # reset frame id of new_lines
                # new_lines[new_lines[:, 0] == ind, 0] = int(float((img.split('/')[-1].split('.')[0])))
                
                img = cv2.imread(img)
                h, w = img.shape[:2]
                for track in tracks:
                    x1, y1, x2, y2 = track[2:6]
                    # resize from (512, 384) to (1440, 1080)
                    # resize from (512, 384) to (1920, 1440)
                    # x1 = int(x1 * 1920 / 512)
                    # y1 = int(y1 * 1440 / 384)
                    # x2 = int(x2 * 1920 / 512)
                    # y2 = int(y2 * 1440 / 384)
                    x1 = int(x1 * w / resize_w)
                    y1 = int(y1 * h / resize_h)
                    x2 = int(x2 * w / resize_w)
                    y2 = int(y2 * h / resize_h)
                    cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
                    temp_id = int(track[1])
                    cv2.putText(img, str(temp_id), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
                cv2.imwrite(vis_path + str(ind) + '.jpg', img)   
            else:
                # T_window = len(samples)
                old_pts3d = pts3d[step:T_window]
                # old_pts3d = [old_pts3d[0].clone().detach(), old_pts3d[1].clone().detach()]
                # old_pts3d = [old_pts3d[0].reshape(-1, 3), old_pts3d[1].reshape(-1, 3)]
                old_pts3d = [pts3d_i.clone().detach() for pts3d_i in old_pts3d]
                old_pts3d = [pts3d_i.reshape(-1, 3) for pts3d_i in old_pts3d]
                old_pts3d = torch.cat(old_pts3d, axis=0)
                sub_images = load_images(samples, size=512)
                pairs = make_pairs(sub_images, scene_graph='complete', prefilter=None, symmetrize=True)
                output = inference(pairs, model, device, batch_size=batch_size)

                # next we'll use the global_aligner to align the predictions
                # depending on your task, you may be fine with the raw output and not need it
                # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
                # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
                scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
                loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
                
                # retrieve useful values from scene:
                imgs = scene.imgs
                pts3d = scene.get_pts3d()
                confidence_masks = scene.get_masks()
                
                def optimizing_projection_matrix(old_pts3d, pts3d, T_window, step, device):
                    cur_pts3d = pts3d[:T_window - step]
                    # cur_pts3d = [cur_pts3d[0].clone().detach(), cur_pts3d[1].clone().detach()]
                    # cur_pts3d = [cur_pts3d[0].reshape(-1, 3), cur_pts3d[1].reshape(-1, 3)]
                    cur_pts3d = [pts3d_i.clone().detach() for pts3d_i in cur_pts3d]
                    cur_pts3d = [pts3d_i.reshape(-1, 3) for pts3d_i in cur_pts3d]
                    cur_pts3d = torch.cat(cur_pts3d, axis=0)
                    # Optimizing RT matrix between two time windows
                    params = torch.eye(4).to(device)
                    # on GPU
                    params.requires_grad = True
                    optimizer = torch.optim.Adam([params], lr=0.001)

                    for epoch in range(100):
                        optimizer.zero_grad()
                        temp_pts3d = torch.cat([cur_pts3d, torch.ones(cur_pts3d.shape[0], 1).to(device)], dim=1)
                        cur_pts3d_transformed = torch.mm(temp_pts3d, params.t())[:, :3]
                        loss = torch.mean((old_pts3d - cur_pts3d_transformed) ** 2)
                        loss.backward()
                        optimizer.step()
                        
                        if epoch % 10 == 0:
                            print(f'epoch {epoch}, loss {loss.item()}')
                    return params
                
                RT_matrix = optimizing_projection_matrix(old_pts3d, pts3d, T_window, step, device)
                
                # transform all the pts3d to the previous frame
                pts3d_transformed = []
                for pts3d_i in range(len(pts3d)):
                    temp_pts3d = torch.cat([pts3d[pts3d_i].detach().reshape(-1, 3), torch.ones(pts3d[pts3d_i].reshape(-1, 3).shape[0], 1).to(device)], dim=1)  
                    pts3d_transformed_temp = torch.mm(temp_pts3d, RT_matrix.detach().t())[:, :3]
                    pts3d_transformed_temp = pts3d_transformed_temp.reshape(pts3d[pts3d_i].shape)    
                    pts3d_transformed.append(pts3d_transformed_temp)
                
                # Tracking
                # 1. update the matched_dict and unmatched_dict
                overlap_frames = set(sampled_frames_id[sample_i]) & set(sampled_frames_id[sample_i-1])
                # to list
                overlap_frames = list(overlap_frames)
                
                for matched_id in matched_dict.keys():
                    for fid in matched_dict[matched_id].keys():
                        if fid in overlap_frames:
                            temp_mask = matched_dict[matched_id][fid]['mask']
                            matched_dict[matched_id][fid]['pts3d'] = pts3d_transformed[sampled_frames_id[sample_i].index(fid)][temp_mask.astype(bool)].detach().cpu().numpy()
                
                for unmatched_id in unmatched_dict.keys():
                    for fid in unmatched_dict[unmatched_id].keys():
                        if fid in overlap_frames:
                            temp_mask = unmatched_dict[unmatched_id][fid]['mask']
                            unmatched_dict[unmatched_id][fid]['pts3d'] = pts3d_transformed[sampled_frames_id[sample_i].index(fid)][temp_mask.astype(bool)].detach().cpu().numpy()
                
                pts3d = pts3d_transformed
                # 2. start tracking
                for j in range(T_window - step, len(samples)):
                    frame_id = sampled_frames_id[sample_i][j]
                    print(f'frame_id: {frame_id}')
                    masks_dict_cur = np.load(mask_root + samples[j].split('/')[-1].split('.')[0] + '.npy', allow_pickle=True).item()
                    confidence_masks_input = [confidence_masks[j-1], confidence_masks[j]]
                    imgs_input = [imgs[j-1], imgs[j]]
                    pts3d_input = [pts3d[j-1], pts3d[j]]
                    matched_dict, unmatched_dict, obj_id = Tracker(matched_dict, unmatched_dict, imgs_input, confidence_masks_input, pts3d_input, masks_dict_cur, frame_id, obj_id, memory=30)
                    lines = save_results(matched_dict, unmatched_dict, lines, frame_id=frame_id) 
    resize_w, resize_h = confidence_masks_input[0].shape[1], confidence_masks_input[0].shape[0]           
    lines = sorted(lines, key=lambda x: x[0])
    # transform to np array
    lines = np.array(lines)
    new_lines = lines.copy()
    # visulization
    exp_id = img_list[0].split('/')[-3]
    for ind, img in enumerate(img_list):
        tracks = lines[lines[:, 0] == ind]
        # reset frame id of new_lines
        # new_lines[new_lines[:, 0] == ind, 0] = int(float((img.split('/')[-1].split('.')[0])))
        
        img = cv2.imread(img)
        h, w = img.shape[:2]
        for track in tracks:
            x1, y1, x2, y2 = track[2:6]
            # resize from (512, 384) to (1440, 1080)
            # resize from (512, 384) to (1920, 1440)
            # x1 = int(x1 * 1920 / 512)
            # y1 = int(y1 * 1440 / 384)
            # x2 = int(x2 * 1920 / 512)
            # y2 = int(y2 * 1440 / 384)
            x1 = int(x1 * w / resize_w)
            y1 = int(y1 * h / resize_h)
            x2 = int(x2 * w / resize_w)
            y2 = int(y2 * h / resize_h)
            cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            temp_id = int(track[1])
            cv2.putText(img, str(temp_id), (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.imwrite(vis_path + str(ind) + '.jpg', img)
    
    # new_lines = new_lines.tolist()
    with open(mask_root.split('masks/')[0] + 'tracks.txt', 'w') as f:
        # save lines to txt
        for line in lines:
            f.write(','.join([str(i) for i in line]) + '\n')
    print('Tracks saved successfully!')
    f.close()
        

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='DUSt3R demo')
    parser.add_argument('--img_root', type=str, default='/Code/dust3r/croco/assets/test/')
    parser.add_argument('--mask_root', type=str, default='/Code/dust3r/outputs/')
    parser.add_argument('--exp_name', type=str, default='test10')
    parser.add_argument('--vis', action='store_true')
    args = parser.parse_args()
    import os
    vis_path = args.mask_root + '/vis/' + args.exp_name + '/'
    if not os.path.exists(vis_path):
        os.makedirs(vis_path)
    args.mask_root = args.mask_root + args.exp_name + '/masks/'
    
    model_path = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
    device = 'cuda'
    batch_size = 1
    schedule = 'cosine'
    lr = 0.01
    niter = 300
    
    model = load_model(model_path, device)
    main(args.img_root, args.mask_root, vis_path, vis=False, Track=True)

    

