from scipy.optimize import linprog
import numpy as np
import ot
from scipy.optimize.linesearch import scalar_search_armijo
from ot.lp import emd
from ot.bregman import sinkhorn
from scipy import stats
import cvxpy as cp



   
##################################
####### PROXIMAL ALGORITHM #######
##################################

def solve_weak_proximal(x, y, Pk, a, b, gamma, method_proj, batch_size_proj, nb_iter_proj, nb_iter_prox = 100, objectif = False):

    beta = 0.9 # parameter for the line search
    nb_iter_linesearch = 10
        
    if method_proj == 'proj_eucl_sto':
        
        obj = np.zeros(nb_iter_prox)
        for k in np.arange(nb_iter_prox):
            beta = 0.5
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl_sto(a, b, Pk-gamma*grad_f(x, y, Pk, a), batch_size_proj, nb_iter_proj, lr=1, log=False)
                if fct_f(x, y, P, a) <= fct_fhat(x, y, P, Pk, a, gamma):
                    break
                gamma = beta*gamma
            Pk = P
            obj[k] = fct_f(x, y, Pk, a)
            
    if method_proj == 'proj_eucl_sto_acc':
        
        obj = np.zeros(nb_iter_prox)
        Pk_1 = Pk
        for k in np.arange(nb_iter_prox):
            Tk = Pk + (k/(k+3))*(Pk-Pk_1)
            Pk_1 = Pk
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl_sto(a, b, Tk-gamma*grad_f(x, y, Tk, a), batch_size_proj, nb_iter_proj, lr=1, log=False)
                if fct_f(x, y, P, a) <= fct_fhat(x, y, P, Tk, a, gamma):
                    break
                gamma = beta*gamma
            Pk = P
            obj[k] = fct_f(x, y, Pk, a)  
                
    
    elif method_proj == 'proj_eucl':
        
        obj = np.zeros(nb_iter_prox)
        for k in np.arange(nb_iter_prox):
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl(a, b, Pk-gamma*grad_f(x, y, Pk, a), nb_iter_proj)
                if fct_f(x, y, P, a) <= fct_fhat(x, y, P, Pk, a, gamma):
                    break
                gamma = beta*gamma
            Pk = P
            obj[k] = fct_f(x, y, Pk, a)
 

    elif method_proj == 'proj_eucl_acc':
        
        obj = np.zeros(nb_iter_prox)
        ## With an accelerated linesearch "FISTA"
        Pk_1 = Pk
        
        for k in np.arange(nb_iter_prox):
            Tk = Pk + (k/(k+3))*(Pk-Pk_1)
            Pk_1 = Pk
            
            # This is the linesearch step:
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl(a, b, Tk-gamma*grad_f(x, y, Tk, a), nb_iter_proj)
                if fct_f(x, y, P, a) <= fct_fhat(x, y, P, Tk, a, gamma):
                    break
                gamma = beta*gamma          
      
            Pk = P
        
            obj[k] = fct_f(x, y, Pk, a)   
            #if (fct_f(x, y, Pk_1, a)-fct_f(x, y, Pk, a))<1e-6:
            #    print('break at iteration k={}'.format(k))
            #    break
    if objectif:
        return Pk, obj
    else:
        return Pk


## For the linesearch

def fct_f(x, y, P, a):
    return np.sum(a[None,:]*((x-np.matmul(P,y)/a[:,None])**2).sum(1))

def fct_fhat(x, y, P, Pk, a, gamma_prox):
    return fct_f(x, y, Pk, a)+(grad_f(x, y, Pk, a)*(P-Pk)).sum()+(1/(2*gamma_prox))*((P-Pk)**2).sum()



##################################################################################
############################## QUADRATIC PROGRAMMING #############################
##################################################################################


def solve_weak_QP(x, y, a, b):
    ns = len(a)
    nt = len(b)
    # Construct the problem.
    pi = cp.Variable((ns,nt))
    objective = cp.Minimize(cp.sum_squares((np.sqrt(a[:,None])*x-(pi@y)/np.sqrt(a[:,None]))))
    constraints = [0 <= pi , pi@np.ones((nt,)) == a, (pi.T)@np.ones((ns,)) == b]
    prob = cp.Problem(objective, constraints)
    result = prob.solve()

    plan_QP = pi.value
    plan_QP[plan_QP<=0] = 0
    plan_QP /= np.sum(plan_QP)
    
    return plan_QP


