import os
import copy
import time
import torch
import numpy as np
from tqdm import tqdm

import ot_solver
import gp_solver as gps
import mlot_simple as mlot


def _generate_line_M(_n):
    points = []
    _M = []
    for i in range(len(_n)):
        x = np.random.uniform(-1e-10, 1e-10, _n[i]) + 0.001*i
        y = np.random.uniform(-10, 10, _n[i])
        points.append(np.stack([x, y], axis=1))

        if i > 0:
            tmp = points[i-1][:, np.newaxis, :] - points[i]
            euclidean = np.sqrt(np.sum(tmp**2, axis=-1))
            _M.append(euclidean)
    return points, _M


def _generate_s_t(_n):
    _s = np.random.rand(_n[0])
    _t = np.random.rand(_n[-1])
    return _s / np.sum(_s), _t / np.sum(_t)


def generateLineData(_n):
    points, _M = _generate_line_M(_n)
    _s, _t = _generate_s_t(_n)
    data = {
        'M': _M,
        'source': _s,
        'target': _t
    }
    print("Finish generate Line data")
    return data


def compo1(r, a):
    tmp1 = 1 + a*a*r*r
    tmp2 = 0.5 * r * np.sqrt(tmp1)
    return tmp2
def compo2(r, a):
    tmp1 = 1 + a*a*r*r
    tmp2 = np.sqrt(tmp1) + a*r
    tmp3 = 1/(2*a) * np.log(tmp2)
    return tmp3
def helicalDist(p1, p2):
    r"""
    极坐标输入, 返回阿基米德螺线距离
    - \frac{dr}{d\theta} = b, 阿基米德螺线方程 r = b (\theta - \theta_0)
    - 代入两点坐标得 b = \frac{1}{\theta_2 - \theta_1}, \theta_0 = \theta_1 - r_1/b
    - 但尤其注意: 螺线的方向不确定, 故两点的坐标角度需要经过调整是否 + 2*pi
    """
    mp1, mp2 = copy.deepcopy(p1), copy.deepcopy(p2)
    r1, r2 = mp1[0], mp2[0]
    theta1, theta2 = mp1[1], mp2[1]
    is_2bigger1 = (theta2 > theta1)
    is_closer = (abs(theta2 - theta1) <= np.pi)
    if is_2bigger1:
        if is_closer:   # 逆时针
            pass
        else:        # 顺时针
            theta1 += 2 * np.pi
    else:
        if is_closer:   # 顺时针
            pass
        else:     # 逆时针
            theta2 += 2 * np.pi

    a = theta2 - theta1
    l = compo1(r2, a) - compo1(r1, a) + compo2(r2, a) - compo2(r1, a)
    return abs(l)
def _calculate_distance_matrices(layers):
    _M = [torch.zeros((len(layers[i]), len(layers[i+1])), dtype=torch.float64) for i in range(len(layers)-1)]
    for i in range(len(layers)-1):
        for j in tqdm(range(len(layers[i]))):
            for k in range(len(layers[i+1])):
                _M[i][j, k] = helicalDist(layers[i][j], layers[i+1][k])
    return _M


def _generate_ring_M(_n):
    angle_points = []
    for i in range(len(_n)):
        angle_per_layer = []
        for j in range(_n[i]):
            angle = np.random.uniform(0, 2 * np.pi)
            angle_per_layer.append((i + 1, angle))
        angle_points.append(angle_per_layer)

    print("Begin cal helical M")
    _M = _calculate_distance_matrices(angle_points)
    torch.save(_M, './results/cache/ring{}.pt'.format(sum(_n)))
    return _M


def generateRingData(_n):
    if not os.path.exists('./results/cache/ring{}.pt'.format(sum(_n))):
        _M = _generate_ring_M(_n)
    else:
        _M = torch.load('./results/cache/ring{}.pt'.format(sum(_n)))
    _s, _t = _generate_s_t(_n)
    data = {
        'M': _M,
        'source': _s,
        'target': _t
    }
    print("Finish generate ring data")
    return data


def solve_with_pot_dp(M, n, s, t, reg, max_iter=10000):
    start = time.perf_counter()
    short_path = ot_solver.cpp_create_short_path(M, n)
    print("Calculated shortest path")
    _max_num = short_path.max()
    _M = short_path / _max_num
    T = ot_solver.solve(s, t, _M, reg, max_iter)
    print("Finish POT sinkhorn")
    end = time.perf_counter()
    dis = _max_num * np.sum(T * _M)
    return T, dis, (end-start)


def solver_with_mincost_flow(M, n, s, t):
    start = time.time()
    mincost = ot_solver.cpp_mincost_flow(M, n, s, t)
    end = time.time()
    return mincost, (end-start)


def solve_with_multi_sinkhorn(s, t, M, reg, numItermax):
    start = time.perf_counter()
    T, log = mlot.multi_sinkhorn_single(s, t, M, reg, numItermax)
    end = time.perf_counter()
    return T, log, (end-start)


