# The detail version of MLOT sinkhorn
# Aiming to record detail infomations per iteration.
# Such as cost, error, etc.
#
# Assume 3-Layer MLOT problem
# So only the second layer's history is recorded

import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("[MLOT-detail Work on {}]".format(device))


def multi_sinkhorn(s, t, M, lbd, tau, radio, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]
    _log = {
        'cost_ii': [],
        'error_uv': [],
        'second_layer': [],
    }

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, a ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64)
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    a = torch.ones(len(n), nfull, device=device, dtype=torch.float64)
    a[0, :s.shape[0]] = s
    a[-1, :t.shape[0]] = t

    iter_num = 0

    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        Kv = torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        u = a[:-1] / Kv
        KTu = torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        v = a[1:] / KTu
        a[1:-1] = torch.pow(u[1:] * v[:-1], -lbd/tau)
        # ----------------------------------------------------------

        if ii % 100 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()

            P_compute_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
            P_ii = [P_compute_ii[i, :n[i], :n[i + 1]] for i in range(len(n) - 1)]
            second = torch.sum(P_ii[0], dim=0)

            _log['cost_ii'].append(radio * sinkhorn_distance(M, P_ii))
            _log['error_uv'].append(error_u + error_v)
            _log['second_layer'].append(second)

            # print("{}\t{:.9f}".format(ii, error_u+error_v))

        iter_num += 1
    else:
        print("multi_sinkhorn not converged after {} iterations".format(numItermax))

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]

    _log['iter'] = iter_num
    return P, _log


def multi_sinkhorn_single(s, t, M, lbd, radio, numItermax=1000):
    s, t = list_to_array(s, t)
    s, t = s.squeeze(), t.squeeze()
    M = [list_to_array(Mi) for Mi in M]
    _log = {
        'cost_ii': [],
        'error_uv': [],
        'second_layer': [],
        'P': [],
    }

    n = [Mi.shape[0] for Mi in M] + [t.shape[0]]
    nfull = max(n) + 1
    n_mask = torch.tensor([
        [1 if col < n[row] else 0 for col in range(nfull) ] for row in range(len(n))
    ]).to(device)

    # Initialization K, u, v, b ------------------------------------
    K = torch.zeros(len(n)-1, nfull, nfull).to(device, dtype=torch.float64) / nfull
    for i in range(len(n)-1):
        _shape = M[i].shape
        K[i, :_shape[0], :_shape[1]] = torch.exp(M[i] / -lbd)
        K[i, _shape[0]:, _shape[1]:] = 1/nfull

    u = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    v = torch.ones(len(n)-1, nfull, device=device, dtype=torch.float64) / nfull
    b = torch.ones(len(n), nfull).to(device, dtype=torch.float64)
    b[0, :s.shape[0]] = s
    b[-1, :t.shape[0]] = t
    # ---------------------------------------------------------------

    iter_num = 0

    for ii in range(numItermax):
        u_pre = u
        v_pre = v

        # Update ---------------------------------------------------
        u = b[:-1] / torch.bmm(K, v.unsqueeze(2)).squeeze(2)
        v = b[1:] / torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2)
        KTu = torch.bmm(K[:-1].transpose(1, 2), u[:-1].unsqueeze(2)).squeeze(2)
        Kv = torch.bmm(K[1:], v[1:].unsqueeze(2)).squeeze(2)
        b[1:-1] = torch.sqrt(KTu * Kv)
        # ----------------------------------------------------------

        if ii % 100 == 0:
            error_u = torch.linalg.norm((torch.log(u)-torch.log(u_pre))*n_mask[:-1]).item()
            error_v = torch.linalg.norm((torch.log(v)-torch.log(v_pre))*n_mask[1:]).item()

            P_compute_ii = u.unsqueeze(2) * K * v.unsqueeze(1)
            P_ii = [P_compute_ii[i, :n[i], :n[i + 1]] for i in range(len(n) - 1)]
            second = torch.sum(P_ii[0], dim=0)

            _log['cost_ii'].append(radio * sinkhorn_distance(M, P_ii))
            _log['error_uv'].append(error_u + error_v)
            _log['second_layer'].append(second)

            # print("{}\t{:.9f}".format(ii, error_u+error_v))

        iter_num += 1
    else:
        print("multi_sinkhorn_single not converged after {} iterations".format(numItermax))

    P_compute = u.unsqueeze(2) * K * v.unsqueeze(1)
    P = [P_compute[i, :n[i], :n[i+1]] for i in range(len(n)-1)]

    _log['iter'] = iter_num

    return P, _log


def transform_ground_metric(ground_metric, area_index):
    N = len(area_index)
    _M = []
    for k in range(N - 1):
        i1, i2 = area_index[k][0], area_index[k][1]
        j1, j2 = area_index[k+1][0], area_index[k+1][1]
        _M.append(ground_metric[i1:i2, j1:j2])
    return _M


def list_to_array(*lst):
    """ Convert a list if in numpy format """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if len(lst) > 1:
        temp = [torch.tensor(a).to(device) if not isinstance(a, torch.Tensor) else a.to(device) for a in lst]
    else:
        temp = torch.tensor(lst[0]).to(device) if not isinstance(lst[0], torch.Tensor) else lst[0].to(device)
    return temp


def sinkhorn_distance(M, P):
    M = [Mi.cpu() if isinstance(Mi, torch.Tensor) else torch.from_numpy(Mi).cpu() for Mi in M]
    P = [Pi.cpu() if isinstance(Pi, torch.Tensor) else torch.from_numpy(Pi).cpu() for Pi in P]
    distance = 0
    for k in range(len(P)):
        distance += torch.sum(P[k] * M[k])
    return distance


if __name__ == '__main__':
    s = np.array([0.2, 0.6, 0.2])
    t = np.array([0.1, 0.9])
    M = [torch.tensor([[1, 0.1, 1], [0.1, 1, 1], [1, 1, 0.1]], dtype=torch.float64),
         torch.tensor([[1, 0.1], [1, 0.1], [0.1, 1]], dtype=torch.float64)]

    _max_num = max([Mi.max() for Mi in M])
    record_max = [_max_num for Mi in M]
    M = [Mi / Mi.max() for Mi in M]

    T, log = multi_sinkhorn_single(s, t, M, 0.001, record_max, numItermax=1000)
    dis = sinkhorn_distance(M, T)
    print("iter: {}".format(log['iter']))
    print("distance: {}".format(dis))
    for i in range(len(T)):
        print("T[{}]: {}".format(i, T[i]))

