import torch
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import gp_solver as gps
import mlot_simple as mlot
import mlot_virtual as mlot_v


def format_func(x, pos):
    return '{:.0e}'.format(x)

# generate data --------------------
def _generate_M(n, width, gap):
    points = []
    M = []
    for i in range(len(n)):
        x = np.random.uniform(-width/2, width/2, n[i]) + gap*i
        y = np.random.uniform(-10, 10, n[i])
        points.append(np.stack([x, y], axis=1))

        if i > 0:
            tmp = points[i-1][:, np.newaxis, :] - points[i]
            euclidean = np.sqrt(np.sum(tmp**2, axis=-1))
            M.append(euclidean)
    return points, M
def _generate_s_t(n):
    _s = np.random.rand(n[0])
    _t = np.random.rand(n[-1])
    return _s / np.sum(_s), _t / np.sum(_t)
def generateData(n, width, gap):
    points, M = _generate_M(n, width, gap)
    s, t = _generate_s_t(n)
    data = {
        'M': M,
        'source': s,
        'target': t
    }
    return data
# -----------------------------------


def groundTruth(s, t, n, M):
    gpmodel = gps.MultilayerSolver(s, t, n, M)
    gpmodel.create_variable()
    gpmodel.initProblem()
    min_cost, vars = gpmodel.solve()

    P_gt = gpmodel.getP()
    gt_layers = gpmodel.getLayers()

    gt = {
        'P': P_gt,
        'objective': min_cost,
        'layers': gt_layers
    }
    return gt


def TtoLayers(T):
    layers = []
    for i in range(len(T)):
        layers.append(torch.sum(T[i], dim=1))
    layers.append(torch.sum(T[-1], dim=0))
    return layers


def MLOTsingle(s, t, M, reg, virtual=False, delta=100):
    _max_in_M = max([np.max(m) for m in M])
    record_max = [_max_in_M for _ in range(len(M))]
    _M = [torch.tensor(Mi) / _max_in_M for Mi in M]

    if not virtual:
        T, log = mlot.multi_sinkhorn_single(s, t, _M, reg, record_max, numItermax=1000)
        distance = mlot.sinkhorn_distance(T, _M, record_max)
        layers = TtoLayers(T)
    else:
        T, log = mlot_v.mlot_virtual_single(s, t, _M, reg, delta, numItermax=1000)
        layers = log['layers']
        distance = mlot.sinkhorn_distance(T, _M, record_max)
    
    print("MLOT single distance=", distance)

    log.update({
        'distance': distance,
        'layers': layers,
        'T': T,
    })
    
    return log


def MLOT(s, t, M, lbd, tau, virtual=False, delta=100):
    _max_in_M = max([np.max(m) for m in M])
    record_max = [_max_in_M for _ in range(len(M))]
    _M = [torch.tensor(Mi) / _max_in_M for Mi in M]

    if not virtual:
        T, log = mlot.multi_sinkhorn(s, t, _M, lbd, tau, record_max, numItermax=1000)
        distance = mlot.sinkhorn_distance(T, _M, record_max)
        layers = TtoLayers(T)
    else:
        T, log = mlot_v.mlot_virtual(s, t, _M, lbd, tau, delta, numItermax=1000)
        layers = log['layers']
        distance = mlot.sinkhorn_distance(T, _M, record_max)
    
    print("MLOT distance=", distance)

    log.update({
        'distance': distance,
        'layers': layers,
        'T': T,
    })

    return log


