import numpy as np
import ot
import scipy as sp
import pickle
import datetime as dt
import multiprocessing as mp

def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m_1, reg_m_2, numItermax=1000,
                              stopThr=1e-6, verbose=False, log=False, **kwargs):
    r"""
    Solve the entropic regularization unbalanced optimal transport problem and return the loss

    The function solves the following optimization problem:

    .. math::
        W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
        \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
    b : array-like (dim_b,) or array-like (dim_b, n_hists)
        One or multiple unnormalized histograms of dimension `dim_b`
        If many, compute all the OT distances (a, b_i)
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float
        Marginal relaxation term > 0
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_distance : (n_hists,) array-like
            the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
    array([[0.51122823, 0.18807035],
           [0.18807035, 0.51122823]])


    .. _references-sinkhorn-knopp-unbalanced:
    References
    ----------
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT

    """
    M, a, b = ot.utils.list_to_array(M, a, b)
    nx = ot.backend.get_backend(M, a, b)

    dim_a, dim_b = M.shape

    if len(a) == 0:
        a = nx.ones(dim_a, type_as=M) / dim_a
    if len(b) == 0:
        b = nx.ones(dim_b, type_as=M) / dim_b

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    if log:
        log = {'err': []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if n_hists:
        u = nx.ones((dim_a, 1), type_as=M) / dim_a
        v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
        a = a.reshape(dim_a, 1)
    else:
        u = nx.ones(dim_a, type_as=M) / dim_a
        v = nx.ones(dim_b, type_as=M) / dim_b

    K = nx.exp(M / (-reg))

    fi_1 = reg_m_1 / (reg_m_1 + reg)
    fi_2 = reg_m_2 / (reg_m_2 + reg)

    err = 1.

    for i in range(numItermax):
        uprev = u
        vprev = v

        Kv = nx.dot(K, v)
        u = (a / Kv) ** fi_1
        Ktu = nx.dot(K.T, u)
        v = (b / Ktu) ** fi_2

        if (nx.any(Ktu == 0.)
                or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
                or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            # warnings.warn('Numerical errors at iteration %s' % i)
            u = uprev
            v = vprev
            break

        err_u = nx.max(nx.abs(u - uprev)) / max(
            nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
        )
        err_v = nx.max(nx.abs(v - vprev)) / max(
            nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
        )
        err = 0.5 * (err_u + err_v)
        if log:
            log['err'].append(err)
            if verbose:
                if i % 50 == 0:
                    print(
                        '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(i, err))
        if err < stopThr:
            break

    if log:
        log['logu'] = nx.log(u + 1e-300)
        log['logv'] = nx.log(v + 1e-300)

    if n_hists:  # return only loss
        res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
        if log:
            return res, log
        else:
            return res

    else:  # return OT matrix

        if log:
            return u[:, None] * K * v[None, :], log
        else:
            return u[:, None] * K * v[None, :]


def bi(a,b,M):
    mat = M.copy()
    mat /= np.sum(M, axis=0)
    return mat * b


class LearnerOT:
    """Learner OT."""

    def __init__(self, M, theta,
                 eta=None,
                 log=False,
                 method="ot",
                 param=(1,),
                 update_M=False):
        """Initialize."""
        self.M = M
        self.theta = theta
        self.n, self.m = M.shape
        self.update_M = update_M
        if eta is None:
            self.exact_eta = False
            self.eta = np.ones(self.n, dtype=np.float64) / self.n
        else:
            self.exact_eta = True
            self.eta = eta

        self.total_data = 1
        self.log = log
        if log:
            self.data = [-1, ]
            self.posterior = [theta]
        self.method_flag = method
        if method == "ot":
            self.method = lambda x: ot.sinkhorn(*x, *param)
        elif method == "uot":
            self.method = lambda x: sinkhorn_knopp_unbalanced(*x, *param)
        elif method == 'bi':
            self.method = lambda x: bi(*x)
            # x[2].copy()/np.sum(x[2],axis=0) * x[1]

    def learn(self, data):
        mat = self.method([self.eta, self.theta, self.M,])
        if self.update_M:
            self.M = mat.copy()

        if self.method_flag == "ot":
            self.theta = mat[data, :] / self.eta[data]
        elif self.method_flag == "uot" or self.method_flag == "bi":
            self.theta = mat[data, :] / np.sum(mat[data, :])
            # tmp = mat[data,:].copy()
            # tmp /= np.sum(mat, axis=0)
            # tmp *= self.theta
            # self.theta = tmp / np.sum(tmp)
        if self.log:
            self.data += [data]
            self.posterior += [self.theta]
        if not self.exact_eta:
            self.eta[data] += 1./self.total_data
            self.eta *= self.total_data/(self.total_data+1)
        self.total_data += 1
        return self.theta


class TeacherNonStationary:
    """
    """


class TeacherCyclic:
    """A cyclic teacher."""

    def __init__(self, M, phase_len=5, phase=0):
        """Init."""
        self.n, self.m = M.shape
        self.current = -1
        self.M = M
        self.phase_len = phase_len
        self.phase = phase

    def teach(self):
        """Teach."""
        self.current += 1
        self.current %= self.phase_len * self.m
        # print((self.current//self.phase_len) % self.m)
        return np.random.choice(self.n, p=self.M[:,
                                                 (self.current//self.phase_len
                                                  + self.phase) % self.m])


def area_of_cycle(seq):
    """Area of a cyclic sequence."""
    area = 0
    for i in range(2, len(seq)):
        v1 = seq[i] - seq[0]
        v2 = seq[i-1] - seq[0]
        area += np.sqrt(np.sum(v1**2)*np.sum(v2**2) - np.sum(v1*v2)**2)/2
    return area


class TeacherOnCircle:
    """Teach a sequence generated by matrix."""

    def __init__(self, M, phase_len=5, phase=0):
        self.M = M
        self.n, self.m = M.shape
        self.current = -1 + phase
        self.phase = phase
        self.phase_len = phase_len
        self.period = int(self.m * self.phase_len)
        self.l = np.linspace(0, 1, phase_len, endpoint=False).reshape(-1, 1)
        hypo = []
        for i in range(self.m):
            hypo += [M[:, i] * (1-self.l) + M[:, (i+1) % self.m] * self.l, ]
        self.hypo = np.concatenate(hypo, axis=0)
        self.path_area()

    def teach(self):
        self.current += 1
        self.current %= self.period
        return np.random.choice(self.n, p=self.hypo[self.current])

    def path_area(self):
        return area_of_cycle(self.hypo)


class TeacherPeriod(TeacherOnCircle):
    """Teach along a path."""

    def __init__(self, M, phase_len=5, phase=0, path=None):
        if path is None:
            super(self.__class__, self).__init__(M, phase_len, phase)
        else:
            self.M = M
            self.n, self.m = M.shape
            self.phase = phase
            self.phase_len = phase_len
            self.current = phase - 1
            self.period = len(path)
            self.hypo = np.array(path)
            self.path_area()


class TeacherNaive:
    def __init__(self, eta):
        self.eta = eta
        self.m = eta.shape[0]

    def teach(self):
        return np.random.choice(self.m, 1, replace=False, p=self.eta)[0]

# PROGRESS = mp.shared_memory.Shareable_List([0.,] * process_cnt)

def single(pack):
    """
    pack = [pid, M, eta, theta, max_step, episode_size, exact_eta, method, param, update_M]
    pid: index of this process
    M: matrix of size n*m
    eta: vector of size n
    theta: vector of size m
    max_step: length of one episode
    episode_size: amount of episodes in Monte Carlo
    exact_eta: True or False, whether learner uses exact data
    method: "ot" or "uot"
    param: tuple, (epsilon,) for ot and (epsilon, tau1, tau2) for uot
    update_M: True or False, whether learner updates M in an episode
    """
    thres = 0.95
    pid, seed, M, eta, theta, max_step, episode_size, exact_eta, method, param, update_M, teacherClass, teacher_args = pack
    # phase_len, phase = teacher_args
    np.random.seed(seed)
    n, m = M.shape

    # teacher = TeacherNaive(eta.copy())

    result = np.zeros([max_step, m], dtype=np.float64)
    summary = np.zeros([max_step, m], dtype=np.float64)
    count = np.zeros(m, dtype=np.float64)
    logs = []
    for k in range(episode_size):
        teacher = teacherClass(M.copy(), *teacher_args)
        # teacher = TeacherCyclic(M.copy(), phase_len, phase)
        if exact_eta:
            learner = LearnerOT(M.copy(), theta.copy(), eta.copy(), log=True,
                                method=method, param=param,
                                update_M=update_M)
        else:
            learner = LearnerOT(M.copy(), theta.copy(), log=True,
                                method=method, param=param,
                                update_M=update_M)
        # Teach and learn
        for i in range(max_step):
            d = teacher.teach()
            theta_l = learner.learn(d)
            result[i] = theta_l.copy()

        logs += [(learner.data, learner.posterior)]
        if k % 20 == 0:
            print("process %d: %.1f%% completed"%(pid, k / episode_size * 100))
            # We could use shared_memory to maintain a list of progresses...

    return logs

    if False:
        # print(result[-1])
        # Check whether the end converges to some existing hypothesis (>threshold)
        if np.any(result[-1]>thres):
            # define which hypothesis the episode converge to
            h = np.argwhere(result[-1]>thres)[0,0]
            # increase the count of that hypothesis by 1
            count[h] += 1
            # use the rest to calculate `log(\theta)-log(1-\theta)`
            # in terms of the whole episode in `result`
            summary[:,h] += np.log(result[:,h])-np.log(np.sum(result[:, [x for x in range(m) if x!=h]], axis=1))

        if k % 20 == 0:
            print("process %d: %.1f%% completed"%(pid, k / episode_size * 100))
            # We could use shared_memory to maintain a list of progresses...
    # print(summary, count)
    return summary / count


def mp_test(M, eta, theta,
            max_step, episode_size,
            process_cnt=4,
            exact_eta=False,
            method="ot", param=(1,),
            update_M=False, suffix="",
            phase_len=5, phase=0, teacherClass=TeacherCyclic,
            teacher_args=[]):
    """
    Multi-processing test.

    M: matrix of size n*m
    eta: vector of size n
    theta: vector of size m
    max_step: length of one episode
    episode_size: amount of episodes in Monte Carlo
    exact_eta: True or False, whether learner uses exact data
    method: "ot" or "uot"
    param: tuple, (epsilon,) for ot and (epsilon, tau1, tau2) for uot
    update_M: True or False, whether learner updates M in an episode
    """
    pool = mp.Pool()
    seeds = np.random.randint(2**32-1, size=process_cnt, dtype=np.int64)
    teacher_args = [phase_len, phase] + teacher_args
    pack_list = [(pid, seeds[pid], M, eta, theta, max_step, episode_size,
                  exact_eta, method, param, update_M, teacherClass, teacher_args)
                 for pid in range(process_cnt)]
    results = pool.map(single, pack_list)
    with open("./data/mp_" + suffix + ".dat", "wb") as fp:
        pickle.dump({"pack_list": pack_list, "result": results}, fp)

    return results


if __name__ == "__main__":
    episode_size = 100000
    max_step = 50
    n, m = 3, 3
    epsilon = 1
