import numpy as np
import ot
def myEmdRigid(location_a,weight_a,location_c,weight_c,maxIterTimes = 5):    
    for iterTime in range(maxIterTimes):
        costMatrix = ot.dist(location_a,location_c)
        flowMartrix = ot.emd(weight_a, weight_c, costMatrix)
        matrixB = (location_a.T).dot(flowMartrix)
        matrixB = matrixB.dot(location_c)
        matrixU,matrixS,matrixVT = np.linalg.svd(matrixB)
        diagList = list([1 for i in range(len(matrixB)-1)])
        diagList.append(np.linalg.det(matrixU)*np.linalg.det(matrixVT))
        matrixR = matrixU.dot( np.diag(  diagList  ))
        matrixR = matrixR.dot(matrixVT)
        location_a = location_a.dot(matrixR)
        loss = np.sum(costMatrix*flowMartrix)
        print("loss = ",loss)
    return flowMartrix,loss,location_a




def test_myEmdRigid():
    n1 = 100
    n2 = 8
    d = 2
    location_a = np.random.rand(n1,d)*100
    location_c = np.random.rand(n2,d)*100
    weight_a = np.random.rand(n1); weight_a = weight_a / sum(weight_a)
    weight_b = np.random.rand(n2); weight_b = weight_b / sum(weight_b)
    location_a = location_a - weight_a.dot(location_a)
    location_c = location_c - weight_b.dot(location_c)

    myEmdRigid(location_a,weight_a,location_c,weight_b)


test_myEmdRigid()
