import numpy as np
import ot
import scipy as sp
import pickle
import datetime as dt
import multiprocessing as mp

def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m_1, reg_m_2, numItermax=1000,
                              stopThr=1e-6, verbose=False, log=False, **kwargs):
    r"""
    Solve the entropic regularization unbalanced optimal transport problem and return the loss

    The function solves the following optimization problem:

    .. math::
        W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg}\cdot\Omega(\gamma) +
        \mathrm{reg_m} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
        \mathrm{reg_m} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})

        s.t.
             \gamma \geq 0

    where :

    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term, :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
    - KL is the Kullback-Leibler divergence

    The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`


    Parameters
    ----------
    a : array-like (dim_a,)
        Unnormalized histogram of dimension `dim_a`
    b : array-like (dim_b,) or array-like (dim_b, n_hists)
        One or multiple unnormalized histograms of dimension `dim_b`
        If many, compute all the OT distances (a, b_i)
    M : array-like (dim_a, dim_b)
        loss matrix
    reg : float
        Entropy regularization term > 0
    reg_m: float
        Marginal relaxation term > 0
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (> 0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True


    Returns
    -------
    if n_hists == 1:
        - gamma : (dim_a, dim_b) array-like
            Optimal transportation matrix for the given parameters
        - log : dict
            log dictionary returned only if `log` is `True`
    else:
        - ot_distance : (n_hists,) array-like
            the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
        - log : dict
            log dictionary returned only if `log` is `True`

    Examples
    --------

    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.],[1., 0.]]
    >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
    array([[0.51122823, 0.18807035],
           [0.18807035, 0.51122823]])


    .. _references-sinkhorn-knopp-unbalanced:
    References
    ----------
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint
        arXiv:1607.05816.

    .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
        Learning with a Wasserstein Loss,  Advances in Neural Information
        Processing Systems (NIPS) 2015

    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT

    """
    M, a, b = ot.utils.list_to_array(M, a, b)
    nx = ot.backend.get_backend(M, a, b)

    dim_a, dim_b = M.shape

    if len(a) == 0:
        a = nx.ones(dim_a, type_as=M) / dim_a
    if len(b) == 0:
        b = nx.ones(dim_b, type_as=M) / dim_b

    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0

    if log:
        log = {'err': []}

    # we assume that no distances are null except those of the diagonal of
    # distances
    if n_hists:
        u = nx.ones((dim_a, 1), type_as=M) / dim_a
        v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
        a = a.reshape(dim_a, 1)
    else:
        u = nx.ones(dim_a, type_as=M) / dim_a
        v = nx.ones(dim_b, type_as=M) / dim_b

    K = nx.exp(M / (-reg))

    fi_1 = reg_m_1 / (reg_m_1 + reg)
    fi_2 = reg_m_2 / (reg_m_2 + reg)

    err = 1.

    for i in range(numItermax):
        uprev = u
        vprev = v

        Kv = nx.dot(K, v)
        u = (a / Kv) ** fi_1
        Ktu = nx.dot(K.T, u)
        v = (b / Ktu) ** fi_2

        if (nx.any(Ktu == 0.)
                or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
                or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn('Numerical errors at iteration %s' % i)
            u = uprev
            v = vprev
            break

        err_u = nx.max(nx.abs(u - uprev)) / max(
            nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
        )
        err_v = nx.max(nx.abs(v - vprev)) / max(
            nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
        )
        err = 0.5 * (err_u + err_v)
        if log:
            log['err'].append(err)
            if verbose:
                if i % 50 == 0:
                    print(
                        '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
                print('{:5d}|{:8e}|'.format(i, err))
        if err < stopThr:
            break

    if log:
        log['logu'] = nx.log(u + 1e-300)
        log['logv'] = nx.log(v + 1e-300)

    if n_hists:  # return only loss
        res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
        if log:
            return res, log
        else:
            return res

    else:  # return OT matrix

        if log:
            return u[:, None] * K * v[None, :], log
        else:
            return u[:, None] * K * v[None, :]

        
def bi(a,b,M):
    mat = M.copy()
    mat /= np.sum(M, axis=0)
    return mat * b
        
class LearnerOT:
    def __init__(self, M, theta, 
                 eta=None, 
                 log=False, 
                 method="ot", 
                 param=(1,), 
                 update_M=False):
        '''
        
        '''
        self.M = M
        self.theta = theta
        self.n, self.m = M.shape
        self.update_M = update_M
        if eta is None:
            self.exact_eta = False
            self.eta = np.ones(self.n,dtype=np.float64) / self.n
        else:
            self.exact_eta = True
            self.eta = eta
            
        self.total_data = 1
        self.log = log
        if log:
            self.data = [-1,]
            self.posterior = [theta]
        self.method_flag = method
        if method == "ot":
            self.method = lambda x: ot.sinkhorn(*x, *param)
        elif method == "uot":
            self.method = lambda x: sinkhorn_knopp_unbalanced(*x, *param)
        elif method == 'bi':
            self.method = lambda x: bi(*x) # x[2].copy()/np.sum(x[2],axis=0) * x[1]
            
    def learn(self, data):
        mat = self.method([self.eta, self.theta, self.M,])
        if self.update_M:
            self.M = mat.copy()
        
        if self.method_flag == "ot":
            self.theta = mat[data, :] / self.eta[data]
        elif self.method_flag == "uot" or self.method_flag == "bi":
            self.theta = mat[data, :] / np.sum(mat[data, :])
            # tmp = mat[data,:].copy()
            # tmp /= np.sum(mat, axis=0)
            # tmp *= self.theta
            # self.theta = tmp / np.sum(tmp)
        if self.log:
            self.data += [data]
            self.posterior += [self.theta]
        if not self.exact_eta:
            self.eta[data] += 1./self.total_data
            self.eta *= self.total_data/(self.total_data+1)
        self.total_data += 1
        return self.theta
        
class TeacherNonStationary:
    '''
    '''

    
class TeacherCyclic:
    '''
    '''
    def __init__(self, M, phase_len=5):
        self.n, self.m = M.shape
        self.current = -1
        self.M = M
        self.phase_len = phase_len
        
    def teach(self):
        self.current += 1
        self.current %= self.phase_len * self.m
        print((self.current//self.phase_len) % self.m)
        return np.random.choice(self.n, p=M[:,(self.current//self.phase_len) % self.m])
    

class TeacherNaive:
    def __init__(self, eta):
        self.eta = eta
        self.m = eta.shape[0]
        
    def teach(self):
        return np.random.choice(self.m, 1, replace=False, p=self.eta)[0]
    
# PROGRESS = mp.shared_memory.Shareable_List([0.,] * process_cnt)
    
def single(pack):
    '''
    pack = [pid, M, eta, theta, max_step, episode_size, exact_eta, method, param, update_M]
    pid: index of this process
    M: matrix of size n*m
    eta: vector of size n
    theta: vector of size m
    max_step: length of one episode
    episode_size: amount of episodes in Monte Carlo
    exact_eta: True or False, whether learner uses exact data
    method: "ot" or "uot"
    param: tuple, (epsilon,) for ot and (epsilon, tau1, tau2) for uot
    update_M: True or False, whether learner updates M in an episode
    '''
    thres = 0.95
    pid, seed, M, eta, theta, max_step, episode_size, exact_eta, method, param, update_M = pack
    np.random.seed(seed)
    n, m = M.shape
    
    # teacher = TeacherNaive(eta.copy())
    teacher = TeacherNaive(eta.copy())
    
    result = np.zeros([max_step, m], dtype=np.float64)
    summary = np.zeros([max_step, m], dtype=np.float64)
    count = np.zeros(m, dtype=np.float64)
    logs = []
    for k in range(episode_size):
        if exact_eta:
            learner = LearnerOT(M.copy(), theta.copy(), eta.copy(), log=True,
                                method=method, param=param,
                                update_M=update_M)
        else:
            learner = LearnerOT(M.copy(), theta.copy(), log=True,
                                method=method, param=param,
                                update_M=update_M)
        # Teach and learn
        for i in range(max_step):
            d = teacher.teach()
            theta_l = learner.learn(d)
            result[i] = theta_l.copy()
        
        logs += [(learner.data, learner.posterior)]
        if k%20==0:
            print("process %d: %.1f%% completed"%(pid, k / episode_size * 100))
            # We could use shared_memory to maintain a list of progresses...
    
    
    return logs


    if False:
        # print(result[-1])
        # Check whether the end converges to some existing hypothesis (>threshold)
        if np.any(result[-1]>thres):
            # define which hypothesis the episode converge to
            h = np.argwhere(result[-1]>thres)[0,0]
            # increase the count of that hypothesis by 1
            count[h] += 1
            # use the rest to calculate `log(\theta)-log(1-\theta)`
            # in terms of the whole episode in `result`
            summary[:,h] += np.log(result[:,h])-np.log(np.sum(result[:, [x for x in range(m) if x!=h]], axis=1))
            
            
        if k%20==0:
            print("process %d: %.1f%% completed"%(pid, k / episode_size * 100))
            # We could use shared_memory to maintain a list of progresses...
    # print(summary, count)
    return summary / count
    
    
    
def mp_test(M, eta, theta, 
            max_step, episode_size, 
            process_cnt=4,
            exact_eta=False,
            method="ot", param=(1,),
            update_M=False, suffix=""):
    
    '''
    M: matrix of size n*m
    eta: vector of size n
    theta: vector of size m
    max_step: length of one episode
    episode_size: amount of episodes in Monte Carlo
    exact_eta: True or False, whether learner uses exact data
    method: "ot" or "uot"
    param: tuple, (epsilon,) for ot and (epsilon, tau1, tau2) for uot
    update_M: True or False, whether learner updates M in an episode
    '''
    pool = mp.Pool()
    seeds = np.random.randint(2**32-1, size=process_cnt, dtype=np.int64)
    pack_list = [(pid, seeds[pid], M, eta, theta, max_step, episode_size, 
                  exact_eta, method, param, update_M) for pid in range(process_cnt)]
    results = pool.map(single, pack_list)
    with open("./data/mp_" + suffix + ".dat", "wb") as fp:
        pickle.dump({"pack_list": pack_list, "result": results}, fp)
        
    return results
    
    
    
if __name__ == "__main__":
    episode_size = 100000
    max_step = 50
    n, m = 3, 3
    epsilon = 1