def solve_with_multi_sinkhorn2(s, t, M, lbd, tau, numItermax):
    start = time.perf_counter()
    T, log = mlot.multi_sinkhorn(s, t, M, lbd, tau, numItermax)
    end = time.perf_counter()
    return T, log, (end-start)


def solve_with_gurobi(gp_model, write=False):
    start = time.time()
    gp_model.create_variable()
    gp_model.initProblem(write)
    min_cost, vars = gp_model.solve()
    end = time.time()
    return min_cost, vars, (end-start)


def testTime(datatype, N, n, l, do_gurobi, do_dp, do_flow, do_mlot):
    print("+-----------------+")
    print("|       {}      |".format(datatype))
    print("+-----------------+")
    print("N={}, n={}".format(N, n))
    assert l == len(n), "Layer number should == len(n)"

    results = {}

    if datatype == 'line':
        data = generateLineData(n)
    else:
        data = generateRingData(n)
    _M = data['M']
    s, t = data['source'], data['target']
    results['M'] = _M
    results['source'] = s
    results['target'] = t
    results['n'] = n

    if do_gurobi:
        gp_model = gps.MultilayerSolver(s, t, n, _M)
        dis_gurobi, vars, time_gurobi = solve_with_gurobi(gp_model, write=False)
        P_gurobi = [torch.ones((n[i], n[i+1])).to(device, dtype=torch.float64) for i in range(l-1)]
        for var in vars:
            if var.varName[0] == "P":
                try:
                    k = int(var.varName[2])
                    assert k in range(l-1), f"index k should be in range({l-1})"
                    i, j = map(int, var.varName[4:-1].split(","))
                    P_gurobi[k][i, j] = var.x
                except Exception as e:
                    print(f"error: {e} at {var.varName}")
                    if var.x > 0:
                        print(f"var.x: {var.x}")
                        assert False
        results['gurobi'] = {'objective': dis_gurobi, 'time': time_gurobi, 'P': P_gurobi}
    else:
        gp_model = None

    if do_dp:
        T_dp, dis_dp, time_dp = solve_with_pot_dp(_M, n, s, t, reg=1e-3, max_iter=10000)
        print("dp distance:", dis_dp)
        results['dp'] = {'objective': dis_dp, 'time': time_dp}

    if do_flow:
        mincost, time_flow = solver_with_mincost_flow(_M, n, s, t)
        print("flow distance:", mincost)
        results['flow'] = {'objective': mincost, 'time': time_flow}

    _max_num = max([Mi.max() for Mi in _M])
    M = [Mi/_max_num for Mi in _M]

    if do_mlot:
        T, log, timer_multi = solve_with_multi_sinkhorn(s, t, M, 1, numItermax=0)
        T2, log2, timer_multi2 = solve_with_multi_sinkhorn2(s, t, M, 1, 1, numItermax=0)

        T, log, timer_multi = solve_with_multi_sinkhorn(s, t, M, 0.001, numItermax=20000)
        dis = _max_num * mlot.sinkhorn_distance(M, T).item()
        results['multi-ot'] = {
            'objective': dis,
            'time': timer_multi,
            'iter': log['iter'],
        }

        T2, log2, timer_multi2 = solve_with_multi_sinkhorn2(s, t, M, 0.001, 0.002, numItermax=20000)
        dis2 = _max_num * mlot.sinkhorn_distance(M, T2).item()
        results['multi-ot2'] = {
            'objective': dis2,
            'time': timer_multi2,
            'iter': log2['iter'],
        }

    return results, gp_model


SEED = 0

if __name__ == '__main__':
    np.random.seed(SEED)
    do_gurobi = False
    do_dp = False
    do_flow = False
    do_mlot = True

    layer = 3
    n = [5000, 10000, 5000]
    total_num = sum(n)

    results, gp_model = testTime('line', total_num, n, layer, do_gurobi, do_dp, do_flow, do_mlot)

    # ------- 最终打印结果 -------
    print("N={}, n={}".format(sum(n), n))
    if do_gurobi:
        print("[Gurobi]:")
        print("\tobj = {:.4f}".format(results['gurobi']['objective']))
        print("\ttime = {:.3f}".format(results['gurobi']['time']))
    if do_dp:
        print("[DP]:")
        print("\tobj = {:.4f}".format(results['dp']['objective']))
        print("\ttime = {:.3f}".format(results['dp']['time']))
    if do_flow:
        print("[Flow]:")
        print("\tobj = {:.4f}".format(results['flow']['objective']))
        print("\ttime = {:.3f}".format(results['flow']['time']))
    if do_mlot:
        print("[MlOT single]:")
        print("\tobj = {:.4f}".format(results['multi-ot']['objective']))
        print("\ttime = {:.3f}".format(results['multi-ot']['time']))
        print("[MlOT2]:")
        print("\tobj = {:.4f}".format(results['multi-ot2']['objective']))
        print("\ttime = {:.3f}".format(results['multi-ot2']['time']))



