# modification to recombination.py

import numpy as np
from numpy import linalg
import scipy as sp
from scipy import linalg
from scipy.linalg import qr

import copy, timeit
#
import numba
from numba import jit
print("Import recombination2")


####################################################################
# Tchernychova_Lyons* functions are the algortihms presented in 
# Tchernychova, Lyons - Caratheodory Cubature Measures, PhD Thesis, 
#                       Univeristy of Oxford, 2016
####################################################################

def Tchernychova_Lyons(X, mu=0,DEBUG=False):

    # It takes X (N x n) and returns the weights w_star and the n+1 points
    # x_star.
    # mu represents the weights of the points in X

    # This function does not need the 
    # barycenter of the point in X (relatively to mu) to be 0

    N, n = X.shape
    tic = timeit.default_timer()

    number_of_sets = 2*(n+1)
    ERR=0
    if np.all(mu==0) or len(mu)!=N or np.any(mu<0):
        mu = np.ones(N)/N
    
    idx_story = np.arange(N)
    idx_story = idx_story[mu!=0]
    remaining_points = len(idx_story)

    while True:
        
        if remaining_points <= n+1:
            idx_star = np.arange(len(mu))[mu>0]
            w_star = mu[idx_star]
            x_star = X[idx_star]
            toc = timeit.default_timer()-tic
            return w_star, idx_star, X[idx_star], toc, ERR, np.nan, np.nan
        elif n+1 < remaining_points <= number_of_sets:
            w_star, idx_star, x_star, _, ERR, _, _ = Tchernychova_Lyons_CAR_jit(X[idx_story], np.copy(mu[idx_story]))
            idx_story = idx_story[idx_star]
            mu[:] = 0.
            mu[idx_story] = w_star
            idx_star = idx_story
            x_star = X[idx_story]
            w_star = mu[mu>0]
            toc = timeit.default_timer()-tic
            return w_star, idx_star, x_star, toc, ERR, np.nan, np.nan
        
        # remaining points at the next step are = remaining_points/card*(n+1)
        
        # number of elements per set
        number_of_el = int(remaining_points/number_of_sets)
        # WHAT IF NUMBER OF EL == 0??????
        # IT SHOULD NOT GET TO THIS POINT GIVEN THAT AT THE END THERE IS A IF

        X_tmp = np.empty((number_of_sets,n))
        # mu_tmp = np.empty(number_of_sets)

        idx = idx_story[:number_of_el*number_of_sets].reshape(number_of_el,-1)
        X_tmp = np.multiply(X[idx],mu[idx,np.newaxis]).sum(axis=0)
        tot_weights = np.sum(mu[idx],0)

        idx_last_part = idx_story[number_of_el*number_of_sets:]
        X_tmp[-1] += np.multiply(X[idx_last_part],mu[idx_last_part,np.newaxis]).sum(axis=0)
        tot_weights[-1] += np.sum(mu[idx_last_part],0)

        X_tmp = np.divide(X_tmp,tot_weights[np.newaxis].T)

        w_star, idx_star, _, _, ERR, _, _ = Tchernychova_Lyons_CAR_jit(X_tmp, np.copy(tot_weights))
        
        idx_tomaintain = idx[:,idx_star].reshape(-1)
        idx_tocancel = np.ones(idx.shape[1]).astype(bool)
        idx_tocancel[idx_star] = 0
        idx_tocancel = idx[:,idx_tocancel].reshape(-1)

        mu[idx_tocancel] = 0.
        mu_tmp = np.multiply(mu[idx[:,idx_star]],w_star)
        mu_tmp = np.divide(mu_tmp,tot_weights[idx_star])
        mu[idx_tomaintain] = mu_tmp.reshape(-1)

        idx_tmp = idx_star == number_of_sets-1
        idx_tmp = np.arange(len(idx_tmp))[idx_tmp!=0]
        #if idx_star contains the last barycenter, whose set could have more points
        if len(idx_tmp)>0:    
            mu_tmp = np.multiply(mu[idx_last_part],w_star[idx_tmp])
            mu_tmp = np.divide(mu_tmp,tot_weights[idx_star[idx_tmp]])
            mu[idx_last_part] = mu_tmp
            idx_tomaintain = np.append(idx_tomaintain,idx_last_part)
        else:
            idx_tocancel = np.append(idx_tocancel,idx_last_part)
            mu[idx_last_part] = 0.

        idx_story = np.copy(idx_tomaintain)
        remaining_points = len(idx_story)
        # remaining_points = np.sum(mu>0)



 
def Tchernychova_Lyons_CAR_jit(X,mu):
    # this functions reduce X from N points to n+1
    #np.seterr(invalid='ignore')
    # com = np.sum(np.multiply(X,mu[np.newaxis].T),0)
    #X = np.insert(X,0,1.,axis=1)
    X=np.hstack((np.ones((X.shape[0],1)),X))
    N, n = X.shape
    
    # work old null space Phi
    if (False): # SVD method
        U, Sigma, V = np.linalg.svd(X.T)
        Phi = V[-(N-n):,:].T
        #print(Phi.shape)
    if (True): # QR method
        Q, R, P = qr(X, mode='full', pivoting=True)
        tol = np.max(X) * np.finfo(R.dtype).eps
        rnk = min(X.shape) - np.abs(np.diag(R))[::-1].searchsorted(tol)
        Phi=Q[ -(N-n):,:].T
#        print("QR",Phi.shape)
    #   
    #cancelled = np.array([], dtype=int)
    
    for _ in range(N-n):
        
        # goal find small positive alpha
        alpha = mu/Phi[:,0]
        idx = np.arange(len(alpha))[Phi[:,0]>0]
        idx = idx[np.argmin(alpha[Phi[:,0]>0])]
        
        #cancelled = np.append(cancelled, idx)
    
        mu[:] = mu-alpha[idx]*Phi[:,0]
        mu[idx] = 0.

        # if DEBUG and (not np.allclose(np.sum(mu),1.)):
        #     # print("ERROR")
        #     print("sum ", np.sum(mu))
        
        Phi_tmp = Phi[:,0]
        tmp=Phi_tmp[idx]
        ###########################
        #Phi = np.delete(Phi,0,axis=1)
        Phi=Phi[:,1:]
        
        ######################
        #Phi = Phi - np.matmul(Phi[idx,np.newaxis].T,Phi_tmp[:,np.newaxis].T).T/Phi_tmp[idx]
        
        Phi=Phi-np.outer(Phi_tmp,Phi[idx])/Phi_tmp[idx]
        
        Phi[idx,:] = 0.
    
    w_star = mu[mu>0]
    idx_star = np.arange(N)[mu>0]
    return w_star, idx_star, np.nan, np.nan, 0., np.nan, np.nan


