import time
import logging
logging.basicConfig(encoding='utf-8', level=logging.INFO)

import solvers
from utils.data.meshes.load import load_pointcloud_from_mesh
from utils.math.costs import euclidean
from utils.viz.transports import plot_transfer

SINK_ARGS = {'numItermax': 50,
             'symmetrize': True}

ARGS = {'eps': 1e-3,
        'numItermax': 100,
        'stop_criterion': 'energy',
        'stopThr': 1e-5,
        'SINK_ARGS': SINK_ARGS}

N, M = 5000, 5000

if __name__ == "__main__":
    X = load_pointcloud_from_mesh('data/00049424_ferrari.ply', N=N)
    Y = load_pointcloud_from_mesh('data/muybridge_014_01.ply', N=N)

    cost = lambda u, v: euclidean(u, v, p=1)

    Cx = cost(X[:, None], X[None, :])
    Cy = cost(Y[:, None], Y[None, :])
    exact_solvers = [solvers.EntropicGW(Cx=Cx, Cy=Cy, **ARGS),
                     solvers.KernelGW(Cx=Cx, Cy=Cy, **ARGS)]

    approx_solvers = [solvers.CntGW(X, Y, cost, approx_dims=20, **ARGS),
                      solvers.LowRankGW(X, Y, cost, approx_dims=20, **ARGS),
                      solvers.MultiscaleCntGW(X, Y, cost, approx_dims=20, ratio=0.1, **ARGS), ]

    for solver in exact_solvers:
        start_time = time.time()
        solver.to('cuda')
        solver.solve(verbose=True)
        print(f"Solver: {solver.__class__.__name__} \t Time: {time.time() - start_time}")
        P = solver.transport_plan(lazy=False).cpu()
        solver.to('cpu')
        solver = None
        plot_transfer(X, Y, P, lazy=False)


    for solver in approx_solvers:
        start_time = time.time()
        solver.solve(verbose=True)
        print(f"Solver: {solver.__class__.__name__} \t Time: {time.time() - start_time}")
        P = solver.transport_plan(lazy=True)
        plot_transfer(X, Y, P, lazy=True)
