import numpy as np
from math import isclose
import ctypes
import os

import torch

print("import ot_solver.py")


dp_dll_path = os.path.abspath("/home_new/shiliangliang/taorui/base/cpp/sp.so")
flow_dll_path = os.path.abspath("/home_new/shiliangliang/taorui/base/cpp/flow.so")
assert os.path.exists(dp_dll_path), f"dp.so file not exists at {dp_dll_path}"
assert os.path.exists(dp_dll_path), f"flow.so file not exists at {flow_dll_path}"

handle = True
if not handle:
    assert False
else:
    # -------------------------------------------------
    # 定义三维指针类型
    DoublePtr = ctypes.POINTER(ctypes.c_double)
    DoublePtrPtr = ctypes.POINTER(DoublePtr)
    DoublePtrPtrPtr = ctypes.POINTER(DoublePtrPtr)
    IntPtr = ctypes.POINTER(ctypes.c_int)
    IntPtrPtr = ctypes.POINTER(IntPtr)
    IntPtrPtrPtr = ctypes.POINTER(IntPtrPtr)


    def numpy_to_ctypes_3d(arr, mtype):
        if mtype == 'double':
            dim1, dim2, dim3 = arr.shape
            top_level = (DoublePtrPtr * dim1)()
            for i in range(dim1):
                second_level = (DoublePtr * dim2)()
                for j in range(dim2):
                    second_level[j] = arr[i, j].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
                top_level[i] = second_level
            return top_level
        elif mtype == 'int':
            dim1, dim2, dim3 = arr.shape
            top_level = (IntPtrPtr * dim1)()
            for i in range(dim1):
                second_level = (IntPtr * dim2)()
                for j in range(dim2):
                    second_level[j] = arr[i, j].ctypes.data_as(ctypes.POINTER(ctypes.c_int))
                top_level[i] = second_level
            return top_level


    def numpy_to_ctypes_2d(arr, mtype):
        if mtype == 'double':
            arr_ptrs = (ctypes.POINTER(ctypes.c_double) * arr.shape[0])()
            for i in range(arr.shape[0]):
                arr_ptrs[i] = arr[i].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
            return arr_ptrs
        elif mtype == 'int':
            arr_ptrs = (ctypes.POINTER(ctypes.c_int) * arr.shape[0])()
            for i in range(arr.shape[0]):
                arr_ptrs[i] = arr[i].ctypes.data_as(ctypes.POINTER(ctypes.c_int))
            return arr_ptrs


    def numpy_to_ctypes_1d(arr, mtype):
        if mtype == 'double':
            return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
        elif mtype == 'int':
            return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_int))


    # 读取 lib 并声明函数接收/返回值
    dp_lib = ctypes.CDLL(dp_dll_path)
    dp_lib.create_short_path.argtypes = [DoublePtrPtrPtr, IntPtr, ctypes.c_int, ctypes.c_int, DoublePtrPtr]
    dp_lib.create_short_path.restype = None
    flow_lib = ctypes.CDLL(flow_dll_path)
    flow_lib.minCostFlow.argtypes = [DoublePtrPtrPtr, DoublePtr, DoublePtr, IntPtr, ctypes.c_int]
    flow_lib.minCostFlow.restype = ctypes.c_double
    # -------------------------------------------------


def fill_M(_M, nfull):
    return np.array([
        np.pad(Mi, ((0, nfull - Mi.shape[0]), (0, nfull - Mi.shape[1])), 'constant') for Mi in _M
    ])


def cpp_create_short_path(_M, _n):
    _n = np.array(_n).reshape(-1, 1)
    nfull = max(_n)[0]
    layer = len(_n)
    _M = fill_M(_M, nfull)
    short_path = np.ones((_n[0, 0], max(_n)[0]), dtype=np.float64)  # 先创建一个空的, 之后向cpp传入其指针原地修改

    M_ptr = numpy_to_ctypes_3d(_M, 'double')
    n_ptr = numpy_to_ctypes_1d(_n, 'int')
    short_path_ptr = numpy_to_ctypes_2d(short_path, 'double')
    dp_lib.create_short_path(M_ptr, n_ptr, layer, nfull, short_path_ptr)

    return short_path[:, :_n[-1, 0]]


