import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import os
import time
import torch
import math

import mlot_detail as mlot


def _generate_M(_n):
    points = []
    _M = []
    for i in range(len(_n)):
        x = np.random.uniform(-1e-10, 1e-10, _n[i]) + 0.001*i
        y = np.random.uniform(-10, 10, _n[i])
        points.append(np.stack([x, y], axis=1))

        if i > 0:
            tmp = points[i-1][:, np.newaxis, :] - points[i]
            euclidean = np.sqrt(np.sum(tmp**2, axis=-1))
            _M.append(euclidean)
    return points, _M


def _generate_s_t(_n):
    _s = np.random.rand(_n[0])
    _t = np.random.rand(_n[-1])
    return _s / np.sum(_s), _t / np.sum(_t)


def generateData(_n):
    points, _M = _generate_M(_n)
    _s, _t = _generate_s_t(_n)
    data = {
        'M': _M,
        'source': _s,
        'target': _t
    }
    return data


def _single_solve(_s, _t, _M, lbds, radio, numItermax):
    if len(lbds) == 1:
        lbd = lbds[0]
        T, log = mlot.multi_sinkhorn_single(_s, _t, _M, lbd, radio, numItermax=numItermax)
        distance = radio * mlot.sinkhorn_distance(T, _M)
    else:
        lbd, tau = lbds
        T, log = mlot.multi_sinkhorn(_s, _t, _M, lbd, tau, radio, numItermax=numItermax)
        distance = radio * mlot.sinkhorn_distance(T, _M)

    log.update({
        'dis': distance
    })
    return log


def solve(_s, _t, _M, test_lbds, numItermax):
    results = {}
    radio = max([np.max(Mi) for Mi in _M])
    _M_norm = [Mi / radio for Mi in _M]
    for lbds in test_lbds:
        results[lbds] = _single_solve(_s, _t, _M_norm, lbds, radio, numItermax)

    return results


def KL(p, q):
    _p = p[p!=0]
    _q = q[p!=0]
    _m = _p / _q
    return (_p * _m.log()).sum().item() - _p.sum().item() + _q.sum().item()


def _plot_one_line_chart(ax, names, groups, xlabel='Iteration', ylabel='Error', ylim=None):
    for i, group in enumerate(groups):
        if ylabel == 'Convergence Error':
            x = [ii for ii in range(0, 5000, 500)] + [ii for ii in range(5000, 10000, 1000)] + [ii for ii in range(10000, 20000, 2000)]  # tau=0 的算法前密后疏
            # x = [ii*100 for ii in range(len(group))]
            z = 2 if i == 0 else 1
            ax.plot(x, group, color=colors[i], label=names[i], marker='^', markersize=5, zorder=z)
            ax.set_ylim(ylim[0], ylim[1])
        elif ylabel == 'KL Divergence':
            x = [ii for ii in range(0, 5000, 500)] + [ii for ii in range(5000, 15000, 2000)]
            # x = [ii*100 for ii in range(len(group))]
            ax.plot(x, group, color=colors[i], label=names[i], marker='o', markersize=3)
            ax.set_ylim(ylim[0], ylim[1])
        else:
            x = [ii * 100 for ii in range(len(group))]
            ax.plot(x, group, color=colors[i], label=names[i])
            ax.set_ylim(ylim[0], ylim[1])
    ax.legend(loc='upper right', prop={'size': 14})
    ax.set_xlabel(xlabel, fontsize=17.5, fontweight='normal')
    ax.set_ylabel(ylabel, fontsize=17.5, fontweight='normal')


def plot_results(names, indicators, ylabels, ylims, path):
    fig, axs = plt.subplots(1, 2, figsize=(11, 5))
    # plt.subplots_adjust(left=0.25)
    plt.subplots_adjust(wspace=0.25)
    for i, indicator in enumerate(indicators):
        _plot_one_line_chart(axs[i], names, indicator, ylabel=ylabels[i], ylim=ylims[i])
    plt.savefig(path)


def num2latex(num):
    sci = "{:.0e}".format(num)
    coeff, exp = sci.split('e')
    coeff, exp = int(coeff), int(exp)
    return r'{}\times 10^{{{}}}'.format(coeff, exp)


SEED = 3407
np.random.seed(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# colors = ['blue', 'red', 'orange', 'green', 'purple', 'black', 'gray']
# colors = [cm.Blues(0.99), cm.Blues(0.8), cm.Blues(0.6), cm.Blues(0.45), cm.Blues(0.3)]
colors = [cm.Reds(0.99), cm.Reds(0.8), cm.Reds(0.6), cm.Reds(0.45), cm.Reds(0.3)]
# colors = [cm.coolwarm(0.99), cm.coolwarm(0.8), cm.coolwarm(0.6), cm.coolwarm(0.4), cm.coolwarm(0.2), cm.coolwarm(0.01)]
# colors = [cm.cividis(0.01), cm.cividis(0.2), cm.cividis(0.4), cm.cividis(0.6), cm.cividis(0.8), cm.cividis(0.99)]


if __name__ == '__main__':
    n = [250, 500, 250]
    layer = len(n)
    N = sum(n)
    print("+-----------------+")
    print("N={}, n={}".format(N, n))
    print("+-----------------+")

    # test_lbds = [
    #     (8e-4, 2e-3),
    #     (8e-4, 4e-3),
    #     (8e-4, 8e-3),
    #     (8e-4, 2e-2),
    #     (8e-4, 4e-2)
    # ]

    test_lbds = [
        (8e-4, ),
        (2e-3, ),
        (4e-3, ),
        (8e-3, ),
        (2e-2, )
    ]


    data = generateData(n)
    M = data['M']
    s, t = data['source'], data['target']

    results = solve(s, t, M, test_lbds, numItermax=20000)
    for key, value in results.items():
        if len(key) == 1:
            print("[MLOT single] eps= {}".format(key[0]))
            print("\titer= {}\tdistance= {}".format(value['iter'], value['dis']))
        else:
            print("[MLOT multi] eps= {}, tau= {}".format(key[0], key[1]))
            print("\titer= {}\tdistance= {}".format(value['iter'], value['dis']))

    gt_second_layer = torch.load('./results/cache/middleLayerN={}({}).pt'.format(N, SEED))
    gt_second_layer = gt_second_layer.reshape(-1).cpu()

    names = []
    costs = []
    error_uvs = []
    KL_second_layers = []
    for k, v in results.items():
        name = r'$\tau={}$'.format(num2latex(k[1])) if len(k) == 2 else r'$\epsilon={}$'.format(num2latex(k[0]))
        names.append(name)
        costs.append(v['cost_ii'][:150])
        _dense_sparse = v['error_uv'][0:50:5] + v['error_uv'][50:100:10] + v['error_uv'][100::20]
        error_uvs.append(_dense_sparse)
        _KL_for_one_para = []
        for layer_ii in v['second_layer']:
            kl = KL(gt_second_layer, layer_ii.cpu())
            _KL_for_one_para.append(kl)
        _dense_sparse = _KL_for_one_para[:50:5] + _KL_for_one_para[50:150:20]
        KL_second_layers.append(_dense_sparse)

    plot_results(names,
                [error_uvs, KL_second_layers],
                 ['Convergence Error', 'KL Divergence'],
                [(-0.05,0.5), (0.29, 0.44)],
                 './results/diff_layer/tau=0K={}({}).png'.format(layer, SEED)
    )




