import click
from src.tools import fit_and_time

from src.models.frot import Frot
from src.models.group_wasserstein import WassersteinByGroup
from src.models.mmd_distances import MMD_distances
from src.models.linear_correlation import LinearCoor

from src.data.bio_loader import BioLoaders
from src.evaluate.biological import BiologicalEvaluator

from tqdm import tqdm


@click.command()
@click.argument('dataset', type=click.Choice(BioLoaders.keys()), default="toy")
@click.option('--folder', default="biological_experiments")
@click.option('--eps', default=0.01, help="Skinhorn parameter")
@click.option('--niter', default=10, help="Number of iterations")
@click.option('--nepochs', default=50, help="Number of epochs")
@click.option('--gpu', default=None, type=int, help="GPU number")
@click.option('--show/--no-show', default=True, help="show matching")
def main(dataset, folder, eps, niter, nepochs, gpu, show):
    device = "cpu" if gpu is None else "cuda:{}".format(gpu)
    evaluator = BiologicalEvaluator(dataset)
    
    pbar = tqdm(range(nepochs))
    for seed in pbar:
        data = BioLoaders[dataset](normalization=True, seed=seed, device=device)

        frots = [Frot(eta=eta, niter=niter, eps=eps) for eta in [0.3, 0.5, 1.0, 2.0, 5.0]]
        others = [WassersteinByGroup(eps=eps), MMD_distances(), LinearCoor()]
        
        models = {model.modelname: model for model in frots + others}

        for modelname, model in models.items():
            pbar.set_description("Fitting {}".format(modelname))
            fit_and_time(model)(data.X, data.Y, data.groups, platform=data.platform)
            
        evaluator.add_experiment(seed, data, models)
        
    evaluator.save_experiments(folder=folder)

    # evaluator.create_figure(folder=folder, show=show)
    

if __name__ == "__main__":
    main()
