import os
import torch
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt

import mlot_detail_multilayer as mlot


def _generate_M(_n):
    points = []
    _M = []
    for i in range(len(_n)):
        x = np.random.uniform(-1e-10, 1e-10, _n[i]) + 0.001*i
        y = np.random.uniform(-10, 10, _n[i])
        points.append(np.stack([x, y], axis=1))

        if i > 0:
            tmp = points[i-1][:, np.newaxis, :] - points[i]
            euclidean = np.sqrt(np.sum(tmp**2, axis=-1))
            _M.append(euclidean)
    return points, _M


def _generate_s_t(_n):
    _s = np.random.rand(_n[0])
    _t = np.random.rand(_n[-1])
    return _s / np.sum(_s), _t / np.sum(_t)


def generateData(_n):
    points, _M = _generate_M(_n)
    _s, _t = _generate_s_t(_n)
    data = {
        'M': _M,
        'source': _s,
        'target': _t
    }
    return data


def _single_solve(_s, _t, _M, lbds, radio, numItermax):
    if len(lbds) == 1:
        lbd = lbds[0]
        T, log = mlot.multi_sinkhorn_single(_s, _t, _M, lbd, radio, numItermax=numItermax)
        distance = radio * mlot.sinkhorn_distance(T, _M)
    else:
        lbd, tau = lbds
        T, log = mlot.multi_sinkhorn(_s, _t, _M, lbd, tau, radio, numItermax=numItermax)
        distance = radio * mlot.sinkhorn_distance(T, _M)

    log.update({
        'dis': distance
    })
    return log


def solve(_s, _t, _M, test_lbds, numItermax):
    results = {}
    radio = max([np.max(Mi) for Mi in _M])
    _M_norm = [Mi / radio for Mi in _M]
    for lbds in test_lbds:
        results[lbds] = _single_solve(_s, _t, _M_norm, lbds, radio, numItermax)

    return results


def KL(p, q):
    _p = p[p!=0]
    _q = q[p!=0]
    _m = _p / _q
    return (_p * _m.log()).sum().item() - _p.sum().item() + _q.sum().item()


def _plot_one_line_chart(ax, names, groups, xlabel='Iteration', ylabel='Error', ylim=None):
    for i, group in enumerate(groups):
        if ylabel == 'Convergence Error':
            x = [ii for ii in range(0, 5000, 500)] + [ii for ii in range(5000, 10000, 1000)] + [ii for ii in range(10000, 20000, 2000)]  # tau=0 的算法前密后疏
            # x = [ii*100 for ii in range(len(group))]
            z = 2 if i == 0 else 1
            ax.plot(x, group, color=colors[i], label=names[i], marker='^', markersize=5, zorder=z)
            ax.set_ylim(ylim[0], ylim[1])
        elif ylabel == 'KL Divergence':
            x = [ii for ii in range(0, 5000, 500)] + [ii for ii in range(5000, 20000, 2000)]
            # x = [ii*100 for ii in range(len(group))]
            z = 2 if i == 0 else 1
            ax.plot(x, group, color=colors[i], label=names[i], marker='o', markersize=3, zorder=z)
            ax.set_ylim(ylim[0], ylim[1])
        else:
            x = [ii*100 for ii in range(len(group))]
            z = 2 if i == 0 else 1
            ax.plot(x, group, color=colors[i], label=names[i], zorder=z)
            ax.set_ylim(ylim[0], ylim[1])
    ax.legend(loc='upper right', prop={'size': 14})
    ax.set_xlabel(xlabel, fontsize=17.5, fontweight='normal')
    ax.set_ylabel(ylabel, fontsize=17.5, fontweight='normal')


def _plot_mean_std(ax, names, group_means, group_stds, xlabel='Iteration', ylabel='KL Divergence', ylim=None):
    print(names)
    print(colors)
    for i, (means, stds) in enumerate(zip(group_means, group_stds)):
        x = [ii*100 for ii in range(len(means))]
        print(i)
        ax.plot(x, means, color=colors_std[i], label=names[i])
        ax.fill_between(x, means - stds, means + stds, color=colors_std[i], alpha=0.1)
        if ylim:
            ax.set_ylim(ylim[0], ylim[1])
    ax.legend(loc='upper right')
    ax.set_xlabel(xlabel, fontsize=15)
    ax.set_ylabel(ylabel, fontsize=15)



def plot_results(names, indicators, ylabels, ylims, path):
    fig, axs = plt.subplots(1, len(indicators), figsize=(11, 5))
    # plt.subplots_adjust(left=0.25)
    plt.subplots_adjust(wspace=0.25)
    for i, indicator in enumerate(indicators):
        if ylabels[i] == 'KL Divergence':
            # _plot_mean_std(axs[i], names, indicator[0], indicator[1], ylabel=ylabels[i], ylim=ylims[i])
            _plot_one_line_chart(axs[i], names, indicator[0], ylabel=ylabels[i], ylim=ylims[i])
        else:
            _plot_one_line_chart(axs[i], names, indicator, ylabel=ylabels[i], ylim=ylims[i])
    plt.savefig(path)


