"""
    Taken from https://github.com/yikun-baio/sliced_opt
"""

import numpy as np
from typing import Tuple #,List
import numba as nb 
from numba.typed import List

global p
p = 2

@nb.njit(cache=True,fastmath=True)
def solve_opt(X,Y,lam): #,verbose=False):
    n,m=X.shape[0],Y.shape[0]
    phi=np.full(shape=n,fill_value=-np.inf)
    psi=np.full(shape=m,fill_value=lam)
    # to which cols/rows are rows/cols currently assigned? -1: unassigned
    piRow=np.full(n,-1,dtype=np.int64)
    piCol=np.full(m,-1,dtype=np.int64)
    # a bit shifted from notes. K is index of the row that we are currently processing
    K=0
    # Dijkstra distance array, will be used and initialized on demand in case 3 subroutine
    dist=np.full(n,np.inf)

    jLast=-1
    while K<n:
        x=X[K]
#        if verbose: print(f"K={K}")
        if jLast==-1:
            val,j=closest_y_opt(x,Y,psi)
        else:
            val,j=closest_y_opt(x,Y[jLast:],psi[jLast:])
            j+=jLast
        #val=c[K,j]-psi[j]
        if val>=lam:
            #if verbose: print("case 1")
            phi[K]=lam
            K+=1
        elif piCol[j]==-1:
            #if verbose: print("case 2")
            piCol[j]=K
            piRow[K]=j
            phi[K]=val
            K+=1
            jLast=j
        else:
            #if verbose: print("case 3")
            phi[K]=val
            #assert piCol[j]==K-1
            # Dijkstra distance vector and currently explored radius
            dist[K]=0.
            dist[K-1]=0.
            v=0.

            # iMin and jMin indicate lower end of range of contiguous rows and cols
            # that are currently examined in subroutine;
            # upper end is always K and j
            iMin=K-1
            jMin=j
            # threshold until an entry of phi hits lam
            if phi[K]>phi[K-1]:
                lamDiff=lam-phi[K]
                lamInd=K
            else:
                lamDiff=lam-phi[K-1]
                lamInd=K-1
            resolved=False
            while not resolved:
                # threshold until constr iMin,jMin-1 becomes active
                if jMin>0:
                    lowEndDiff=(X[iMin]-Y[jMin-1])**p-phi[iMin]-psi[jMin-1]
                    # catch: empty rows in between that could numerically be skipped
                    if iMin>0:
                        if piRow[iMin-1]==-1:
                            lowEndDiff=np.infty
                else:
                    lowEndDiff=np.infty
                # threshold for upper end
                if j<m-1:
                    hiEndDiff=(X[K]-Y[j+1])**p-phi[K]-psi[j+1]-v
                else:
                    hiEndDiff=np.infty
                if hiEndDiff<=min((lowEndDiff,lamDiff)):
                 #  if verbose: print("case 3.2")
                    v+=hiEndDiff
                    for i in range(iMin,K):
                        phi[i]+=v-dist[i]
                        psi[piRow[i]]-=v-dist[i]
                    
                    phi[K]+=v
                    piRow[K]=j+1
                    piCol[j+1]=K
                    jLast=j+1
                    resolved=True
                elif lowEndDiff<=min((hiEndDiff,lamDiff)):
                    if piCol[jMin-1]==-1:
                    #    if verbose: print("case 3.3a")
                        v+=lowEndDiff

                        for i in range(iMin,K):
                            phi[i]+=v-dist[i]
                            psi[piRow[i]]-=v-dist[i]
                        phi[K]+=v
                        # "flip" assignment along whole chain
                        jPrime=jMin
                        piCol[jMin-1]=iMin
                        piRow[iMin]-=1
    
                        
                        for i in range(iMin+1,K):
                            piCol[jPrime]+=1
                            piRow[i]-=1
                            jPrime+=1
                        piRow[K]=j #jPrime
                        piCol[j]+=1 #jPrime
                        resolved=True
                    else:
                      #  if verbose: print("case 3.3b")
                      #  assert piCol[jMin-1]==iMin-1
                        v+=lowEndDiff
                        dist[iMin-1]=v
                        # adjust distance to threshold
                        lamDiff-=lowEndDiff
                        iMin-=1
                        jMin-=1
                        if lam-phi[iMin]<lamDiff:
                            lamDiff=lam-phi[iMin]
                            lamInd=iMin

                else:
                 #   if verbose: print(f"case 3.1, lamInd={lamInd}")
                    v+=lamDiff
                    for i in range(iMin,K):
                        phi[i]+=v-dist[i]
                        psi[piRow[i]]-=v-dist[i]
                    phi[K]+=v
                    # "flip" assignment from lambda touching row onwards
                    if lamInd<K:
                        jPrime=piRow[lamInd]
                        piRow[lamInd]=-1
                        
                        for i in range(lamInd+1,K):
                            piCol[jPrime]+=1
                            piRow[i]-=1
                            jPrime+=1
                        piRow[K]=j #jPrime
                        piCol[j]+=1 #jPrime
                    resolved=True
            #assert np.min(c-phi.reshape((M,1))-psi.reshape((1,N)))>=-1E-15
            K+=1
    objective=np.sum(phi)+np.sum(psi)
    return objective,phi,psi,piRow,piCol


@nb.njit(fastmath=True,cache=True)
def closest_y_opt(x,Y,psi):
    m=Y.shape[0]
    min_val=np.inf
    min_index=0
    for j in range(m):
        cost_xy=(x-Y[j])**p-psi[j]
        if cost_xy<min_val:
            min_val=cost_xy
            min_index=j
    return min_val,min_index


