# The simply version of MLOT sinkhorn
# No shape checking, no error detection, only pure iteration
# 
# Also the expanded-M version
# M is expanded by one duplicate row & column, to maintain non-zero during iteration

import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("[MLOT-virtual Work on {}]".format(device))


def expandM(M):
    _M = [None for _ in range(len(M))]
    n = [Mi.shape[0] for Mi in M] + [M[-1].shape[1]]
    for i in range(len(M)):
        _M[i] = torch.cat([M[i], torch.zeros(1, n[i+1]).to(device)], dim=0).to(device)
        _M[i] = torch.cat([_M[i], torch.zeros(n[i]+1, 1).to(device)], dim=1).to(device)
        _M[i][-1, -1] = torch.inf
    return _M


def mlot_virtual(s, t, M, lbd, tau, delta=1e-8, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    s, t = s + delta/len(s), t + delta/len(t)   #  将delta均分在每个位置
    mass = s.sum()
    s = torch.cat([s, torch.tensor([delta]).to(device)], dim=0)    # 总的增加一个 delta 的额外位置
    t = torch.cat([t, torch.tensor([delta]).to(device)], dim=0)

    M = expandM([list_to_array(Mi) for Mi in M])
    n = [Mi.shape[0] for Mi in M] + [M[-1].shape[1]]
    _log = {}

    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, a ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64)
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull
      
    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    a = torch.ones(len(n), nfull, device=device, dtype=torch.float64)
    a[0, :s.shape[0]] = s
    a[-1, :t.shape[0]] = t


    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        Kv = torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        u = a[:-1] / Kv
        KTu = torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        v = a[1:] / KTu
        a[1:-1] = torch.pow(u[1:] * v[:-1], -lbd/tau)

        for i in range(1, len(a)-1):
            a[i][:n[i]-1] = torch.clamp(a[i][:n[i]-1], min=delta/(n[i]-1), max=mass)
            a[i][n[i]-1] = torch.clamp(a[i][n[i]-1], min=0, max=delta)
        # ----------------------------------------------------------

        # if torch.isnan(u).any() or torch.isnan(v).any() or torch.isinf(u).any() or torch.isinf(v).any():
        #     u = u_pre
        #     v = v_pre
        #     print("Numerical error at iteration {}".format(ii))
        #     break

        if ii % 1000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{:.9f} {:.9f}".format(ii, error_u, error_v))
            if error_u < 1e-20 and error_v < 1e-20:
                print("Error OK after {} iterations".format(ii))
                break

    else:
        print("multi_sinkhorn not converged after {} iterations".format(numItermax))

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P_virtual = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]
    P = [P_compute[i, :n[i]-1, :n[i+1]-1] for i in range(len(n)-1)]
    _log['P_virtual'] = P_virtual

    for i in range(len(P_virtual)):
        print("{}: {}, {}".format(i, P_virtual[i][-1, :-1].sum(), P_virtual[i][:-1, -1].sum()))

    print("Check virtual values")
    source_out = P_virtual[0][:-1, -1].reshape(-1).cpu()
    print("source error: {}".format(torch.norm(source_out - torch.ones(n[0]-1)*delta/(n[0]-1)) / torch.norm(source_out)))
    for i in range(1, len(n)-1):
        virtual1 = P_virtual[i-1][-1, :-1].reshape(-1)
        virtual2 = P_virtual[i][:-1, -1].reshape(-1)
        print("error {}: {}, radio={}".format(i, torch.norm(virtual1 - virtual2) / torch.norm(virtual1), torch.norm(virtual1)/torch.norm(virtual2)))
    target_in = P_virtual[-1][-1, :-1].reshape(-1).cpu()
    print("target error: {}".format(torch.norm(target_in - torch.ones(n[-1]-1)*delta/(n[-1]-1)) / torch.norm(target_in)))


    virtual_layers = [torch.sum(P[0], dim=1) - delta/(n[0]-1)]
    for i in range(len(P)-1):
        tmp = torch.sum(P[i], dim=0)
        tmp = torch.clamp(tmp - P_virtual[i][-1, :-1], min=0)
        # tmp = torch.clamp(tmp - P_virtual[i+1][:-1, -1], min=0)
        virtual_layers.append(tmp)
    virtual_layers.append(torch.sum(P[-1], dim=0) - delta/(n[-1]-1))
    _log['layers'] = virtual_layers

#     tmp = torch.sum(P[0], dim=0)
#     tmp = torch.clamp(tmp - P_virtual[1][:-1, -1], min=0)
#     _log['another_layer'] = tmp
    
    return P, _log


