import torch
import random
from tqdm import tqdm
import mlot_simple as mlot
import mlot_virtual as mlot_v
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects
from PIL import Image, ImageDraw, ImageFont
import torchvision.transforms as transforms

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("[Work on {}]".format(device))

seed = 3407
def setSeed(seed):
    torch.manual_seed(seed)
    random.seed(seed)


def normalizedGroundMetric():
    t = torch.linspace(0, 63, 64)
    [Y, X] = torch.meshgrid(t, t, indexing='ij')
    position = torch.vstack((X.reshape(-1), Y.reshape(-1)))
    position = position.T / torch.max(position)     # 形状为 (4906, 2) 即包含 64x64 所有的坐标点

    diff = position.unsqueeze(1) - position.unsqueeze(0)
    Dist = torch.sum(diff**2, dim=-1)
    print("[Distance metric computed] shape =", Dist.shape)
    _max_num = torch.max(Dist)
    _min_num = torch.min(Dist)
    print("dist max={}, min={}".format(_max_num, _min_num))
    Dist = Dist / _max_num
    Dist += 1e-20
    return Dist


# 前期图像准备--已完成
def draw_handwritten_letter(letter='A', color='green', thickness=10):
    fig, ax = plt.subplots(figsize=(1, 1), dpi=64)
    ax.set_xlim(0, 64)
    ax.set_ylim(0, 64)
    ax.axis('off')

    # 创建手绘效果的字母
    ax.text(32, 25, letter, fontsize=50, color=color,
            ha='center', va='center', fontweight='normal', fontfamily='sans-serif',
            path_effects=[patheffects.withStroke(linewidth=thickness, foreground=color)])

    plt.savefig("./cache/{}.png".format(letter), pad_inches=0)
# draw_handwritten_letter('S', 'black', 4)
# draw_handwritten_letter('T', 'black', 4)

# 读取一张照片, 检查是否是 64x64 并展示
def show64Image(path):
    img = Image.open(path).convert('RGB')
    img = img.resize((64, 64), Image.ANTIALIAS)
    img = transforms.ToTensor()(img)
    img = img.unsqueeze(0).to(device)
    plt.figure(figsize=(1, 1), dpi=64)
    plt.imshow(img[0].cpu().numpy().transpose(1, 2, 0))
    plt.axis('off')
    plt.show()
# readImage("./letter/S.png")
# readImage("./letter/T.png")



def solveMLOT(M, reg, virtual, numItermax):
    if virtual == 0:
        if len(reg) == 1:
            return mlot.multi_sinkhorn_single(s, t, M, reg[0], numItermax)
        return mlot.multi_sinkhorn(s, t, M, reg[0], reg[1], numItermax)
    else:
        if len(reg) == 1:
            return mlot_v.mlot_virtual_single(s, t, M, reg[0], virtual, numItermax)
        return mlot_v.mlot_virtual(s, t, M, reg[0], reg[1], virtual, numItermax)


def MLOT(k):
    M = [Dist for _ in range(k-1)]
    test_no_virtual = [
        [[5e-5, 5e-4], 5000],
        [[5e-5, 5e-3], 5000],
        [[5e-6, 5e-5], 5000],
        [[5e-6, 5e-4], 5000],
    ]

    test_virtual = [
        [[1e-3], 5e-2, 2000],
        [[1e-4, 5e-4], 5e-2, 2000]
    ]

    if not TEST_VIRTUAL:
        results_layers = []
        for reg, numIter in test_no_virtual:
            print("[MLOT no-Virtual]--reg={}".format(reg))
            T = solveMLOT(M, reg, 0, numIter)
            batch_layers = []
            for i in range(k-2):
                batch_layers.append(T[i].sum(0))
            results_layers.append(batch_layers)
    else:
        results_layers = []
        for reg, delta, numIter in test_virtual:
            print("[MLOT Virtual]--reg={}--virtual={}".format(reg, delta))
            T, log = solveMLOT(M, reg, delta, numIter)
            results_layers.append(log['layers'][1:-1])

    # 绘图 --------------------
    fig, axs = plt.subplots(len(results_layers), k)
    for i, layers in enumerate(results_layers):
        # 绘图永远是字母'S'在前, 字母'T'在后
        if DIRECT == 'ST':
            draw_layers = [s] + layers + [t]
        else:
            draw_layers = [t] + layers[::-1] + [s]

        for j, layer in enumerate(draw_layers):
            ax = axs[i, j]
            ax.imshow(layer.reshape(64, 64).cpu().numpy(), cmap='gray')
            ax.axis('off')
        # 第一张子图标题写上参数
        ax = axs[i, 0]
        if not TEST_VIRTUAL:
            ax.set_title("{}".format(test_no_virtual[i][0]))
        else:
            ax.set_title("{}\n{}".format(test_virtual[i][0], test_virtual[i][1]))
    # -------------------------

    if TEMPORARY:
        plt.savefig("./results/test_bary/TEMP-{}.png".format(DIRECT, k))
        print("Save MLOT temp results")
    else:
        if not TEST_VIRTUAL:
            plt.savefig("./results/test_bary/{}_{}_NOvir.png".format(DIRECT, k))
            print("Save fig to ./results/test_bary/{}_{}_NOvir.png".format(DIRECT, k))
        else:
            plt.savefig("./results/test_bary/{}_{}_vir.png".format(DIRECT, k))
            print("Save fig to ./results/test_bary/{}_{}_vir.png".format(DIRECT, k))


