# -*- coding: utf-8 -*-
"""
Elatic Optimal Transport Solvers
"""

# Author: elot
# License: MIT License

import numpy as np
import ot
from ot.lp import emd


def elot_emd(a, b, M, nb_dummies=1, log=False, **kwargs):
    # equivalent OT problem
    b_extended = np.append(b, [(np.sum(a)) / nb_dummies] * nb_dummies)
    a_extended = np.append(a, [(np.sum(b)) / nb_dummies] * nb_dummies)
    M_extended = np.zeros((len(a_extended), len(b_extended)))
    M_extended[:len(a), :len(b)] = M

    # call emd solver
    gamma, log_ot = emd(a_extended, b_extended, M_extended, log=True,
                         **kwargs)

    if log_ot['warning'] is not None:
        raise ValueError("Error in the EMD resolution: try to increase the"
                         " number of dummy points")
    log_ot['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])

    if log:
        return gamma[:len(a), :len(b)], log_ot
    else:
        return gamma[:len(a), :len(b)]


def elot_entropic(a, b, M, reg, nb_dummies=1, numItermax=1000,
                                 stopThr=1e-100, verbose=False, log=False, **kwargs):
    # equivalent OT problem
    b_extended = np.append(b, [(np.sum(a)) / nb_dummies] * nb_dummies)
    a_extended = np.append(a, [(np.sum(b)) / nb_dummies] * nb_dummies)
    M_extended = np.zeros((len(a_extended), len(b_extended)))
    M_extended[:len(a), :len(b)] = M

    # call sinkhorn solver
    gamma, log_ot = ot.sinkhorn(a_extended, b_extended, M_extended, reg, numItermax=numItermax,
                                 stopThr=stopThr, verbose=verbose, log=True, **kwargs)

    # if log_ot['warning'] is not None:
    #     raise ValueError("Error in the EMD resolution: try to increase the"
    #                      " number of dummy points")
    log_ot['partial_w_dist'] = np.sum(M * gamma[:len(a), :len(b)])

    if log:
        return gamma[:len(a), :len(b)], log_ot
    else:
        return gamma[:len(a), :len(b)]