def mlot_virtual_single(s, t, M, lbd, delta=1e-8, numItermax=1000):
#     s, t = list_to_array(s, t)
#     s, t = s.squeeze(), t.squeeze()
    s, t = s + delta/len(s), t + delta/len(t)
    mass = s.sum()
    _extra = torch.tensor([delta]).to(device)
    s = torch.cat([s, _extra], dim=0)
    t = torch.cat([t, _extra], dim=0)
  

    M = expandM([list_to_array(Mi) for Mi in M])
    n = [Mi.shape[0] for Mi in M] + [M[-1].shape[1]]
    _log = {
        'e_u': [],
        'cost_ii': [],
    }

    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, b ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64)
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull
    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    b = torch.ones(len(n), nfull).to(device, dtype=torch.float64) * delta
    b[0, :s.shape[0]] = s
    b[-1, :t.shape[0]] = t

    _need_to_clamp = 0
    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        u = b[:-1] / torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        v = b[1:] / torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        KTu = torch.bmm(K[:-1].transpose(1, 2), u[:-1].unsqueeze(2)).squeeze(2)
        Kv = torch.bmm(K[1:], v[1:].unsqueeze(2)).squeeze(2)
        b[1:-1] = torch.sqrt(KTu * Kv)

        for i in range(1, len(b)-1):
            b[i][:n[i]-1] = torch.clamp(b[i][:n[i]-1], min=delta/(n[i]-1), max=mass)
            if b[i][n[i]-1] > delta:
                _need_to_clamp += 1
            b[i][n[i]-1] = torch.clamp(b[i][n[i]-1], min=0, max=delta)
        # ----------------------------------------------------------

        if torch.isnan(u).any() or torch.isnan(v).any() or torch.isinf(u).any() or torch.isinf(v).any():
            u = u_pre
            v = v_pre
            print("Numerical error at iteration {}".format(ii))
            break
        if ii % 1000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{:.9f} {:.9f}".format(ii, error_u, error_v))
            
            # _log['e_u'].append(sqrt(error_u + error_v) if error_u + error_v > 1 else error_u + error_v)
            # P_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
            # P_ii = [P_ii[i, :n[i]-1, :n[i+1]-1] for i in range(len(n)-1)]
            # cost_ii = sinkhorn_distance(M, P_ii, record_max)
            # _log['cost_ii'].append(cost_ii.item())
            
#             if error_u < 1e-5 and error_v < 1e-5:
#                 print("Error OK after {} iterations".format(ii))
#                 break

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P_virtual = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]
    P = [P_compute[i, :n[i]-1, :n[i+1]-1] for i in range(len(n)-1)]


    virtual_layers = [torch.sum(P[0], dim=1) - delta/(n[0]-1)]
    for i in range(len(P)-1):
        tmp = torch.sum(P[i], dim=0)
        tmp = torch.clamp(tmp - P[i+1][:, -1], min=0)
        virtual_layers.append(tmp)
    _log['layers'] = virtual_layers

    return P, _log


def transform_ground_metric(ground_metric, area_index):
    N = len(area_index)
    _M = []
    for k in range(N - 1):
        i1, i2 = area_index[k][0], area_index[k][1]
        j1, j2 = area_index[k+1][0], area_index[k+1][1]
        _M.append(ground_metric[i1:i2, j1:j2])
    return _M


def list_to_array(*lst):
    """ Convert a list if in numpy format """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if len(lst) > 1:
        temp = [torch.tensor(a).to(device) if not isinstance(a, torch.Tensor) else a.to(device) for a in lst ]
    else:
        temp = torch.tensor(lst[0]).to(device) if not isinstance(lst[0], torch.Tensor) else lst[0].to(device)
    return temp


def sinkhorn_distance(M, P, addition=None):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    distance = 0
    for k in range(len(P)):
        try:
            if addition == None:
                # print("Pk*Mk shape:", (P[k] * M[k]).shape)
                distance += (P[k] * M[k]).sum()
            else:
                distance += torch.sum(P[k] * M[k]) * addition[k]
        except Exception as e:
            print("Error at {}-th: {}".format(k, e))
            print("Please check shape: P[k] {}, M[k] {}".format(P[k].shape, M[k].shape))
    return distance


if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    s = torch.tensor([0.2, 0.6, 0.2])
    t = torch.tensor([0.1, 0.9])
    M = [
        torch.tensor([[1,0.1,1],[0.1,1,1],[1,1,0.1]]).to(device, dtype=torch.float64),
        torch.tensor([[1,0.1], [1,0.1], [0.1,1]]).to(device, dtype=torch.float64)
    ]
    T_v, log_v = mlot_virtual(s, t, M, 1e-2, 1e-1, 1e-1, 2000)
    print(log_v['P_virtual'])