def barycenter():
    test_no_virtual = [
        [[5e-5, 5e-4], 5000],
        [[5e-5, 5e-3], 5000],
        [[5e-6, 5e-5], 5000],
        [[5e-6, 5e-4], 5000],
    ]

    test_virtual = [
        [[1e-4], 5e-2, 2000],
        [[1e-4, 5e-4], 5e-2, 2000]
    ]

    lambdas = [0.25, 0.5, 0.75]
    results_layers = []
    if not TEST_VIRTUAL:
        for reg, numIter in test_no_virtual:
            batch_layers = []
            for lbd in lambdas:
                M = [Dist*(1-lbd), Dist*lbd]
                print("[Barycenter no-Virtual]--weight={}--reg={}".format([1-lbd,lbd], reg))
                T = solveMLOT(M, reg, 0, numIter)
                batch_layers.append(T[0].sum(0))
            results_layers.append(batch_layers)
    else:
        for reg, delta, numIter in test_virtual:
            batch_layers = []
            for lbd in lambdas:
                M = [Dist*(1-lbd), Dist*lbd]
                print("[Barycenter Virtual]--weight={}--reg={}--virtual={}".format([1-lbd,lbd], reg, delta))
                T, log = solveMLOT(M, reg, delta, numIter)
                batch_layers.append(log['layers'][1])
            results_layers.append(batch_layers)

    # 绘图 --------------------
    fig, axs = plt.subplots(len(results_layers), 5)
    for i, layers in enumerate(results_layers):
        # 绘图永远是字母'S'在前, 字母'T'在后
        if DIRECT == 'ST':
            draw_layers = [s] + layers + [t]
            draw_lbd = [0] + lambdas + [1]
        else:
            draw_layers = [t] + layers[::-1] + [s]
            draw_lbd = [0] + lambdas + [1]
        for j, layer in enumerate(draw_layers):
            ax = axs[i, j]
            print("layer sum={}".format(layer.sum().item()))
            ax.imshow(layer.reshape(64, 64).cpu().numpy(), cmap='gray')
            ax.axis('off')
            if 0 < j < 4:
                ax.set_title(r"$\lambda=${}".format(draw_lbd[j]))
        # 第一张子图标题写上参数
        ax = axs[i, 0]
        if not TEST_VIRTUAL:
            ax.set_title("{}".format(test_no_virtual[i][0]))
        else:
            ax.set_title("{}\n{}".format(test_virtual[i][0], test_virtual[i][1]))

    if TEMPORARY:
        plt.savefig("./results/test_bary/TEMP-{}-bary.png".format(DIRECT))
        print("Save BC temp results")
    else:
        if not TEST_VIRTUAL:
            plt.savefig("./results/test_bary/{}_bary_NOvir.png".format(DIRECT))
            print("Save fig to ./results/test_bary/{}_bary_NOvir.png".format(DIRECT))
        else:
            plt.savefig("./results/test_bary/{}_bary_vir.png".format(DIRECT))
            print("Save fig to ./results/test_bary/{}_bary_vir.png".format(DIRECT))


if __name__ == '__main__':
    setSeed(seed)
    Dist = normalizedGroundMetric()
    img_s = plt.imread("./cache/S.png")[:,:,0]
    img_t = plt.imread("./cache/T.png")[:,:,0]
#     img_t = plt.imread("./cache/cheetah_64.png")[:,:,0]
    print("Image shape:", img_s.shape, img_t.shape)

    DIRECT = 'ST'
    TEST_VIRTUAL = False
    TEMPORARY = False

    if DIRECT == 'ST':
        s = torch.tensor(img_s).reshape(-1).to(device)
        t = torch.tensor(img_t).reshape(-1).to(device)
        print("[Source=s] sum={}, [Target=t] sum={}".format(sum(s).item(), sum(t).item()))
    else:
        s = torch.tensor(img_t).reshape(-1).to(device)
        t = torch.tensor(img_s).reshape(-1).to(device)
        print("[Source=t] sum={}, [Target=s] sum={}".format(sum(s).item(), sum(t).item()))
    s = s / sum(s).item()
    t = t / sum(t).item()
    s += 1e-40
    t += 1e-40

    MLOT(4)
    # barycenter()
    print("Direct={}, Virtual={}, Temporary={}".format(DIRECT, TEST_VIRTUAL, TEMPORARY))
