import torch
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("[Work on {}]".format(device))

def multi_sinkhorn(s, t, M, lbd, tau, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]
    _log = {}

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, a ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float32)
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float32) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float32) / nfull
    a = torch.ones(len(n), nfull, device=device, dtype=torch.float32)
    a[0, :s.shape[0]] = s
    a[-1, :t.shape[0]] = t

    iter_num = 0
    # P_pre = u.unsqueeze(2) * K * v.unsqueeze(1)

    for ii in range(numItermax+1):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        Kv = torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        u = a[:-1] / Kv
        KTu = torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        v = a[1:] / KTu
        a[1:-1] = torch.pow(u[1:] * v[:-1], -lbd/tau)
        # ----------------------------------------------------------

        if torch.isnan(u).any() or torch.isnan(v).any():
            u = u_pre
            v = v_pre
            print("Numerical error at iteration {}".format(ii))
            break

        if ii % 5000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{} {}".format(ii, error_u, error_v))
            # if error_u < 1e-20 and error_v < 1e-20:
            #     print("Error OK {}".format(ii))
            #     break

        # if ii % 1000 == 0:
        #     P_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
        #     error_P = torch.linalg.norm(P_ii - P_pre).item()
        #     print("{}\t{}".format(ii, error_P))
        #     if error_P < 1e-20:
        #         print("Error OK {}".format(ii))
        #         break
        #     P_pre = P_ii
        iter_num += 1


    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]

    _log['iter'] = iter_num
    return P, _log


def multi_sinkhorn_single(s, t, M, lbd, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]
    _log = {}

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, b ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float32) / nfull
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i].double() / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull
    # print("Mi[0, 0] = {}".format(M[0][0, 0]))
    # print("K0 = {}".format(K[0]))
    # print("K1 = {}".format(K[1]))

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float32) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float32) / nfull
    b = torch.ones(len(n), nfull).to(device, dtype=torch.float32)
    b[0, :s.shape[0]] = s
    b[-1, :t.shape[0]] = t
    # ---------------------------------------------------------------

    iter_num = 0
    # P_pre = u.unsqueeze(2) * K * v.unsqueeze(1)
    for ii in range(numItermax+1):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        u = b[:-1] / torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        v = b[1:] / torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        KTu = torch.bmm(K[:-1].transpose(1, 2), u[:-1].unsqueeze(2)).squeeze(2)
        Kv = torch.bmm(K[1:], v[1:].unsqueeze(2)).squeeze(2)
        b[1:-1] = torch.sqrt(KTu * Kv)
        # ----------------------------------------------------------

        if torch.isnan(u).any() or torch.isnan(v).any():
            u = u_pre
            v = v_pre
            print("Numerical error at iteration {}".format(ii))
            break

        if ii % 1000 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()
            print("{}\t{} {}".format(ii, error_u, error_v))
            # if error_u < 1e-20 and error_v < 1e-20:
            #     print("Error OK {}".format(ii))
            #     break

        # if ii > 0 and ii % 1000 == 0:
        #     P_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
        #     error_P = torch.linalg.norm(P_ii - P_pre).item()
        #     print("{}\t{}".format(ii, error_P))
        #     if error_P < 1e-20:
        #         print("Error OK {}".format(ii))
        #         break
        #     P_pre = P_ii
        iter_num += 1

    # cpuu, cpuv, cpuK = u.cpu(), v.cpu(), K.cpu()
    # del u, v, K
    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]
    print("P shape:", [Pi.shape for Pi in P])

    _log['iter'] = iter_num

    # del cpuu, cpuv, cpuK, P_compute
    # torch.cuda.empty_cache()
    del u, v, K, P_compute
    torch.cuda.empty_cache()

    return P, _log


def list_to_array(*lst):
    """ Convert a list if in numpy format """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if len(lst) > 1:
        temp = [torch.tensor(a).to(device) if not isinstance(a, torch.Tensor) else a.to(device) for a in lst ]
    else:
        temp = torch.tensor(lst[0]).to(device) if not isinstance(lst[0], torch.Tensor) else lst[0].to(device)
    return temp


def sinkhorn_distance(M, P, addition=None):
    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    M = [Mi.cpu() if isinstance(Mi, torch.Tensor) else torch.from_numpy(Mi).cpu() for Mi in M]
    P = [Pi.cpu() if isinstance(Pi, torch.Tensor) else torch.from_numpy(Pi).cpu() for Pi in P]
    distance = 0
    for k in range(len(P)):
        try:
            if addition == None:
                distance += torch.sum(P[k] * M[k])
            else:
                distance += torch.sum(P[k] * M[k]) * addition[k]
        except Exception as e:
            print("Error at {}-th: {}".format(k, e))
            print("Please check shape: P[k] {}, M[k] {}".format(P[k].shape, M[k].shape))
    return distance


if __name__ == '__main__':
    s = np.array([0.2, 0.6, 0.2])
    t = np.array([0.1, 0.9])
    M = [torch.tensor([[1, 0.1, 1], [0.1, 1, 1], [1, 1, 0.1]], dtype=torch.float32),
         torch.tensor([[1, 0.1], [1, 0.1], [0.1, 1]], dtype=torch.float32)]

    _max_num = max([Mi.max() for Mi in M])
    record_max = [_max_num for Mi in M]
    M = [Mi / Mi.max() for Mi in M]

    T, log = multi_sinkhorn_single(s, t, M, 0.001, numItermax=1000)
    dis = sinkhorn_distance(M, T)
    print("iter: {}".format(log['iter']))
    print("distance: {}".format(dis))
    for i in range(len(T)):
        print("T[{}]: {}".format(i, T[i]))

