import numpy as np

import sys
sys.path.append('path/to/colmap/build')
from colmap.scripts.python.read_write_model import read_model, qvec2rotmat
from colmap.scripts.python.read_write_dense import read_array
from imageio import imread
from tqdm import tqdm

import h5py
import deepdish as dd
from time import time

import cv2 as cv
from scipy.optimize import linear_sum_assignment

import argparse
import random
import json

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', required=True)
parser.add_argument('--skip_pairs', default=1, type=int)
parser.add_argument('--skip_kps', default=1, type=int)
parser.add_argument('--outlier_rate', default=0.0, type=float)
parser.add_argument('--num_query_kps', default=100.0, type=int)
parser.add_argument('--log_file', required=True)
args = parser.parse_args()

def dist(x, y):
    d =  np.linalg.norm(x - y) ** 2
    return d if d > 0 else 0.0000001

def greedy(X, X_hash):
    M = np.zeros((len(X), len(X_hash)))
    for i, x in enumerate(X):
        for j, x_hash in enumerate(X_hash):
            M[i, j] = dist(x, x_hash)
    pie = np.ones(len(X), dtype=np.int32)
    for i in range(len(X)):
        pie[i] = np.argmin(M[i, :])
        M[:, pie[i]] = np.inf
    return pie

def LSS(X, X_hash):
    M = np.zeros((len(X), len(X_hash)))
    for i, x in enumerate(X):
        for j, x_hash in enumerate(X_hash):
            M[i, j] = dist(x, x_hash)
            
    row_ind, col_ind = linear_sum_assignment(M)
    return col_ind

def LSNS(X, X_hash, sigma, sigma_hash):
    M = np.zeros((len(X), len(X_hash)))
    for i, x in enumerate(X):
        for j, x_hash in enumerate(X_hash):
            M[i, j] = dist(x, x_hash) / (sigma_hash[j] ** 2 + sigma[i] ** 2) ** (1/2)
    
    row_ind, col_ind = linear_sum_assignment(M)
    return col_ind

def LSL(X, X_hash):
    M = np.zeros((len(X), len(X_hash)))
    for i, x in enumerate(X):
        for j, x_hash in enumerate(X_hash):
            M[i, j] = np.log(dist(x, x_hash))
    
    row_ind, col_ind = linear_sum_assignment(M)
    return col_ind

def mask_random(arr, rate):
    if rate == 0 or len(arr) == 0:
        return [], []
    elif rate <= 1:  # rate == 1 just shuffles the array
        res_len = int(len(arr) * rate)
    else:
        res_len = rate
    arr = list(enumerate(arr))
    random.shuffle(arr)
    return map(list, zip(*sorted(arr[:res_len])))

def shuffle_kps(query_kps, inlier_kps, outlier_kps, correct_matching):
    train_kps = inlier_kps + outlier_kps
    perm, train_kps = mask_random(train_kps, 1)
    
    return query_kps, train_kps, perm[:len(query_kps)]
    
def compute_accs(statistics):
    (overall_cv, overall_query, overall_train, cv_corr, greedy_corr, lss_corr,
     lsl_corr, cv_outliers, greedy_outliers, lss_outliers, lsl_outliers) = statistics
    
    return {
        'OpenCV_acc': cv_corr / overall_cv,
        'Greedy_acc': greedy_corr / overall_query,
        'LSS_acc': lss_corr / overall_query,
        'LSL_acc': lsl_corr / overall_query,
        
        'OpenCV_outliers_acc': 1 - cv_outliers / overall_cv,
        'Greedy_outliers_acc': 1 - greedy_outliers / overall_query,
        'LSS_outliers_acc': 1 - lss_outliers / overall_query,
        'LSL_outliers_acc': 1 - lsl_outliers / overall_query
    }

def update_res(res, accs, idx1, idx2):
    key = "{} {}".format(idx1, idx2)
    for k, v in accs.items():
        if key in res[k]:
            print("{} {} appearing twice, skipping...".format(idx1, idx2))
        else:
            res[k][key] = v
    return res
    
root = './data'
seq = args.dataset  #'reichstag' #'temple_nara_japan'
src = root + '/' + seq
print(f'Doing {seq} data.')

# load reconstruction from colmap
cameras, images, points = read_model(path=src + '/dense/sparse', ext='.bin')

print(f'Cameras: {len(cameras)}')
print(f'Images: {len(images)}')
print(f'3D points: {len(points)}')

indices = [i for i in cameras]

