import os
import argparse
import torch
import numpy as np
from pathlib import Path
from time import time

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
from icecream import ic
ic(torch.cuda.is_available())  
ic(torch.cuda.device_count())

from mast3r.model import AsymmetricMASt3R, fast_reciprocal_NNs

from dust3r.image_pairs import make_pairs
from dust3r.inference import inference
from dust3r.utils.device import to_numpy
from dust3r.utils.geometry import inv
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from utils.sfm_utils import (save_intrinsics, save_extrinsic, save_points3D, save_time, save_images_and_masks,
                             init_filestructure, get_sorted_image_files, split_train_test, load_images, compute_co_vis_masks,
                             project_points, warp_corners_and_draw_matches)
from PIL import Image
import cv2
from utils.scale_metric_depth_calculate import fit_scale_and_shift_multiple, plot_depth_distributions, fit_scale_and_shift_multiple2
from gim.demo import gim_run
import imageio
import matplotlib.pyplot as plt
from vggt.vggt_run import vggt_run

def main(source_path, model_path, ckpt_path, device, batch_size, image_size, schedule, lr, niter, 
         min_conf_thr, llffhold, n_views, co_vis_dsp, depth_thre, conf_aware_ranking=False, focal_avg=False, infer_video=False):

    
    save_path, sparse_0_path, sparse_1_path = init_filestructure(Path(source_path), n_views)
    model = AsymmetricMASt3R.from_pretrained(ckpt_path).to(device)
    image_dir = Path(source_path) / 'images'
    image_files, image_suffix = get_sorted_image_files(image_dir)
    if infer_video:
        train_img_files = image_files
    else:
        train_img_files, test_img_files = split_train_test(image_files, llffhold, n_views, verbose=True)
    
    
    image_files = train_img_files
    images, org_imgs_shape, img_names = load_images(image_files, size=image_size)

    intrinsic, _ = vggt_run(image_files, device=device)
    fx = (intrinsic[0][0][0] + intrinsic[1][0][0]) / 2 * (images[0]['true_shape'][0][1] / (intrinsic[0,0,2] *2))
    fy = (intrinsic[0][1][1] + intrinsic[1][1][1]) / 2  * (images[0]['true_shape'][0][0] / (intrinsic[0,1,2] *2))

    start_time = time()
    print(f'>> Making pairs...')
    pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
    print(f'>> Inference...')
    output, output1 = inference(pairs, model, device, batch_size=1, verbose=True)
    print(f'>> Global alignment...')
    scene = global_aligner(output, device=args.device, mode=GlobalAlignerMode.PointCloudOptimizer)
    loss = scene.compute_global_alignment(init="mst", niter=300, schedule=schedule, lr=lr, focal_avg=args.focal_avg, known_focal=[fx, fy])

    
    
    
    x1, x2 = [cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR) for img in scene.imgs]
    mkpts_0, mkpts_1, valid_matches = gim_run(x1, x2)
    
    H0, W0 = images[0]['img'].shape[-2:]
    valid_matches_im0 = (mkpts_0[:, 0] >= 3) & (mkpts_0[:, 0] < int(W0) - 3) & (
        mkpts_0[:, 1] >= 3) & (mkpts_0[:, 1] < int(H0) - 3)
    H1, W1 = images[1]['img'].shape[-2:]
    valid_matches_im1 = (mkpts_1[:, 0] >= 3) & (mkpts_1[:, 0] < int(W1) - 3) & (
        mkpts_1[:, 1] >= 3) & (mkpts_1[:, 1] < int(H1) - 3)
    valid_matches = valid_matches & valid_matches_im0 & valid_matches_im1
    mkpts_0, mkpts_1 = mkpts_0[valid_matches], mkpts_1[valid_matches]

    
    canvas = warp_corners_and_draw_matches(mkpts_0, mkpts_1, x1, x2)
    plt.figure(figsize=(12,12))
    plt.imshow(canvas[..., ::-1])
    plt.axis('off')
    plt.savefig('./matched_images.png', bbox_inches='tight', pad_inches=0)

    
    
    
    mast3rdepth = scene.get_depthmaps() 
    confs = scene.get_conf() 
    mast3r_depths = []
    mast3r_confs = []
    for depth, conf, filename in zip(mast3rdepth, confs, image_files):
        depth_np = to_numpy(depth.detach().cpu())
        conf_np = to_numpy(conf.detach().cpu())
        mast3r_depths.append(depth_np)
        mast3r_confs.append(conf_np)
    
    scene.min_conf_thr = min_conf_thr
    msks = scene.get_masks()

    
    
    
    from metric_3d_v2 import metric3d_depth_cal_save
    focals = to_numpy(scene.get_focals())
    metric_3d_depths = metric3d_depth_cal_save(img_names, focals.repeat(2, 1), scene.get_principal_points().detach().cpu().numpy(), org_imgs_shape)
    mast3r_depths = metric_3d_depths
    

    
    
    
    
    
    mask_ones = np.ones((H1, W1), dtype=bool)
    avg_scale, avg_shift = fit_scale_and_shift_multiple([mast3r_depths[0]], [metric_3d_depths[0]], [msks[0].detach().cpu().numpy()], [mask_ones])
    mast3r_depth_0 = avg_scale * mast3r_depths[0] + avg_shift

    
    
    
    

    
    




    
    extrinsics_w2c_0 = np.eye(4)
    pts3d_0 = to_numpy(scene.get_pts3d_0(inv(extrinsics_w2c_0), mast3r_depth_0))[0]
    pts3d_0_flat = pts3d_0.reshape(-1, 3)
    
    pts0_to_3d = {}
    H, W = scene.get_depthmaps()[0].shape
    for i, (x, y) in enumerate(mkpts_0):
        if 0 <= int(y) < H and 0 <= int(x) < W:
            idx = int(y) * W + int(x)
            pts0_to_3d[(x, y)] = pts3d_0_flat[idx]
    
    imagePoints = []
    objectPoints = []
    for pt0, pt1 in zip(mkpts_0, mkpts_1):
        pt0_tuple = (pt0[0], pt0[1])
        if pt0_tuple in pts0_to_3d:
            imagePoints.append(pt1)
            objectPoints.append(pts0_to_3d[pt0_tuple])

    imagePoints = np.array(imagePoints, dtype=np.float32)
    objectPoints = np.array(objectPoints, dtype=np.float32)

    
    
    
    
    
    
    k = scene.get_intrinsics()[1].cpu().numpy()
    res = cv2.solvePnPRansac(objectPoints, imagePoints, k, None,
                             iterationsCount=50000, 
                            confidence=0.999,
                            reprojectionError=3.0,
                            flags=cv2.SOLVEPNP_SQPNP,
                            )
    success, R, T, inliers = res
    assert success

    
    if inliers is not None and len(inliers) > 0:
        inlier_indices = inliers.ravel()
        inlier_objectPoints = objectPoints[inlier_indices]
        inlier_imagePoints = imagePoints[inlier_indices]
        
        R, T = cv2.solvePnPRefineLM(inlier_objectPoints, inlier_imagePoints, k, None, R, T)

    
    print(objectPoints.shape, imagePoints.shape)
    print(f"Number of points used for PnP: {len(objectPoints)}")
    print(f"Number of inliers: {len(inliers) if inliers is not None else 0}")
    print(f"Inlier ratio: {len(inliers)/len(objectPoints) if inliers is not None else 0:.3f}")

    
    if inliers is not None:
        inlier_points3d = objectPoints
        inlier_points2d = imagePoints
        projected_points, _ = cv2.projectPoints(inlier_points3d, R, T, k, None)
        error = np.mean(np.linalg.norm(projected_points.reshape(-1,2) - inlier_points2d, axis=1))
        print(f"Mean reprojection error for inliers: {error:.3f} pixels")
    R = cv2.Rodrigues(R)[0]  

    pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]])  
    extrinsics_w2c_1 = np.r_[np.c_[R, T], [(0, 0, 0, 1)]] 
    
    
    
    
    
    
    
    
    
    
    
    

    
    points_2d_1, sparse_depth_1 = project_points(pts3d_0_flat, k, extrinsics_w2c_1)
    

    
    h, w = mast3rdepth[1].shape
    sparse_depth_map = np.zeros((h, w), dtype=np.float32)
    sparse_depth_mask = np.zeros((h, w), dtype=bool)

    
    valid_points = (points_2d_1[:, 0] >= 0) & (points_2d_1[:, 0] < w) & \
                (points_2d_1[:, 1] >= 0) & (points_2d_1[:, 1] < h) & \
                (sparse_depth_1 > 0) & (~np.isnan(sparse_depth_1))

    
    x_indices = np.clip(np.round(points_2d_1[valid_points, 0]).astype(int), 0, w-1)
    y_indices = np.clip(np.round(points_2d_1[valid_points, 1]).astype(int), 0, h-1)

    
    sparse_depth_map[y_indices, x_indices] = sparse_depth_1[valid_points]
    sparse_depth_mask[y_indices, x_indices] = True

    
    
    plt.figure(figsize=(16, 8))
    
    
    plt.subplot(1, 2, 1)
    masked_depth = np.copy(sparse_depth_map)
    masked_depth[~sparse_depth_mask] = np.nan  
    plt.imshow(masked_depth, cmap='viridis')
    plt.colorbar(label='Depth')
    plt.title(f'Sparse Depth Map (Valid points: {np.sum(sparse_depth_mask)})')
    
    
    plt.subplot(1, 2, 2)
    plt.imshow(sparse_depth_mask, cmap='gray')
    plt.title('Sparse Depth Mask')
    
    plt.tight_layout()
    plt.savefig('sparse_depth_visualization.png')
    print(f"Sparse depth visualization saved. Valid points: {np.sum(sparse_depth_mask)}")

    avg_scale2 , avg_shift2 = fit_scale_and_shift_multiple2([mast3r_depths[1]], [metric_3d_depths[1]], [sparse_depth_mask], [mask_ones])
    
    another_scale_depth = mast3r_depths[1] * avg_scale2 + avg_shift2
    pts3d_1 = to_numpy(scene.get_pts3d_0(pose, another_scale_depth))[0]

    
    

    scale_depth = [mast3r_depth_0, another_scale_depth]
    pts3d = np.array([pts3d_0, pts3d_1])
    depthmaps = np.log(scale_depth)
    extrinsics_w2c =np.array([extrinsics_w2c_0, extrinsics_w2c_1])
    intrinsics = to_numpy(scene.get_intrinsics())
    imgs = np.array(scene.imgs)
    values = [param.detach().cpu().numpy() for param in scene.im_conf]
    confs = np.array(values)



    
    focals = to_numpy(scene.get_focals())
    
    
    
    
    
    
    

    if conf_aware_ranking:
        print(f'>> Confiden-aware Ranking...')
        avg_conf_scores = confs.mean(axis=(1, 2))
        sorted_conf_indices = np.argsort(avg_conf_scores)[::-1]
        sorted_conf_avg_conf_scores = avg_conf_scores[sorted_conf_indices]
        print("Sorted indices:", sorted_conf_indices)
        print("Sorted average confidence scores:", sorted_conf_avg_conf_scores)
    else:
        sorted_conf_indices = np.arange(n_views)
        print("Sorted indices:", sorted_conf_indices)

    
    print(f'>> Calculate the co-visibility mask...')
    if depth_thre > 0:
        overlapping_masks = compute_co_vis_masks(sorted_conf_indices, depthmaps, pts3d, intrinsics, extrinsics_w2c, imgs.shape, depth_threshold=depth_thre)
        overlapping_masks = ~overlapping_masks
    else:
        co_vis_dsp = False
        overlapping_masks = None
    end_time = time()
    Train_Time = end_time - start_time
    print(f"Time taken for {n_views} views: {Train_Time} seconds")
    save_time(model_path, '[1] coarse_init_TrainTime', Train_Time)

    
    if not infer_video:
        n_train = len(train_img_files)
        n_test = len(test_img_files)

        if n_train < n_test:
            n_interp = (n_test 
            all_inter_pose = []
            for i in range(n_train-1):
                tmp_inter_pose = generate_interpolated_path(poses=extrinsics_w2c[i:i+2], n_interp=n_interp)
                all_inter_pose.append(tmp_inter_pose)
            all_inter_pose = np.concatenate(all_inter_pose, axis=0)
            all_inter_pose = np.concatenate([all_inter_pose, extrinsics_w2c[-1][:3, :].reshape(1, 3, 4)], axis=0)
            indices = np.linspace(0, all_inter_pose.shape[0] - 1, n_test, dtype=int)
            sampled_poses = all_inter_pose[indices]
            sampled_poses = np.array(sampled_poses).reshape(-1, 3, 4)
            assert sampled_poses.shape[0] == n_test
            inter_pose_list = []
            for p in sampled_poses:
                tmp_view = np.eye(4)
                tmp_view[:3, :3] = p[:3, :3]
                tmp_view[:3, 3] = p[:3, 3]
                inter_pose_list.append(tmp_view)
            pose_test_init = np.stack(inter_pose_list, 0)
        else:
            indices = np.linspace(0, extrinsics_w2c.shape[0] - 1, n_test, dtype=int)
            pose_test_init = extrinsics_w2c[indices]

        save_extrinsic(sparse_1_path, pose_test_init, test_img_files, image_suffix)
        test_focals = np.repeat(focals[0], n_test)
        save_intrinsics(sparse_1_path, test_focals, org_imgs_shape, imgs.shape, save_focals=False)
    

    
    focals = np.repeat(focals[0], n_views)
    print(f'>> Saving results...')
    end_time = time()
    save_time(model_path, '[1] init_geo', end_time - start_time)
    save_extrinsic(sparse_0_path, extrinsics_w2c, image_files, image_suffix)
    save_intrinsics(sparse_0_path, focals, org_imgs_shape, imgs.shape, save_focals=True)
    pts_num = save_points3D(sparse_0_path, imgs, pts3d, confs.reshape(pts3d.shape[0], -1), overlapping_masks, use_masks=co_vis_dsp, save_all_pts=True, save_txt_path=model_path, depth_threshold=depth_thre)
    save_images_and_masks(sparse_0_path, n_views, imgs, overlapping_masks, image_files, image_suffix)
    print(f'[INFO] MASt3R Reconstruction is successfully converted to COLMAP files in: {str(sparse_0_path)}')
    print(f'[INFO] Number of points: {pts3d.reshape(-1, 3).shape[0]}')    
    print(f'[INFO] Number of points after downsampling: {pts_num}')

    
    
    
    
    
    
    
    
    

    
    
    
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process images and save results.')
    parser.add_argument('--source_path', '-s', type=str, required=True, help='Directory containing images')
    parser.add_argument('--model_path', '-m', type=str, required=True, help='Directory to save the results')
    parser.add_argument('--ckpt_path', type=str,
        default='./mast3r/checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth', help='Path to the model checkpoint')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference')
    parser.add_argument('--batch_size', type=int, default=1, help='Batch size for processing images')
    parser.add_argument('--image_size', type=int, default=512, help='Size to resize images')
    parser.add_argument('--schedule', type=str, default='cosine', help='Learning rate schedule')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--niter', type=int, default=300, help='Number of iterations')
    parser.add_argument('--min_conf_thr', type=float, default=0.7, help='Minimum confidence threshold')
    parser.add_argument('--llffhold', type=int, default=8, help='')
    parser.add_argument('--n_views', type=int, default=3, help='')
    
    parser.add_argument('--focal_avg', action="store_true")
    parser.add_argument('--conf_aware_ranking', action="store_true")
    parser.add_argument('--co_vis_dsp', action="store_true")
    parser.add_argument('--depth_thre', type=float, default=0.1, help='Depth threshold')
    parser.add_argument('--infer_video', action="store_true")

    args = parser.parse_args()
    main(args.source_path, args.model_path, args.ckpt_path, args.device, args.batch_size, args.image_size, args.schedule, args.lr, args.niter,         
          args.min_conf_thr, args.llffhold, args.n_views, args.co_vis_dsp, args.depth_thre, args.conf_aware_ranking, args.focal_avg, args.infer_video)