import sys
import os
from itertools import *
import numpy as np
import cmath
import math
import scipy.integrate as integrate


import time
from copy import deepcopy





######## Linear algebra tools ###########


def updateInvK(rKernel,invK,seq,newNode,eval_matrix):
    L = invK.shape[0]
    eval_X_vector = (rKernel.eval_at_X(newNode)).reshape((rKernel.N,1))
    kernel_vector = np.dot(eval_matrix,eval_X_vector)
    C = kernel_vector.reshape((L,1))
    A = np.array(np.power(np.linalg.norm(eval_X_vector),2)).reshape((1,1))
    updatedInvK = inverseSchur(A,C.T,C,invK) 
    return updatedInvK

def inverseSchur(A,B,C,invD):
    d = invD.shape[0] 
    a = A.shape[0]
    arrOutput = np.zeros((d+a,d+a))
    S = A - np.dot(np.dot(B, invD), C)
    invS = np.linalg.inv(S)
    BinvD = np.dot(B, invD)
    invDC = np.dot(invD, C)
    arrOutput[d:, d:] = invS
    arrOutput[d:, :d] = -np.dot(invS, BinvD)
    arrOutput[:d, d:] = -np.dot(invDC, invS)
    arrOutput[:d, :d] = invD + np.dot(np.dot(invDC, invS), BinvD)  
    return arrOutput


#######################################################

########### Projection kernel tools ###################




class Fourier_Element():
    def __init__(self,order):
        self.order = order
    def evaluate_at_X(self,X):
        if self.order ==0:
            return 1
        elif self.order >0:
            order_parity = self.order %2
            if order_parity ==0:
                return np.sqrt(2)*math.cos(2*math.pi*(self.order/2)*X)
            else:
                return np.sqrt(2)*math.sin(2*math.pi*((self.order+1)/2)*X)


class Multi_Fourier_Element():
    def __init__(self,multi_order):
        self.dim = len(multi_order)      ## The dimension
        self.multi_order = multi_order   ## Multi order is a list
        self.elements = self.calculate_elements()
    def calculate_elements(self):
        output_list = []
        if max(self.multi_order) == 0:
            Fourier_element_0 = Fourier_Element(0)
            for i in list(range(self.dim)):    
                output_list.append(Fourier_element_0)
            return output_list
        else:
            order_max = max(self.multi_order)
            Fourier_elements_list = []
            for n in list(range(order_max+1)):
                Fourier_elements_list.append(Fourier_Element(n))
            for i in list(range(self.dim)):    
                output_list.append(Fourier_elements_list[self.multi_order[i]])
            return output_list
    def evaluate_at_X(self,X):
        output_var = 1
        for i in list(range(self.dim)):
            output_var = output_var*(self.elements[i].evaluate_at_X(X[i]))
        return output_var



class multi_Fourier_projection_kernel:
    def __init__(self,elements_list,N):
        self.N = N
        self.elements_list = elements_list
        self.my_kernel = self.Fourier_K_NN_kernel()

    def Fourier_K_NN_function(self):
        def Fourier_K_NN_function_aux(X):
            evaluation_list_tmp = [np.power(self.elements_list[n].evaluate_at_X(X),2) for n in list(range(self.N))]
            Fourier_kernel_diagonal_evaluation = np.sum(evaluation_list_tmp)#/self.N
            return Fourier_kernel_diagonal_evaluation
        return Fourier_K_NN_function_aux

    def Fourier_K_NN_kernel(self):
        def Fourier_K_NN_kernel_aux(X,Y):
            evaluation_list_tmp_X = np.asarray([self.elements_list[n].evaluate_at_X(X) for n in list(range(self.N))])
            evaluation_list_tmp_Y = np.asarray([self.elements_list[n].evaluate_at_X(Y) for n in list(range(self.N))])
            Fourier_kernel_evaluation = np.dot(evaluation_list_tmp_X,evaluation_list_tmp_Y)
            return Fourier_kernel_evaluation
        return Fourier_K_NN_kernel_aux

    def get_kernel_matrix_projection_kernel(self,seq):
        M = len(seq) 
        my_kernel = self.Fourier_K_NN_kernel()
        kernel_matrix = np.zeros((M,M))
        for n1 in list(range(M)):
            for n2 in list(range(M)):
                kernel_matrix[n1,n2] = my_kernel(seq[n1],seq[n2])
        return kernel_matrix
    def eval_at_X(self,X):
        evaluation_list_tmp_X = np.asarray([self.elements_list[n].evaluate_at_X(X) for n in list(range(self.N))])
        return evaluation_list_tmp_X

    def Fourier_K_NN_function_conditional(self,seq,invK,eval_matrix):
        
        def Fourier_K_NN_function_conditional_aux(X):
            eval_X_vector = (self.eval_at_X(X)).reshape((self.N,1))
            kernel_vector = np.dot(eval_matrix,eval_X_vector)
            Fourier_K_NN_function_conditional_evaluation = (np.power(np.linalg.norm(eval_X_vector),2) - np.dot(np.dot(kernel_vector.T,invK),kernel_vector))
            return Fourier_K_NN_function_conditional_evaluation[0][0]
        return Fourier_K_NN_function_conditional_aux




