# The simply version of MLOT sinkhorn
# No shape checking, no error detection, only pure iteration

import torch
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("[MLOT-simple Work on {}]".format(device))


def multi_sinkhorn(s, t, M, lbd, tau, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    n_tensor = torch.tensor(n).to(device, dtype=torch.float64)
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, a ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64)
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    a = torch.ones(len(n), nfull, device=device, dtype=torch.float64)
    a[0, :s.shape[0]] = s
    a[-1, :t.shape[0]] = t

#     P_pre = u.unsqueeze(2) * K * v.unsqueeze(1)
    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        Kv = torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        u = a[:-1] / Kv
        KTu = torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        v = a[1:] / KTu
        a[1:-1] = torch.pow(u[1:] * v[:-1], -lbd/tau)
        # ----------------------------------------------------------

        if torch.isnan(u).any() or torch.isnan(v).any():
            u = u_pre
            v = v_pre
            print("Numerical error at iteration {}".format(ii))
            break

        if ii % 1000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{:.9f} {:.9f}".format(ii, error_u, error_v))
        #     if error_u < 1e-20 and error_v < 1e-20:
        #         print("Error OK {}".format(ii))
        #         break
#
#         if ii % 1000 == 0:
#             P_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
#             error_P = torch.linalg.norm(P_ii - P_pre).item()
#             print("{}\t{}".format(ii, error_P))
#             if error_P < 1e-20:
#                 print("Error OK {}".format(ii))
#                 break
#             P_pre = P_ii

    else:
        print("multi_sinkhorn not converged after {} iterations".format(numItermax))

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]

    return P


def multi_sinkhorn_single(s, t, M, lbd, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    n_tensor = torch.tensor(n).to(device, dtype=torch.float64)
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, b ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64) / nfull
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    b = torch.ones(len(n), nfull).to(device, dtype=torch.float64)
    b[0, :s.shape[0]] = s
    b[-1, :t.shape[0]] = t
    # ---------------------------------------------------------------

#     P_pre = u.unsqueeze(2) * K * v.unsqueeze(1)

    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        u = b[:-1] / torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        v = b[1:] / torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        KTu = torch.bmm(K[:-1].transpose(1, 2), u[:-1].unsqueeze(2)).squeeze(2)
        Kv = torch.bmm(K[1:], v[1:].unsqueeze(2)).squeeze(2)
        b[1:-1] = torch.sqrt(KTu * Kv)
        # ----------------------------------------------------------

        if torch.isnan(u).any() or torch.isnan(v).any():
            u = u_pre
            v = v_pre
            print("Numerical error at iteration {}".format(ii))
            break

        if ii % 1000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{} {}".format(ii, error_u, error_v))
        #     if error_u < 1e-20 and error_v < 1e-20:
        #         print("Error OK {}".format(ii))
        #         break

#         if ii > 0 and ii % 1000 == 0:
#             P_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
#             error_P = torch.linalg.norm(P_ii - P_pre).item()
#             print("{}\t{}".format(ii, error_P))
#             if error_P < 1e-20:
#                 print("Error OK {}".format(ii))
#                 break
#             P_pre = P_ii

    else:
        print("multi_sinkhorn_single not converged after {} iterations".format(numItermax))

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]

    return P


def transform_ground_metric(ground_metric, area_index):
    N = len(area_index)
    _M = []
    for k in range(N - 1):
        i1, i2 = area_index[k][0], area_index[k][1]
        j1, j2 = area_index[k+1][0], area_index[k+1][1]
        _M.append(ground_metric[i1:i2, j1:j2])
    return _M


def list_to_array(*lst):
    """ Convert a list if in numpy format """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if len(lst) > 1:
        temp = [torch.tensor(a).to(device) if not isinstance(a, torch.Tensor) else a.to(device) for a in lst ]
    else:
        temp = torch.tensor(lst[0]).to(device) if not isinstance(lst[0], torch.Tensor) else lst[0].to(device)
    return temp


def sinkhorn_distance(M, P, addition=None):
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    M = [Mi.cpu() if isinstance(Mi, torch.Tensor) else torch.from_numpy(Mi).cpu() for Mi in M]
    P = [Pi.cpu() if isinstance(Pi, torch.Tensor) else torch.from_numpy(Pi).cpu() for Pi in P]
    distance = 0
    for k in range(len(P)):
        try:
            if addition == None:
                distance += torch.sum(P[k] * M[k])
            else:
                distance += torch.sum(P[k] * M[k]) * addition[k]
        except Exception as e:
            print("Error at {}-th: {}".format(k, e))
            print("Please check shape: P[k] {}, M[k] {}".format(P[k].shape, M[k].shape))
    return distance


if __name__ == '__main__':
    s = np.array([0.2, 0.6, 0.2])
    t = np.array([0.1, 0.9])
    M = [torch.tensor([[1, 0.1, 1], [0.1, 1, 1], [1, 1, 0.1]], dtype=torch.float64),
         torch.tensor([[1, 0.1], [1, 0.1], [0.1, 1]], dtype=torch.float64)]

    _max_num = max([Mi.max() for Mi in M])
    record_max = [_max_num for Mi in M]
    M = [Mi / Mi.max() for Mi in M]

    T, log = multi_sinkhorn_single(s, t, M, 0.001, record_max, numItermax=1000)
    dis = sinkhorn_distance(M, T)
    print("iter: {}".format(log['iter']))
    print("distance: {}".format(dis))
    for i in range(len(T)):
        print("T[{}]: {}".format(i, T[i]))

