import argparse
import os
from tqdm import tqdm
import numpy as np
import torch

def batch_pairwise_dist(x,y):
	bs, num_points_x, points_dim = x.size()
	_, num_points_y, _ = y.size()
	xx = torch.bmm(x, x.transpose(2,1))
	yy = torch.bmm(y, y.transpose(2,1))
	zz = torch.bmm(x, y.transpose(2,1))
	# if self.use_cuda:
	# 	dtype = torch.cuda.LongTensor
	# else:
	# 	dtype = torch.LongTensor
	diag_ind_x = torch.arange(0, num_points_x, device=x.device) # send to same device as x
	diag_ind_y = torch.arange(0, num_points_y, device=y.device) # send to same device as y
	rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2,1))
	ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz)
	P = (rx.transpose(2,1) + ry - 2*zz)
	return P

def chamfer_distance(preds, gts):
	P = batch_pairwise_dist(gts, preds)
	mins, _ = torch.min(P, 1)
	loss_1 = torch.sum(mins, dim=1) # (B,)
	mins, _ = torch.min(P, 2)
	loss_2 = torch.sum(mins, dim=1) # (B,)
	return loss_1 + loss_2 # (B,)

def main():
    # parse and load specs
    parser = argparse.ArgumentParser()
    parser.add_argument("--index_file", "-i", default="/home/ubuntu/orienter-3d/data/shapenet_index_files/all_point_clouds/all.txt", help="Path to index file for all point clouds")
    args = parser.parse_args()    
    index_file_path = args.index_file
    print(f"{index_file_path=}")

   # Load the index file

    with open(index_file_path, "r") as f:
        index_file = f.readlines()
    index_file = [x.strip() for x in index_file]

    # Load the cube flips

    cube_flip_path = "/home/ubuntu/orienter-3d/utils/24_cube_flips.pt"
    cube_flips = torch.load(cube_flip_path).cuda()

    # Iterate through the point clouds and calculate the confusion matrices
    # Save each one to the same path as the point cloud but with the suffix _confusion_mtx.npy

    for path in tqdm(index_file):
        # Load the point cloud
        point_cloud = torch.from_numpy(np.load(path)).float().cuda()
        # subsample first 2k points
        point_cloud = point_cloud[:2000]
        # apply all possible cube flips to the point cloud
        flipped_point_clouds = []
        for cube_flip in cube_flips:
            flipped_point_cloud = point_cloud @ cube_flip.t()
            flipped_point_clouds.append(flipped_point_cloud)

        # compute pairwise chamfer distances between all pairs of point clouds in the flipped_point_clouds

        chamfer_distance_matrix = torch.zeros(len(flipped_point_clouds), len(flipped_point_clouds)).to(point_cloud)
        for i, point_cloud_1 in enumerate(flipped_point_clouds):
            for j, point_cloud_2 in enumerate(flipped_point_clouds):
                chamfer_distance_matrix[i, j] = chamfer_distance(point_cloud_1.unsqueeze(0), point_cloud_2.unsqueeze(0))

        # Save the confusion matrix
        save_path = path.replace("point_cloud.npy", "confusion_mtx.npy")
        np.save(save_path, chamfer_distance_matrix.cpu().numpy())

if __name__ == "__main__":
    main()