def compute_av_sigma(arrays):
    """
    输入是同一个测试指标的多次实验的结果
    它们存在列表arrays里, 每个元素是一次实验的测量
    """
    max_len = max(len(arr) for arr in arrays)
    av = [0 for _ in range(max_len)]
    sigma = [0 for _ in range(max_len)]

    for i in range(max_len):
        values = [arr[i] for arr in arrays]
        av[i] = np.mean(values)
        sigma[i] = np.std(values)
    return av, sigma


def num2latex(num):
    sci = "{:.0e}".format(num)
    coeff, exp = sci.split('e')
    coeff, exp = int(coeff), int(exp)
    return r'{}\times 10^{{{}}}'.format(coeff, exp)


SEED = 0
np.random.seed(SEED)
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# colors = ['blue', 'red', 'orange', 'green', 'purple', 'black', 'gray']
# colors = [cm.Blues(0.99), cm.Blues(0.8), cm.Blues(0.6), cm.Blues(0.45), cm.Blues(0.3)]
colors_std = [cm.rainbow(0.99), cm.rainbow(0.8), cm.rainbow(0.6), cm.rainbow(0.45), cm.rainbow(0.3)]
colors = [cm.Reds(0.99), cm.Reds(0.8), cm.Reds(0.6), cm.Reds(0.45), cm.Reds(0.3)]
# colors = [cm.coolwarm(0.99), cm.coolwarm(0.8), cm.coolwarm(0.6), cm.coolwarm(0.4), cm.coolwarm(0.2), cm.coolwarm(0.01)]
# colors = [cm.cividis(0.01), cm.cividis(0.2), cm.cividis(0.4), cm.cividis(0.6), cm.cividis(0.8), cm.cividis(0.99)]


if __name__ == '__main__':
    n = [25, 50, 125, 150, 150, 150, 150, 125, 50, 25]
    layer = len(n)
    N = sum(n)
    print("+-----------------+")
    print("N={}, n={}".format(N, n))
    print("+-----------------+")

    # test_lbds = [
    #     (8e-4, 2e-3),
    #     (1e-3, 4e-3),
    #     (1e-3, 8e-3),
    #     (1e-3, 2e-2),
    #     (3e-3, 4e-2)
    # ]

    test_lbds = [
        (8e-4, ),
        (2e-3, ),
        (4e-3, ),
        (8e-3, ),
        (2e-2, )
    ]


    data = generateData(n)
    M = data['M']
    s, t = data['source'], data['target']

    results = solve(s, t, M, test_lbds, numItermax=20000)
    for key, value in results.items():
        if len(key) == 1:
            print("[MLOT single] eps= {}".format(key[0]))
            print("\titer= {}\tdistance= {}".format(value['iter'], value['dis']))
        else:
            print("[MLOT multi] eps= {}, tau= {}".format(key[0], key[1]))
            print("\titer= {}\tdistance= {}".format(value['iter'], value['dis']))

    gt_layers = []
    for i in range(1, layer-1):
        tmp = torch.load('./results/cache/layer{}.pt'.format(i))
        tmp = tmp.reshape(-1).cpu()
        gt_layers.append(tmp)

    names = []
    error_uvs = []
    KL_second_layers = []
    KL_last_third_layers = []
    _tmp_layers = []
    for k, v in results.items():
        name = r'$\tau={}$'.format(num2latex(k[1])) if len(k) == 2 else r'$\epsilon={}$'.format(num2latex(k[0]))
        names.append(name)
        _dense_sparse = v['error_uv'][0:50:5] + v['error_uv'][50:100:10] + v['error_uv'][100::20]
        error_uvs.append(_dense_sparse)
        _KL_for_one_para = []
        for i in range(layer-2):
            # 计算第i层的20000次迭代的KL散度序列
            _KL_for_one_layer = []
            for layer_ii in v['layers'][i]:
                kl = KL(gt_layers[i], layer_ii.cpu())
                _KL_for_one_layer.append(kl)
            _KL_for_one_para.append(_KL_for_one_layer)
        _tmp_layers.append(_KL_for_one_para)

    # 没用 ---
    KL_layers = [[] for _ in range(layer-2)]
    for i in range(layer-2):
        for p in range(len(_tmp_layers)):
            KL_layers[i].append(_tmp_layers[p][i])
    # --------

    KL_means = []
    KL_stds = []
    for i in range(len(test_lbds)):
        av, sigma = compute_av_sigma(_tmp_layers[i])
        _dense_sparse_av = av[:50:5] + av[50:200:20]
        _dense_sparse_sigma = sigma[:50:5] + sigma[50:200:20]
        KL_means.append(_dense_sparse_av)
        KL_stds.append(_dense_sparse_sigma)

    print("len(KL_means)={}".format(len(KL_means)))     # 应为参数个数 (而非层数)
    print("len(KL_stds)={}".format(len(KL_stds)))
    print("len(KL_means[0])={}".format(len(KL_means[0])))   # 应为迭代次数
    print("len(KL_stds[0])={}".format(len(KL_stds[0])))

    plot_results(names,
                 [error_uvs, [KL_means, KL_stds]],
                 ['Convergence Error', 'KL Divergence'],
                 [(-0.2, 3), (0.35, 1.25)],
                #  [None, None, None],
                 './results/diff_layer/tau=0K=10(0).png'
    )




