# NOT PyTorch version of MLOT-Sinkhorn algorithm
# The speed is slower, but more space-efficient
# When doing large-scale retrieval, this implementation is recommended

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[MLOT work on {}]".format(device))


def multi_sinkhorn_single(s, t, M, lbd, numItermax=2000):
    n = [M[i].shape[0] for i in range(len(M))]
    n.append(M[-1].shape[1])
    _N = len(n)

    ##### initialization #####
    K = [torch.exp(M[i] / (-lbd)).to(device, dtype=torch.float64) for i in range(_N-1)]
    u = [torch.ones(n[i], 1).to(device, dtype=torch.float64) / n[i] for i in range(_N-1)]
    v = [torch.ones(n[i+1], 1).to(device, dtype=torch.float64) / n[i+1] for i in range(_N-1)]

    b = [torch.ones(n[i], 1).to(device, dtype=torch.float64) / n[i] for i in range(1, _N-1)]
    b.insert(0, s.unsqueeze(1).to(device, dtype=torch.float64))
    b.append(t.unsqueeze(1).to(device, dtype=torch.float64))

    ##### Iterations #####
    for ii in range(numItermax):
        u_pre = u.copy()
        v_pre = v.copy()
        _num_error_flag = False
        for i in range(_N-1):
            u[i] = b[i] / torch.mm(K[i], v[i])
        for i in range(_N-1):
            v[i] = b[i+1] / torch.mm(K[i].T, u[i])
        for i in range(1, _N-1):
            b[i] = torch.sqrt((K[i-1].T @ u[i-1]) * (K[i] @ v[i]))

        for i in range(_N-1):
            if torch.isnan(u[i]).any() or torch.isnan(v[i]).any() or torch.isinf(u[i]).any() or torch.isinf(v[i]).any():
                _num_error_flag = True
                break
        if _num_error_flag:
            print("Numerical error at iteration {}".format(ii))
            u = u_pre
            v = v_pre
            break

        if ii % 1000 == 0:
            error_u = [torch.norm(u[i] - u_pre[i]).item() for i in range(_N-1)]
            error_v = [torch.norm(v[i] - v_pre[i]).item() for i in range(_N-1)]
            print("{}\t{} {}".format(ii, sum(error_u), sum(error_v)))

    P = [u[i].reshape(-1,1) * K[i] * v[i].reshape(1,-1) for i in range(_N-1)]

    return P


if __name__ == '__main__':
    # 测试小样例, 需精准
#     s = torch.ones(2, dtype=torch.float64).to(device) / 2
#     t = torch.ones(2, dtype=torch.float64).to(device) / 2
#     M = [
#         torch.tensor([[10, 10, 1, 2], [2, 1, 10, 10]]).to(device),
#         torch.tensor([[1, 10], [10, 1], [10, 1], [5, 5]]).to(device)
#     ]
#     M = [Mi / 10 for Mi in M]
#     T = multi_sinkhorn_single(s, t, M, 0.01, numItermax=2000)
#     print(T)

    # 测试大样例, 需完成迭代
    query_num = 250
    s = torch.ones(query_num, dtype=torch.float64).to(device) / query_num
    t = torch.ones(query_num, dtype=torch.float64).to(device) / query_num
    M = [
        torch.rand(query_num, 150000).to(device),
        torch.rand(150000, query_num).to(device)
    ]
    _max_num = max([torch.max(Mi).item() for Mi in M])
    M = [Mi / _max_num for Mi in M]
    T = multi_sinkhorn_single(s, t, M, 0.01, numItermax=10000)
    print(T[0].shape, T[1].shape)