def plot_double_histo_heat(source, middle, target, P, name, heatmap_max: list, ishow=True):
    total = len(source) + len(target) + len(middle)
    scalar = [len(source)/total, len(target)/total, len(middle)/total]
    base = 10
    fig = plt.figure(figsize=((scalar[0]+scalar[1])*base+3, scalar[2]*base+4))
    gs = gridspec.GridSpec(3, 2, width_ratios=[1, scalar[2]*base], height_ratios=[scalar[0]*base, 1, scalar[1]*base], hspace=0.05, wspace=0.05)

    # fig.suptitle(name, fontsize=17, weight='bold')
    # fig.text(0.5, 0.03, name, ha='center', va='center', fontsize=27, weight='bold')

    # source 柱状图
    ax0 = plt.subplot(gs[0])
    ax0.barh(np.arange(len(source)), source.flatten(), color=cm.YlGnBu_r(0.25))
    ax0.invert_xaxis()
    ax0.set_ylim(-0.5, len(source) - 0.5)
    ax0.set_ylabel('Source', fontsize=25, weight='bold')


    # 第一张 热力图 YlGnBu_r
    ax1 = plt.subplot(gs[1])
    im = ax1.imshow(P[0], cmap='YlGnBu_r', vmin=0, vmax=heatmap_max[0], aspect='auto')
    ax1.axis('off')
    ax1.text(0.5, 1.07, 'First coupling $\mathbf{P_1}$', ha='center', va='center', transform=ax1.transAxes, fontsize=25, weight='bold')
    cax1 = fig.add_axes([ax1.get_position().x1+0.01,ax1.get_position().y0,0.02,ax1.get_position().height])
    cbar = fig.colorbar(im, cax=cax1)
    # formatter = FuncFormatter(format_func)
    # cbar.ax.yaxis.set_major_formatter(formatter)

    # 空 只显示字
    ax_middle = plt.subplot(gs[2])
    ax_middle.text(0.5, 0.5, 'Middle', ha='center', va='center', fontsize=25, weight='bold')
    ax_middle.axis('off')

    # middle 柱状图
    ax2 = plt.subplot(gs[3])
    ax2.bar(np.arange(len(middle.flatten())), middle.flatten(), color=(0.35,0,0.65))
    ax2.set_ylim(0, 0.12)
    ax2.set_xlim(-0.5, len(middle) - 0.5)

    # target 柱状图
    ax3 = plt.subplot(gs[4])
    ax3.barh(np.arange(len(target)), target.flatten(), color=cm.hot(0.25))
    ax3.invert_xaxis()
    ax3.set_ylim(-0.5, len(target) - 0.5)
    ax3.set_ylabel('Target', fontsize=25, weight='bold')

    # 第二张 热力图 hot
    ax4 = plt.subplot(gs[5])
    im = ax4.imshow(P[1].T, cmap='hot', vmin=0, vmax=heatmap_max[1], aspect='auto')
    ax4.axis('off')
    ax4.text(0.5, -0.1, 'Second coupling $\mathbf{P_2}$', ha='center', va='center', transform=ax4.transAxes, fontsize=25, weight='bold')
    cax4 = fig.add_axes([ax4.get_position().x1+0.01,ax4.get_position().y0,0.02,ax4.get_position().height])
    cbar = fig.colorbar(im, cax=cax4)
    # formatter = FuncFormatter(format_func)
    # cbar.ax.yaxis.set_major_formatter(formatter)

    # 去除三个柱状图的双坐标轴
    list(map(lambda x: x.set_xticks([]) or x.set_yticks([]), [ax0, ax2, ax3]))
    # 给两个热力图设置同一个颜色条
    # fig.colorbar(im, ax=[ax0, ax1, ax2, ax3, ax4])

    if ishow:
        plt.show()


np.random.seed(0)
if __name__ == "__main__":
    results = {}

    n = [25, 50, 25]
    data = generateData(n, 0.001, 5)

    results['gt'] = groundTruth(data['source'], data['target'], n, data['M'])
    # results['mlot-single'] = MLOTsingle(data['source'], data['target'], data['M'], 1e-3, virtual=True, delta=5e-3)
    # results['mlot'] = MLOT(data['source'], data['target'], data['M'], 1e-3, 2e-2, virtual=True, delta=5e-3)


    s, t = data['source'], data['target']
    _heat_max = [0.1,0.1]

    plot_double_histo_heat(s, results['gt']['layers'][1], t, results['gt']['P'], 'Gurobi', _heat_max, False)
    plt.savefig("results/reg/Gurobi.png")
    # plot_double_histo_heat(s, results['mlot-single']['layers'][1], t, results['mlot-single']['T'], r'$\mathbf{\tau=0}$', _heat_max, False)
    # plt.savefig("results/reg/[1e-3]0.png")
    # plot_double_histo_heat(s, results['mlot']['layers'][1], t, results['mlot']['T'], r'$\mathbf{\tau=2\times 10^{-2}}$', _heat_max, False)
    # plt.savefig("results/reg/[1e-3]2e-2.png")
    # plt.show()