###################################################################################
####### PROJECTION ONTO THE SPACE OF TRANSPORT PLAN VIA DYSKTRA'S ALGORITHM #######
###################################################################################

# A. Dessein, N. Papadakis, and J.-L. Rouas. Regularized optimal transport and the ROT mover’s distance.
# The Journal of Machine Learning Research, 19(1):590–642, 2018.

## Solve min <-P,gamma> + (reg/2) |gamma|^2 = solve min |P-gamma/reg|^2

def proj_eucl(a, b, P, nb_iter_proj):
    dim_a = len(a)
    dim_b = len(b)
    
    #P = P / np.max(P)
    xi = np.empty(P.shape, dtype=P.dtype)
    np.divide(P, 1, out=xi) # reg = 1
    psip = np.zeros((dim_a, dim_b))
    
    cpt = 0
    err = 1
    tmp2 = np.empty(b.shape, dtype=P.dtype)

    while (err > 1e-9 and cpt < nb_iter_proj):
        psipprev = psip
        
        xi = xi + (np.matmul((a-psip.sum(1))[:,None],np.ones((1,dim_b))))/dim_b  # Row offsetting
        psip = np.maximum(xi, 0)   # Entry truncation
        xi = xi + np.matmul(np.ones((dim_a,1)),(b-psip.sum(0))[None,:])/dim_a   # Column offsetting
        psip = np.maximum(xi, 0)   # Entry truncation
        
        if cpt % 10 == 0:
            err = np.linalg.norm(psipprev - psip)
        cpt = cpt + 1
    
    #print('Final error in the projection = {}'.format(err))
    #print('Break projection at cpt={}'.format(cpt))
    
    return psip


    
    
##########################
####### GRADIENT F #######
##########################

def grad_f(x, y, pi,a):
    piy = np.matmul(pi,y)/a[:,None]
    df = -2*(y[None,:]*(x[:,None]-piy[:,None])).sum(2)
    return df



################################################################################
####### PROJECTION ONTO THE SPACE OF TRANSPORT PLAN VIA STOCHASTIC TOOLS #######
################################################################################

# V. Seguy, B.B. Damodaran, R. Flamary, N. Courty, A. Rolet, and M. Blondel. Large-scale optimal transport
# and mapping estimation. In International Conference on Learning Representations (ICLR), 2018.

def proj_eucl_sto(a, b, M, batch_size, nb_iter_proj = 10000, lr=1, log=False):
    
    r'''
    Compute the transportation matrix to solve the regularized discrete measures
        optimal transport dual problem

    '''
    
    reg = 0.5 # In order to obtain the right projection

    opt_alpha, opt_beta = sgd_proj_eucl_sto(a, b, M, reg, batch_size, nb_iter_proj, lr)
    
    pi = (np.maximum((opt_alpha[:, None] + opt_beta[None, :] + M[:, :]),0))/(2*reg)
        
        
    if log:
        log = {}
        log['alpha'] = opt_alpha
        log['beta'] = opt_beta
        return pi, log
    else:
        return pi
    
    
def batch_grad_proj_eucl_sto(a, b, M, reg, alpha, beta, batch_size, batch_alpha, batch_beta):
    
    r'''
    
    Computes the partial gradient of the dual optimal transport problem.
    
    '''

    G = - (np.maximum((alpha[batch_alpha, None] + beta[None, batch_beta] + 
           M[batch_alpha, :][:, batch_beta]),0))/(2*reg)
    grad_beta = np.zeros(np.shape(M)[1])
    grad_alpha = np.zeros(np.shape(M)[0])
    grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0]
                             + G.sum(0))
    grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta)
                               / np.shape(M)[1] + G.sum(1))

    return grad_alpha, grad_beta


    
def sgd_proj_eucl_sto(a, b, M, reg, batch_size, nb_iter_proj, lr):
    
    r'''
    Compute the sgd algorithm to solve the regularized discrete measures
        optimal transport dual problem

    The function solves the following optimization problem:
    
    '''

    n_source = np.shape(M)[0]
    n_target = np.shape(M)[1]
    cur_alpha = np.zeros(n_source)
    cur_beta = np.zeros(n_target)
    for cur_iter in np.arange(nb_iter_proj):
        k = np.sqrt(cur_iter + 1)
        batch_alpha = np.random.choice(n_source, batch_size, replace=False)
        batch_beta = np.random.choice(n_target, batch_size, replace=False)
        update_alpha, update_beta = batch_grad_proj_eucl_sto(a, b, M, reg, cur_alpha,
                                                    cur_beta, batch_size,
                                                    batch_alpha, batch_beta)
        cur_alpha[batch_alpha] += (lr / k) * update_alpha[batch_alpha]
        cur_beta[batch_beta] += (lr / k) * update_beta[batch_beta]

    return cur_alpha, cur_beta




