import click

from src.models.frot import FrotSinkhorn,FrotLP,FrotEMD
from src.data.toy import ToyLoader
from src.evaluate.toy import ToyEvaluator

import numpy as np
import matplotlib.pyplot as plt

import torch


@click.command()
@click.option('--eta', default=0.5, help="Value of eta")
@click.option('--eps', default=0.1, help="Skinhorn parameter")
@click.option('--niter', default=10, help="Number of iterations")
@click.option('--show/--no-show', default=False, help="show matching")
def main(eta, eps, niter, show):
    data = ToyLoader(device="cpu")


    #Oracle
    modellp = FrotLP()
    modellp.fit(data.X, data.Y, data.groups, platform=data.platform)

    lp_score = np.ones(20)*modellp.losses_
    emd_score = np.zeros(20)
    sh_score  = np.zeros(20)
    emd_error_eta = np.zeros(20)
    sh_error_eta = np.zeros(20)
    sh_error_eps = np.zeros(20)

    x = np.zeros(20)

    for i in range(0,20):


        eta = 1.0 + i * 0.5
        x[i] = 1/eta
        modelemd = FrotEMD(eta=eta, niter=niter)
        modelemd.fit(data.X, data.Y, data.groups, platform=data.platform)

        #Sinkhorn
        modelsh = FrotSinkhorn(eta=eta, niter=niter,eps=eps)
        modelsh.fit(data.X, data.Y, data.groups, platform=data.platform)

        emd_score[i] = modelemd.losses_[-1]
        sh_score[i]  = modelsh.losses_[-1]


        emd_error_eta[i] = torch.sqrt(torch.pow(modelemd.losses_[-1]-modellp.losses_,2))
        sh_error_eta[i]  = torch.sqrt(torch.pow(modelsh.losses_[-1]-modellp.losses_,2))

    x_eps = np.zeros(20)
    eta =1.0
    for i in range(0, 20):
        eps = 0.1 + i * 0.05
        x_eps[i] = 1 / eps

        # Sinkhorn
        modelsh = FrotSinkhorn(eta=eta, niter=niter, eps=eps)
        modelsh.fit(data.X, data.Y, data.groups, platform=data.platform)

        sh_error_eps[i] = torch.sqrt(torch.pow(modelsh.losses_[-1] - modellp.losses_, 2))

    plt.rcParams["font.size"] = 18
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(x, lp_score, label='LP')
    ax.plot(x, emd_score, label='FW-EMD')
    ax.plot(x, sh_score, label='FW-Sinkhorn')
    ax.set_xlabel(r'$\eta^{-1}$')
    ax.set_ylabel('Objective score')
    plt.legend(loc='best')
    plt.savefig('/home/myamada/Dropbox/Apps/Overleaf/FROT/synthetic_exp/toy_convergence.pdf',bbox_inches='tight')

    plt.show()

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(x, emd_error_eta, label='FW-EMD')
    ax.plot(x, sh_error_eta, label='FW-Sinkhorn')
    ax.set_xlabel(r'$\eta^{-1}$')
    ax.set_ylabel('MSE')
    plt.legend(loc='best')
    plt.savefig('/home/myamada/Dropbox/Apps/Overleaf/FROT/synthetic_exp/toy_convergence_PI_eta.pdf',bbox_inches='tight')
    plt.show()

    fig = plt.figure()
    ax = fig.add_subplot(111)
    #ax.plot(x, emd_error_eta, label='FW-EMD')
    ax.plot(x_eps, sh_error_eps, label='FW-Sinkhorn')
    ax.set_xlabel(r'$\epsilon^{-1}$')
    ax.set_ylabel('MSE')
    plt.legend(loc='best')
    plt.savefig('/home/myamada/Dropbox/Apps/Overleaf/FROT/synthetic_exp/toy_convergence_PI_eps.pdf', bbox_inches='tight')
    plt.show()

if __name__ == "__main__":
    main()