@nb.njit(['int64[:](int64[:],int64[:],int64[:])'],cache=True)
def recover_indice(indice_X,indice_Y,L):
    '''
    input:
        indice_X: n*1 float torch tensor, whose entry is integer 0,1,2,....
        indice_Y: m*1 float torch tensor, whose entry is integer 0,1,2,.... 
        L: n*1 list, whose entry could be 0,1,2,... and -1.
        L is the original transportation plan for sorted X,Y 
        L[i]=j denote x_i->y_j and L[i]=-1 denote we destroy x_i. 
        If we ignore -1, it must be in increasing order  
    output:
        mapping_final: the transportation plan for original unsorted X,Y
        
        Eg. X=[2,1,3], indice_X=[1,0,2]
            Y=[3,1,2], indice_Y=[1,2,0]
            L=[0,1,2] which means the mapping 1->1, 2->2, 3->3
        return: 
            L=[2,1,0], which also means the mapping 2->2, 1->1,3->3.
    
    '''
    n=L.shape[0]
    indice_Y_mapped=np.where(L>=0,indice_Y[L],-1)
    mapping=np.stack((indice_X,indice_Y_mapped))
    mapping_final=mapping[1].take(mapping[0].argsort())
    return mapping_final



@nb.njit(cache=True)
def pot(X,Y): 
    #M=cost_matrix(X,Y)
    n,m=X.shape[0],Y.shape[0]
    L=np.zeros(n,dtype=np.int64) # save the optimal plan
    cost=0.0 # save the optimal cost    
    #argmin_Y=closest_y_M(M) # M.argmin(1)

 
    #initial loop:
    k=0
    x=X[k]
    #jk=argmin_Y[k]
    #cost_xk_yjk=M[k,jk]
    cost_xk_yjk,jk=closest_y(x,Y)
    cost+=cost_xk_yjk
    L[k]=jk
    for k in range(1,n):
        x=X[k]
        cost_xk_yjk,jk=closest_y(x,Y)
        j_last=L[k-1]
    
        #define consistent term     
        if jk>j_last:# No conflict, L[-1] is the j last assig
            cost+=cost_xk_yjk
            L[k]=jk
        else:
            # this is the case for conflict: 

            # compute the first cost 
            if j_last+1<=m-1:
                cost_xk_yjlast1=(x-Y[j_last+1])**2
                cost1=cost+cost_xk_yjlast1
            else:
                cost1=np.inf 
            # compute the second cost 
            i_act,j_act=unassign_y(L[0:k])
            if j_act>=0:                        
                cost2=0.
                # cost2=np.sum((X[0:i_act]-Y[L[0:i_act]])**p)+np.sum((X[i_act:k]-Y[L[i_act:k]-1])**p)+(x-Y[j_last])**2
                # in numba for loop is faster
                for ind in range(0,i_act):
                    cost2+=(X[ind]-Y[L[ind]])**p
                for ind in range(i_act,k):
                    cost2+=(X[ind]-Y[L[ind]-1])**p
                cost2+=(x-Y[j_last])**p
                
            else:
                cost2=np.inf
            if cost1<cost2:
                cost=cost1
                L[k]=j_last+1 #=np.append(L,j_last+1)
            elif cost2<=cost1:
                cost=cost2
                for ind in range(i_act,k):
                    L[ind]=L[ind]-1
                L[k]=j_last
                
    return cost,L


@nb.njit(['Tuple((int64,int64))(int64[:])'],cache=True)
def unassign_y(L1):
    '''
    Parameters
    ----------
    L1 : n*1 list , whose entry is 0,1,2,...... 
            transporportation plan. L[i]=j denote we assign x_i to y_j, L[i]=-1, denote we destroy x_i. 
            if we ignore -1, L1 must be in increasing order 
            make sure L1 do not have -1 and is not empty, otherwise there is mistake in the main loop.  


    Returns
    -------
    i_act: integer>=0 
    j_act: integer>=0 or -1    
    j_act=max{j: j not in L1, j<L1[end]} If L1[end]=-1, there is a bug in the main loop. 
    i_act=min{i: L[i]>j_act}.
    
    Eg. input: L1=[1,3,5]
    return: 2,4
    input: L1=[2,3,4]
    return: 0,1
    input: L1=[0,1,2,3]
    return: 0,-1
    
    '''
    n=L1.shape[0]
    j_last=L1[n-1]
    i_last=L1.shape[0]-1 # this is the value of k-i_start
    for l in range(n):
        j=j_last-l
        i=i_last-l+1
        if j > L1[n-1-l]:
            return i,j
    j=j_last-n
    if j>=0:
        return 0,j
    else:       
        return 0,-1
    

@nb.njit(fastmath=True,cache=True) 
def closest_y(x,Y):
    '''
    Parameters
    ----------
    x : float number, xk
    Y : m*1 float np array, 

    Returns
    -------
    min_index : integer >=0
        argmin_j min(x,Y[j])  # you can also return 
    min_cost : float number 
        Y[min_index]

    '''
    min_index=0
    min_cost=np.inf
    for j in range(Y.shape[0]):
        y=Y[j]
        costxy=abs(x-y)**p
        if costxy<min_cost:
            min_cost=costxy
            min_index=j 
    return min_cost,min_index

@nb.njit(cache=True)
def cost_function(x,y): 
    ''' 
    case 1:
        input:
            x: float number
            y: float number 
        output:
            (x-y)**2: float number 
    case 2: 
        input: 
            x: n*1 float np array
            y: n*1 float np array
        output:
            (x-y)**2 n*1 float np array, whose i-th entry is (x_i-y_i)**2
    '''
#    V=np.square(x-y) #**p
    V=np.abs(x-y)**p
    return V