def cpp_mincost_flow(_M, _n, _s, _t):
    _n = np.array(_n).reshape(-1, 1)
    nfull = max(_n)[0]
    layer = len(_n)
    _M = fill_M(_M, nfull)

    n_ptr = numpy_to_ctypes_1d(_n, 'int')
    source_ptr = numpy_to_ctypes_1d(_s, 'double')
    target_ptr = numpy_to_ctypes_1d(_t, 'double')
    M_ptr = numpy_to_ctypes_3d(_M, 'double')

    result = flow_lib.minCostFlow(M_ptr, source_ptr, target_ptr, n_ptr, layer)

    return result


def list_to_array(*lst):
    r""" Convert a list if in torch format """
    if len(lst) > 1:
        return [np.array(a) for a in lst]
    else:
        return np.array(lst[0])


def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
                   verbose=False, log=False, warn=False, **kwargs):
    a, b, M = list_to_array(a, b, M)

    # init data
    dim_a = len(a)
    dim_b = b.shape[0]

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    if log:
        log = {}

    if n_hists:
        u = np.ones((dim_a, n_hists), dtype=np.float64) / dim_a
        v = np.ones((dim_b, n_hists), dtype=np.float64) / dim_b
    else:
        u = np.ones(dim_a, dtype=np.float64) / dim_a
        v = np.ones(dim_b, dtype=np.float64) / dim_b

    K = np.exp(M / (-reg))

    Kp = (1 / a).reshape(-1, 1) * K

    err = 1
    for ii in range(numItermax):
        uprev = u
        vprev = v
        KtransposeU = np.dot(K.T, u)
        v = b / KtransposeU
        u = 1. / np.dot(Kp, v)

        if (np.any(KtransposeU == 0)
                or np.any(np.isnan(u)) or np.any(np.isnan(v))
                or np.any(np.isinf(u)) or np.any(np.isinf(v))):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            print('Warning: numerical errors at iteration {}'.format(ii))
            u = uprev
            v = vprev
            break
        if ii % 10 == 0:
            if n_hists:
                tmp2 = np.einsum('ik,ij,jk->jk', u, K, v)
            else:
                tmp2 = np.einsum('i,ij,j->j', u, K, v)
            err = np.linalg.norm(tmp2 - b)  # violation of marginal

            if err < stopThr:
                break
            if verbose:
                if ii % 200 == 0:
                    print(
                        '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(ii, err))
    else:
        print("Sinkhorn did not converge after {} iterations.".format(numItermax))

    if log:
        log['niter'] = ii
        log['u'] = u
        log['v'] = v

    if n_hists:  # return only loss
        res = np.einsum('ik,ij,jk,ij->k', u, K, v, M)
        if log:
            return res, log
        else:
            return res

    else:  # return OT matrix

        if log:
            return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
        else:
            return u.reshape((-1, 1)) * K * v.reshape((1, -1))


def solve(s, t, short_path, reg=1e-2, max_iter=1000):
    _s = s.reshape(-1)
    _t = t.reshape(-1)
    T = sinkhorn_knopp(_s, _t, short_path, reg, numItermax=max_iter)
    return T


if __name__ == '__main__':
    n = [5, 5, 4]
    M = [np.array([[1,2,3,4,5], [5,4,3,2,1], [1,2,3,4,5], [5,4,3,2,1], [1,2,3,4,5]], dtype=np.float64),
         np.array([[1,2,3,4], [5,4,3,2], [1,2,3,5], [5,4,3,1], [1,2,3,5]], dtype=np.float64)]
    M = [Mi / 5 for Mi in M]
    s = np.array([0.1, 0.3, 0.3, 0.2, 0.1])
    t = np.array([0.1, 0.1, 0.5, 0.3])

    sp = cpp_create_short_path(M, n)
    print("cpp short_path: {}".format(sp))
    T = solve(s, t, sp, 1e-2)
    dis_sp = np.sum(T * sp)
    print("cpp SP+OT: {}".format(dis_sp))

    dis_flow = cpp_mincost_flow(M, n, s, t)
    print("cpp mincost flow: {}".format(dis_flow))