########### DPP sampler ###################



def diagonal_sampler(H,N,M,d,mode,multi_indices,Fourier_elements_list,N_perfect):
    # N is the order of the projection kernel
    # s_p is the square root of the variance of the proposal
    # M is the number of samples

    ratio_to_perfect = N_perfect/N
    m_Fourier_projection_kernel_0 = multi_Fourier_projection_kernel(Fourier_elements_list,N)
    m_Fourier_projection_kernel_1 = multi_Fourier_projection_kernel(Fourier_elements_list,N_perfect)
    KNN_function = m_Fourier_projection_kernel_0.Fourier_K_NN_function()
    KNN_function_perfect = m_Fourier_projection_kernel_1.Fourier_K_NN_function()
    output_list = [0]*M
    for m in list(range(M)):
        u = np.random.uniform(low=0,high=1)
        if u > ratio_to_perfect:
            x = np.random.uniform(low=0,high=1,size=d)
            evaluation_x = KNN_function(x)-KNN_function_perfect(x)
            R_bound = np.power(2,d)*(N-N_perfect+1)
            ratio = evaluation_x/R_bound
            b = np.random.binomial(1, ratio, 1)
            while b ==0:
                x = np.random.uniform(low=0,high=1,size=d)
                evaluation_x = KNN_function(x)-KNN_function_perfect(x)
                R_bound = np.power(2,d)*(N-N_perfect+1)
                ratio = evaluation_x/R_bound
                b = np.random.binomial(1, ratio, 1)
        else:
            x = np.random.uniform(low=0,high=1,size=d)
        output_list[m] = x

    if M ==1:
        return output_list[0]
    else:
        return output_list




def dpp_sampler(N,d,mode,multi_indices,N_perfect):
    # N is the order of the projection kernel
    # s_p is the square root of the variance of the proposal
    # M is the number of samples
    #H = 2*np.power(np.sqrt(2),d)
    H = np.power(np.sqrt(2),d)
    eval_matrix = np.zeros((N,N))
    Fourier_elements_list = []
    for n in list(range(N)):
        Fourier_elements_list.append(Multi_Fourier_Element(list(multi_indices[n])))
    x_1 = diagonal_sampler(H,N,1,d,mode,multi_indices,Fourier_elements_list,N_perfect)
    seq = [x_1]
    seq_counter = 1
    m_Fourier_projection_kernel_0 = multi_Fourier_projection_kernel(Fourier_elements_list,N)
    eval_matrix[0,:] =  m_Fourier_projection_kernel_0.eval_at_X(x_1) 
    KNN_function = m_Fourier_projection_kernel_0.Fourier_K_NN_function()
    invK = np.array(1/(KNN_function(x_1))).reshape((1,1))
    while seq_counter < N:
        x = diagonal_sampler(H,N,1,d,mode,multi_indices,Fourier_elements_list,N_perfect)
        K_seq_function = m_Fourier_projection_kernel_0.Fourier_K_NN_function_conditional(seq,invK,eval_matrix[0:seq_counter,:])
        evaluation_x = K_seq_function(x)
        ratio = np.abs(evaluation_x/KNN_function(x))
        b = np.random.binomial(1, ratio, 1)
        counter = 0
        while b ==0:
            x = diagonal_sampler(H,N,1,d,mode,multi_indices,Fourier_elements_list,N_perfect)
            evaluation_x = K_seq_function(x)
            ratio = np.abs(evaluation_x/KNN_function(x))
            b = np.random.binomial(1, ratio, 1)
            counter +=1
        eval_matrix[seq_counter,:] =  m_Fourier_projection_kernel_0.eval_at_X(x)
        invK = updateInvK(m_Fourier_projection_kernel_0,invK,seq,x,eval_matrix[0:seq_counter,:])
        seq.append(np.asarray(x))
        seq_counter = seq_counter +1
    return seq


def get_N_perfect_1D(multi_indices):
    output_list = []
    for e in multi_indices:
        n = e[0]
        if n ==0 or n ==1 or n ==2:
            my_tuple = (0,)
            output_list.append(my_tuple)
        else:
            if n%2==0:
                my_tuple = (n-2,)
                output_list.append(my_tuple)
            else:
                my_tuple = (n-1,)
                output_list.append(my_tuple)
    return output_list

