import numpy as np
import argparse
import os
import sys

sys.path.append('../utils/SGW/lib')
sys.path.append('..')
from QDOT.QDOT_numpy import *
from joblib import Parallel, delayed
from itertools import product
from sgw_numpy import sgw_cpu

def compute_loss(X1, X2, method = 'IQDOT'):
    if(method == 'IQDOT'):
        loss = QDOT(X1, X2, sigma = 200, intergal = True, tor = 0, scale = False)
    elif(method == 'SGW'):
        loss = sgw_cpu(X1, X2, nproj = 50)
    return loss

def main():
    parser = argparse.ArgumentParser(description='Compute pairwise losses with selectable method.')
    parser.add_argument('--method', type=str, default='IQDOT', choices=['IQDOT', 'SGW'],
    help='Distance')
    parser.add_argument('--point_step', type=int, help='sample points', default=1, choices=[1, 2])
    args = parser.parse_args()
    point_step = args.point_step
    
    X_train = np.load('modelnet40_7cls_points.npy')[:, ::point_step, :]
    X_test = np.load('shapenetpart_7cls_points.npy')[:, ::point_step, :]
    
    print('ModelNet40 Dataset:', X_train.shape)
    print('ShapenetPart Dataset:', X_test.shape)
    
    n1 = len(X_test)
    n2 = len(X_train)
    Loss = np.zeros((n1, n2), dtype=float)
    def _pair_losses(i, j, method):
        return (
        i, j,
        compute_loss(X_test[i], X_train[j], method=method),
        )
    
    results = Parallel(
        n_jobs=100,
        batch_size=512,
        backend='loky',
        verbose=5
    )(
        delayed(_pair_losses)(i, j, args.method)
        for i, j in product(range(n1), range(n2))
    )


    for i, j, val in results:
        Loss[i, j] = val


    os.makedirs('losses', exist_ok=True)
    out_path = os.path.join('losses', f'Loss-{args.method.lower()}-{X_train.shape[1]}.npy')
    np.save(out_path, Loss)
    print(f'Saved to {out_path}')

if __name__ == '__main__':
    main()