######################################################
################# OBJECTIVE FUNCTION #################
######################################################

def objective(mi, final_S, wbar):
    ## Objective function that is \sum V(\mu,\nu_i)/n
    res = ((((final_S**2).sum(2))*wbar[None,:]).sum(1)).mean() - ((((final_S*wbar[None,:,None]).sum(1)).mean(0))**2).sum()-((mi**2).sum(1)).mean()-((mi.mean(0))**2).sum()
    return res

#########################################################
################# GENERATE TOY EXAMPLES #################
#########################################################

def generate_ring(n, mean_x, mean_y, rad_ring = 2):
    theta = np.random.uniform(0, 2*np.pi, n)
    rx = np.random.uniform(mean_x-rad_ring, mean_x+rad_ring)
    ry = np.random.uniform(mean_y-rad_ring, mean_y+rad_ring)
    x = rx*np.cos(theta)
    y = ry*np.sin(theta)
    dd = np.stack((x,y))
    return dd.T




#############################################################
################# ONE-DIMENSIONAL FUNCTIONS #################
#############################################################


def fct_V_1D(x, y, P, a):
    return np.sum(a*(x-np.sum(P*y/a[:,None],1))**2)

def grad_V_1D(x, y, P, a):
    xPy = x-np.sum((P*y)/a[:,None],1)
    df = -2*(np.inner(xPy[:, None],y[:,None]))
    return df

def fct_Vhat_1D(x, y, P, Pk, a, gamma_prox):
    return fct_V_1D(x, y, Pk, a)+(grad_V_1D(x, y, Pk, a)*(P-Pk)).sum()+(1/(2*gamma_prox))*((P-Pk)**2).sum()

def solve_OWT_1D(x, y, Pk, a, b, gamma, method_proj, nb_iter_proj, nb_iter_prox = 100):

    beta = 0.5 # parameter for the line search
    nb_iter_linesearch = 10
        
    
    if method_proj == 'proj_eucl':
        
        obj = np.zeros(nb_iter_prox)
        for k in np.arange(nb_iter_prox):
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl(a, b, Pk-gamma*grad_V_1D(x, y, Pk, a), nb_iter_proj)
                if fct_V_1D(x, y, P, a) <= fct_Vhat_1D(x, y, P, Pk, a, gamma):
                    break
                gamma = beta*gamma
            Pk = P
            obj[k] = fct_V_1D(x, y, Pk, a)
            
            
            
    elif method_proj == 'proj_eucl_acc':
        
        obj = np.zeros(nb_iter_prox)
        ## With an accelerated linesearch "FISTA"
        Pk_1 = Pk
        
        for k in np.arange(nb_iter_prox):
            Tk = Pk + (k/(k+3))*(Pk-Pk_1)
            Pk_1 = Pk
            
            # This is the linesearch step:            
            for j in np.arange(nb_iter_linesearch) : 
                P = proj_eucl(a, b, Tk-gamma*grad_V_1D(x, y, Tk, a), nb_iter_proj)
                if fct_V_1D(x, y, P, a) <= fct_Vhat_1D(x, y, P, Tk, a, gamma):
                    break
                gamma = beta*gamma   
      
            Pk = P
        
            obj[k] = fct_V_1D(x, y, Pk, a)
            #if (fct_V_1D(x, y, Pk_1, a)-fct_V_1D(x, y, Pk, a))<1e-15:
            #    print('break at iteration k={}'.format(k))
            #    break
                    
    return Pk, obj


def solve_weak_QP_1D(x, y, a, b):
    ns = len(a)
    nt = len(b)
    # Construct the problem.
    pi = cp.Variable((ns,nt))
    objective = cp.Minimize(cp.sum_squares((np.sqrt(a)*x-(pi@y)/np.sqrt(a))))
    constraints = [0 <= pi , pi@np.ones((nt,)) == a, (pi.T)@np.ones((ns,)) == b]
    prob = cp.Problem(objective, constraints)
    result = prob.solve()

    plan_QP = pi.value
    plan_QP[plan_QP<=0] = 0
    plan_QP /= np.sum(plan_QP)
    
    return plan_QP