# Retrieve one image, the depth map, and 2D points
def get_image(idx, verbose=False):
    im = imread(src + '/dense/images/' + images[idx].name)
    depth = read_array(src + '/dense/stereo/depth_maps/' + images[idx].name + '.photometric.bin')
    min_depth, max_depth = np.percentile(depth, [5, 95])
    depth[depth < min_depth] = min_depth
    depth[depth > max_depth] = max_depth

    # reformat data
    q = images[idx].qvec
    R = qvec2rotmat(q)
    T = images[idx].tvec
    p = images[idx].xys
    pars = cameras[idx].params
    K = np.array([[pars[0], 0, pars[2]], [0, pars[1], pars[3]], [0, 0, 1]])
    pids = images[idx].point3D_ids
    v = pids >= 0
    if verbose:
        print('Number of (valid) points: {}'.format((pids > -1).sum()))
        print('Number of (total) points: {}'.format(v.size))
    
    # get also the clean depth maps
    base = '.'.join(images[idx].name.split('.')[:-1])
    with h5py.File(src + '/dense/stereo/depth_maps_clean_300_th_0.10/' + base + '.h5', 'r') as f:
        depth_clean = f['depth'].value

    return {
        'image': im,
        'depth_raw': depth,
        'depth': depth_clean,
        'K': K,
        'q': q,
        'R': R,
        'T': T,
        'xys': p,
        'ids': pids,
        'valid': v}


# We can just retrieve all the 3D points
xyz, rgb = [], []
for i in points:
    xyz.append(points[i].xyz)
    rgb.append(points[i].rgb)
xyz = np.array(xyz)
rgb = np.array(rgb)

print(xyz.shape)

# We also provide a measure of how images overlap, based on the bounding boxes
# of the 2D points they have in common
t = time()
# each pair contains [bbox1, bbox2, visibility1, visibility2, # of shared matches]
pairs = dd.io.load(src + '/dense/stereo/pairs-dilation-0.00-fixed2.h5')
print(f'Done ({time() - t:.2f} s.)')

# Threshold at a given value
# pairs[p][0]: ratio between the area of the bounding box containing common points and that of image 1
# pairs[p][1]: same for image 2
th = 0.3

filtered = []
for p in pairs:
    if pairs[p][0] >= th and pairs[p][1] >= th:
        idx1, idx2 = p
        # print(f'Valid pair: ({idx1}, {idx2}), ths=({pairs[p][2]:.2f}, {pairs[p][3]:.2f})')
        filtered += [p]
print(f'Valid pairs: {len(filtered)}/{len(pairs)}')
pairs = filtered

res = {
    'OpenCV_acc': {},
    'Greedy_acc': {},
    'LSS_acc': {},
    'LSL_acc': {},
    'OpenCV_outliers_acc': {},
    'Greedy_outliers_acc': {},
    'LSS_outliers_acc': {},
    'LSL_outliers_acc': {},
}

for idx1, idx2 in tqdm(pairs[::args.skip_pairs]):
    
    # pick one pair (e.g. the third one)
    # These two images should be matchable
    data1 = get_image(idx1)
    data2 = get_image(idx2)

    # Find the points in common
    v1 = data1['ids'][data1['ids'] > 0]
    v2 = data2['ids'][data2['ids'] > 0]
    common = np.intersect1d(v1, v2)
    
    depth1 = data1['depth']
    K1 = data1['K']
    R1 = data1['R']
    T1 = data1['T']

    depth2 = data2['depth']
    K2 = data2['K']
    R2 = data2['R']
    T2 = data2['T']

    # Get the points from one of the images
    xy1s = np.array([tmp_xy for i, tmp_xy in enumerate(data1['xys']) if data1['ids'][i] in common]) #data1['xys'][data1['valid'], :]
    u_xy1s = xy1s.T
    
    # Filter wrong xys
    u_xy1s = u_xy1s[:, u_xy1s[0] >= 0]
    u_xy1s = u_xy1s[:, u_xy1s[0] < data1['depth'].shape[1] ]
    u_xy1s = u_xy1s[:, u_xy1s[1] >= 0]
    u_xy1s = u_xy1s[:, u_xy1s[1] < data1['depth'].shape[0] ]

    # Convert to homogeneous coordinates
    u_xy1s = np.concatenate([u_xy1s, np.ones([1, u_xy1s.shape[1]])], axis=0)

    # Get depth (on image 1) for each point
    u_xy1s_int = u_xy1s.astype(np.int32)
    z1 = data1['depth'][u_xy1s_int[1], u_xy1s_int[0]]

    # Eliminate points on occluded areas
    not_void = z1 > 0
