import os
import argparse
import pdd

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--k', type=int)
    args = parser.parse_args()
    
    if args.k is None:
        k = 100
    else:
        k = args.k
    
    pairs = [
        ('AFIBOH', 'NENCUF'),
        ('COLYEI', 'POCLOK'),
        ('DTBIPT', 'DTHBPD10'),
        ('HIFCAB', 'JEPLIA'),
        ('LALNET', 'POCPAA'),
    ]
    
    for id1, id2 in pairs:
        
        path_1 = os.path.join('identical_by_all_atoms', id1 + '.cif')
        path_1 = os.path.join('identical_by_all_atoms', id2 + '.cif')
    
        crystal_1 = pdd.read_cif(path_1)
        crystal_2 = pdd.read_cif(path_1)
        
        pdd_1 = pdd.pdd(crystal_1, k)
        pdd_2 = pdd.pdd(crystal_2, k)
        
        emd = pdd.emd(pdd_1, pdd_2)
        print(f'EMD distance (k={k}) between {id1} and {id2}:', emd)