#     print(f'Valid points: {sum(not_void)}/{len(not_void)}')
    u_xy1s = u_xy1s[:, not_void]
    z1 = z1[not_void]

    # Move to world coordinates
    n_xyz1s = np.dot(np.linalg.inv(K1), u_xy1s)
    n_xyz1s = n_xyz1s * z1 / n_xyz1s[2, :]
    xyz_w = np.dot(R1.T, n_xyz1s - T1[:,None])

    # Reproject into image 2
    n_xyz2s = np.dot(R2, xyz_w) + T2[:,None]
    u_xy2s = np.dot(K2, n_xyz2s)
    z2 = u_xy2s[2,:]
    u_xy2s = u_xy2s / z2

    # Get SIFT descriptors
    query_kps, train_kps = ([cv.KeyPoint(x = xys[0], y = xys[1], _size=10) for xys in u_xy1s[[0, 1], ::args.skip_kps].T],
                            [cv.KeyPoint(x = xys[0], y = xys[1], _size=10) for xys in u_xy2s[[0, 1], ::args.skip_kps].T])
    
    if len(query_kps) < args.num_query_kps:
        print('Skipping pair {} - {} because of lack of kps'.format(idx1, idx2))
        continue
#     correct_indices, keypoints[0] = mask_random(keypoints[0], args.outlier_rate)

    # Filtering given number of query keypoints
    num_query_kps = args.num_query_kps
    correct_matching, query_kps = mask_random(query_kps, num_query_kps)
    num_query_kps = len(query_kps)
    
    # Dividing train keypoints into inliers / outliers
    inlier_kps = [train_kps[i] for i in correct_matching]
    outlier_kps = [kp for i, kp in enumerate(train_kps) if i not in correct_matching]
    correct_matching = range(num_query_kps)
    
    # Selecting given number of outliers to use
    num_outliers = int(num_query_kps * args.outlier_rate / (1 - args.outlier_rate))
    if len(outlier_kps) < num_outliers:
        print('Skipping pair {} - {} because of lack of kps to add as outliers'.format(idx1, idx2))
        continue
    _, outlier_kps = mask_random(outlier_kps, num_outliers)
    
    query_kps, train_kps, correct_matching = shuffle_kps(query_kps, inlier_kps, outlier_kps, correct_matching)
  

    sift = cv.xfeatures2d.SIFT_create()

    kp1, desc1 = sift.compute(data1['image'], query_kps)
    kp2, desc2 = sift.compute(data2['image'], train_kps)

    bf = cv.BFMatcher(cv.NORM_L2, crossCheck=False)
    matches = bf.match(desc1, desc2)

    greedy_ = greedy(desc1, desc2)
    lss = LSS(desc1, desc2)
    lsl = LSL(desc1, desc2)
    
    overall_cv = len(matches)
    overall_query = len(kp1)
    overall_train = len(kp2)
    cv_corr = sum([1 for tmp in matches if tmp.trainIdx == correct_matching[tmp.queryIdx]])
    greedy_corr = sum([1 for i, g in enumerate(greedy_) if g == correct_matching[i]])
    lss_corr = sum([1 for i, g in enumerate(lss) if g == correct_matching[i]])
    lsl_corr = sum([1 for i, g in enumerate(lsl) if g == correct_matching[i]])

    cv_outliers = sum([1 for tmp in matches if tmp.trainIdx not in correct_matching])
    greedy_outliers = sum([1 for i, g in enumerate(greedy_) if g not in correct_matching])
    lss_outliers = sum([1 for i, g in enumerate(lss) if g not in correct_matching])
    lsl_outliers = sum([1 for i, g in enumerate(lsl) if g not in correct_matching])
    
    statistics = (overall_cv, overall_query, overall_train, cv_corr, greedy_corr, lss_corr,
                  lsl_corr, cv_outliers, greedy_outliers, lss_outliers, lsl_outliers)
    
    accs = compute_accs(statistics)
#     if accs['LSL_acc'] < 0.2:
#         print("!!!\n{} {} is a very bad pair for matching with LSL Acc: {}".format(idx1, idx2, accs['LSL_acc']))
    
    res = update_res(res, accs, idx1, idx2)

#     except KeyboardInterrupt:
#         exit(0)
        
#     except:
#         print('Pair {} - {} is skipped'.format(idx1, idx2))
#         continue
    
    
with open(args.log_file, 'w') as f:
    json.dump(res, f, indent=True)

print(f'OpenCV Acc: ', res['cv_corr'] / res['overall_cv'])
print(f'OpenCV Acc Macro: ', np.mean(res['OpenCV_acc']))

print(f'Greedy Acc: ', res['greedy_corr'] / res['overall_query'])
print(f'Greedy Acc Macro: ', np.mean(res['Greedy_acc']))
      
print(f'LSS Acc: ', res['lss_corr'] / res['overall_query'])
print(f'LSS Acc Macro: ', np.mean(res['LSS_acc']))
      
print(f'LSL Acc: ', res['lsl_corr'] / res['overall_query'])
print(f'LSL Acc Macro: ', np.mean(res['LSL_acc']))
    