import pickle
import sys
import os
import numpy as np
from numpy import random as rand
import time
import functools, multiprocessing
import pandas as pd


#HELPER FUNCTIONS

def G_optimal(X, threshold, iterations):
    #Construct G-optimal design
    total_arms = X.shape[0]
    lambda_vec = np.array([1/total_arms]*total_arms)
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 
    return FrankWolfe_XY(outers, X, X, iterations, threshold, lambda_vec.copy()) #get G_optimal-design 



def Sample_X(X, threshold, iterations, samples):
    #sample according to G-optimal design
    total_arms = X.shape[0]
    Phi = make_Cov_space(X)
    samples = samples//2
    cov_A_phi, y_max_val_phi, lambda_vec_phi = G_optimal(Phi,threshold,iterations)
    #print("lambda_vec_phi")
    #print(lambda_vec_phi)
    cov_A, y_max_val, lambda_vec = G_optimal(X,threshold,iterations)
    #print("lambda_vec")
    #print(lambda_vec)
    #theta_choice = rand.choice(total_arms, size = (1, samples), p = lambda_vec/np.sum(lambda_vec))[0]
    theta_lambda = lambda_vec/np.sum(lambda_vec)
    #X_samples = np.ceil(samples*(lambda_vec/np.sum(lambda_vec)))
    #theta_choice = sum([ [i]*int(X_samples[i]) for i in range(total_arms)], [])
    #Sigma_choice = rand.choice(total_arms, size = (1,samples), p = lambda_vec_phi/np.sum(lambda_vec_phi))[0]
    Sigma_lambda = lambda_vec_phi/np.sum(lambda_vec_phi)
    #Sigma_samples = np.ceil(samples*(lambda_vec_phi/np.sum(lambda_vec_phi)))
    #Sigma_choice = sum([ [i]*int(Sigma_samples[i]) for i in range(total_arms)], [])
    #X_theta = X[theta_choice]
    #X_Sigma = X[Sigma_choice]
    return theta_lambda, Sigma_lambda



def make_Cov_space(X):
    #create phi_x vectors out of x vectors
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    Phi = np.zeros([total_arms, phi_d])
    for i in range(total_arms):
        outer_x = np.outer(X[i], X[i])
        dia = np.diag(outer_x)
        outer_triu = outer_x[np.triu_indices(d,1)]
        phi_x = np.concatenate([np.multiply(2,outer_triu), dia])
        #print(phi_x)
        Phi[i] = phi_x
    return Phi

def inv_min_eig_Phi(X, Phi, iterations):
    #Optimizes arms for the Separate Arm Estimator
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    current_inv_eig = float('inf')
    current_choice = np.arange(phi_d)
    n = X.shape[0]
    high_magnitude = int(n//10)
    high_magnitude = max(high_magnitude, phi_d+20)
    for i in range(iterations):
        test_choice = rand.choice(high_magnitude, size = (1, phi_d), replace = False)[0]
        test_Phi = Phi[test_choice]
        P,D,Q = np.linalg.svd(test_Phi)
        if min(D) != 0:
            test_inv_min_eig = 1/min(D)
            if test_inv_min_eig < current_inv_eig:
                current_inv_eig = test_inv_min_eig
                current_choice = test_choice
    print(len(current_choice))
    print(phi_d)
    return current_choice, current_inv_eig

#EXPERIMENTAL DESIGN

def FrankWolfe_XY(outers, X, Y, iterations, threshold, warm_start, sigmas=None):
    #Frank Wolfe optimization that takes the Y matrix of differences
    
    old_y_max_val = 1
    lambda_vec = warm_start
    if sigmas is None:
        sigmas = np.ones(len(lambda_vec))
    
    sigmas_diag_inv = np.linalg.inv(np.diag(sigmas))
    
    for k in range(1,iterations):
        #compute design
        A_lambda = np.sum(outers*lambda_vec[:,np.newaxis, np.newaxis]/sigmas[:,np.newaxis, np.newaxis], axis=0)
        
        #compute pseudo-inverse if singular
        if np.linalg.det(A_lambda) == 0:
            #print("singular")
            cov_A = np.linalg.pinv(A_lambda)
        else:
            cov_A = np.linalg.inv(A_lambda)
            
        #determine max
        diag_arg = (Y @cov_A * Y).sum(-1)
        y_max = Y[np.argmax(diag_arg)] #index of max predictive uncertainty for differences
        y_max_val = np.max(diag_arg) #value of max predictive uncertainty
        #y_max = Y[np.argmax(np.diag(Y @ cov_A @ Y.T))] #index of max predictive uncertainty for differences
        #y_max_val = np.max(np.diag(Y @ cov_A @ Y.T)) #value of max predictive uncertainty
        lambda_derivative = -(y_max.T @ cov_A @ X.T @ np.sqrt(sigmas_diag_inv))**2 #compute derivative 
        
        #update lambda vector
        alpha = 2/(k+2) #step size
        min_lambda_derivative_index = np.argmin(lambda_derivative)
        
        #Frank-Wolfe update
        lambda_vec -= alpha*lambda_vec
        lambda_vec[min_lambda_derivative_index] +=  alpha
        
        if y_max_val == 0 or abs((old_y_max_val - y_max_val)/old_y_max_val) < threshold: #threshold criterion for stopping 
            break
        old_y_max_val = y_max_val #storage for threshold criterion
    #print(y_max_val)
    return cov_A, y_max_val, lambda_vec

#ESTIMATORS

def two_spaces(X, Sigma, theta, threshold, iterations, samples):
    #HEAD Estimator
    d = X.shape[1]
    Phi = make_Cov_space(X)
    samples = samples//2
    X_theta, X_Sigma, Sigma_choice, theta_choice = Sample_X(X, Phi, threshold, iterations, samples)
    Phi_Sigma = Phi[Sigma_choice]
    
    X_theta_noise = (X_theta@Sigma * X_theta).sum(-1)
    samples_X = X_theta.shape[0]
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples_X))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    X_Sigma_noise = (X_Sigma @ Sigma * X_Sigma).sum(-1)
    
    samples_Phi = X_Sigma.shape[0]
    Y_Sigma = X_Sigma @ theta + np.multiply(np.sqrt(X_Sigma_noise),rand.randn(samples_Phi))
    Sigma_SE = np.square(Y_Sigma - X_Sigma @ theta_hat_G)
    
    vech_Sigma_hat_G =  np.linalg.inv(Phi_Sigma.T @ Phi_Sigma) @ Phi_Sigma.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    return Phi_Sigma, X_Sigma, X_theta, Sigma_hat_G

def White_Estimator(X, Sigma, theta, threshold, iterations, samples):
    #Uniform Estimator
    total_arms = X.shape[0]
    d = X.shape[1]
    Phi = make_Cov_space(X)
    #cov_A, y_max_val, lambda_vec = G_optimal(X,threshold,iterations)
    choice = rand.choice(total_arms, size = (1, samples))[0]
    Phi_theta = Phi[choice]
    X_theta = X[choice]
    
    #pull arms and find thetaHat
    X_theta_noise = (X_theta @ Sigma * X_theta).sum(-1)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    #w, v = np.linalg.eigh(Sigma_hat_G)
    #w[w < 0] = 0
    #Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return Sigma_hat_G

def Alt_est(X, Sigma, theta, threshold, iterations, samples):
    #Separate Arm Estimator
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    Phi = make_Cov_space(X)
    current_choice, current_inv_eig =  inv_min_eig_Phi(X, Phi, iterations)

    samples_int = samples//phi_d
    samples_vec = [samples_int]*phi_d
    for i in range(samples % phi_d):
        samples_vec[i] += 1
    #print(samples_vec)
    Phi_samples =  Phi[np.repeat(current_choice,samples_vec)]
    Sigma_SE = []
    for i in range(phi_d):
        M_samples = samples_vec[i]
        arms_i = X[[current_choice[i]]*M_samples]
        Y_i = arms_i @ theta + np.multiply([np.sqrt(X[current_choice[i]].T @ Sigma @ X[current_choice[i]])]*M_samples, 
                                           rand.randn(M_samples))
        Y_bar_i = np.mean(Y_i)
        Sigma_SE_i = list(np.square(Y_i - Y_bar_i))

        Sigma_SE += Sigma_SE_i
        
    vech_Sigma_hat = np.linalg.inv(Phi_samples.T @ Phi_samples) @ Phi_samples.T @ Sigma_SE
    
    upper_Sigma_hat = np.zeros((d,d))
    upper_Sigma_hat[np.triu_indices(d,1)] = vech_Sigma_hat[0:-d]
    Sigma_hat = upper_Sigma_hat + upper_Sigma_hat.T
    np.fill_diagonal(Sigma_hat, vech_Sigma_hat[-d:len(vech_Sigma_hat)])
    
    print(current_inv_eig)

    
    return current_choice, Sigma_hat

#SAMPLING

def unit_sphere_samp(arms, d):
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (arms,1)
    return X/norms

def two_spheres_samp(arms, d, p):
    #Arm space sampler for the two spheres setting in the paper
    big_arms = int(arms*p)
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), big_arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (big_arms,1)
    big_X = X/norms
    
    small_arms = int(arms*(1-p))
    X = rand.multivariate_normal([0]*d, np.diag([1]*d), small_arms)
    norms = np.linalg.norm(X, axis=1)
    norms.shape = (small_arms,1)
    small_X = X/ (10*norms)
    
    return np.concatenate((big_X, small_X), axis=0)


#MASTER FUNCTION

def run_sample_sim(arms, homogeneous, threshold, iterations, two_spheres, dimension_vec, p_vec, sim= None):
    #simulator function for multiple variances estimators
    np.random.seed()
    
    samples_vec = np.multiply(np.array((list(range(1,20)))),5000)
    total_sims = len(samples_vec)*len(p_vec)*len(dimension_vec)*3 #three for the three estimators
    df = pd.DataFrame(columns = ['Samples', 'ProportionLS', 'Dimension','Estimator', "MaxError", "Sim"])
    
    master_counter = 0
    Estimators = [Alt_est, two_spaces, White_Estimator] #could make an input at some point
    #Estimators = [Alt_est]
    Estimator_names = ["Independent Arms", "G-optimal", "White Estimator"]
    #Estimator_names = ["White Estimator"]
    
    start = time.time()
    Start_internal = time.time()
    
    for d in dimension_vec:
        print(d)
        for p in p_vec:
            if two_spheres:
                X = two_spheres_samp(arms, int(d), float(p))
            else:
                X = unit_sphere_samp(arms, int(d), float(p))
            theta_vec = [1,1]*100
            theta = theta_vec[:d]
            
            if homogeneous:
                Sigma_vec = [1.0]*100
            else:
                Sigma_vec = [0.1,1]*100
            Sigma = np.diag(Sigma_vec[:d])
            
            dummy_samples = 1000
            Alt_current_choice, _ = Alt_est(X, Sigma, theta, threshold, iterations, dummy_samples)
            #White_X_theta, White_Phi_theta, _   = White_Estimator(X, Sigma, theta, threshold, iterations, samples)
            theta_lambda, Sigma_lambda = Sample_X(X, threshold, iterations, dummy_samples)
            
            for samples in samples_vec:
                for k in range(len(Estimators)):
                    
                    if k == 0:
                        Sigma_hat = Alt_est_static(Alt_current_choice, X, Sigma, theta, threshold, iterations, samples)
                    if k == 1:
                        Sigma_hat = two_spaces_static(theta_lambda, Sigma_lambda, X, Sigma, theta, threshold, iterations, samples)
                    if k == 2:
                        Sigma_hat = White_Estimator(X, Sigma, theta, threshold, iterations, samples)
                    
                    #print("k")
                    #print(k)
                    #print("Sigma_hat")
                    #print(Sigma_hat.shape)
                    
                    raw_estimates = np.diag(X @ Sigma_hat @ X.T)
                    sigma_max = np.max(np.diag(X @ Sigma @ X.T))
                    sigma_min = np.min(np.diag(X @ Sigma @ X.T))
                    estimates = [0]*arms

                    for i in range(arms):
                        if raw_estimates[i] > sigma_max:
                            estimates[i] = sigma_max
                        elif raw_estimates[i] < sigma_min:
                            estimates[i] = sigma_min
                        else:
                            estimates[i] = raw_estimates[i]
                    
                    estimates = np.array(estimates)
                    
                    df.loc[master_counter, 'Samples']  = samples
                    df.loc[master_counter, 'ProportionLS'] = p
                    df.loc[master_counter, 'Dimension'] = d
                    df.loc[master_counter, 'Estimator'] = Estimator_names[k]
                    df.loc[master_counter, 'MaxError'] = np.max(np.absolute(estimates - np.diag(X @ Sigma @ X.T)))
                    df.loc[master_counter, "Sim"] = sim
                    master_counter += 1
                    print(master_counter)
    print(time.time() - start)
    return df


def two_spaces_static(theta_lambda, Sigma_lambda, X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    total_arms = X.shape[0]
    d = X.shape[1]
    Phi = make_Cov_space(X)
    samples = samples//2
    
    theta_choice = rand.choice(total_arms, size = (1, samples), p = theta_lambda)[0]
    Sigma_choice = rand.choice(total_arms, size = (1,samples), p = Sigma_lambda)[0]
    
    X_theta = X[theta_choice]
    X_Sigma = X[Sigma_choice]
    
    Phi_Sigma = Phi[Sigma_choice]
    
    X_theta_noise = (X_theta@Sigma * X_theta).sum(-1)
    samples_X = X_theta.shape[0]
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples_X))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    X_Sigma_noise = (X_Sigma @ Sigma * X_Sigma).sum(-1)
    
    samples_Phi = X_Sigma.shape[0]
    Y_Sigma = X_Sigma @ theta + np.multiply(np.sqrt(X_Sigma_noise),rand.randn(samples_Phi))
    Sigma_SE = np.square(Y_Sigma - X_Sigma @ theta_hat_G)
    
    vech_Sigma_hat_G =  np.linalg.inv(Phi_Sigma.T @ Phi_Sigma) @ Phi_Sigma.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    return Sigma_hat_G


def White_Estimator_static(X_theta, Phi_theta, X, Sigma, theta, threshold, iterations, samples):
    #Sample split algorithm
    total_arms = X.shape[0]
    d = X.shape[1]
    
    #pull arms and find thetaHat
    X_theta_noise = (X_theta @ Sigma * X_theta).sum(-1)
    Y_theta = X_theta @ theta + np.multiply(np.sqrt(X_theta_noise),rand.randn(samples))
    theta_hat_G = np.linalg.inv(X_theta.T @ X_theta) @ X_theta.T @ Y_theta
    
    Sigma_SE = np.square(Y_theta - X_theta@ theta_hat_G)
    vech_Sigma_hat_G =  np.linalg.inv(Phi_theta.T @ Phi_theta) @ Phi_theta.T @ Sigma_SE
    
    upper_Sigma_hat_G = np.zeros((d,d))
    upper_Sigma_hat_G[np.triu_indices(d,1)] = vech_Sigma_hat_G[0:-d]
    Sigma_hat_G = upper_Sigma_hat_G + upper_Sigma_hat_G.T
    np.fill_diagonal(Sigma_hat_G, vech_Sigma_hat_G[-d:len(vech_Sigma_hat_G)])
    
    #Project onto positive definite cone
    #w, v = np.linalg.eigh(Sigma_hat_G)
    #w[w < 0] = 0
    #Sigma_hat_G = v @ np.diag(w) @ v.T
    
    return Sigma_hat_G

def Alt_est_static(current_choice, X, Sigma, theta, threshold, iterations, samples):
    d = X.shape[1]
    total_arms = X.shape[0]
    phi_d = int(d*(d+1)/2)
    Phi = make_Cov_space(X)
    
    samples_int = samples//phi_d
    samples_vec = [samples_int]*phi_d
    for i in range(samples % phi_d):
        samples_vec[i] += 1
    #print(samples_vec)
    Phi_samples =  Phi[np.repeat(current_choice,samples_vec)]

    Sigma_SE = []
    for i in range(phi_d):
        M_samples = samples_vec[i]
        arms_i = X[[current_choice[i]]*M_samples]
        Y_i = arms_i @ theta + np.multiply([np.sqrt(X[current_choice[i]].T @ Sigma @ X[current_choice[i]])]*M_samples, 
                                           rand.randn(M_samples))
        Y_bar_i = np.mean(Y_i)
        Sigma_SE_i = list(np.square(Y_i - Y_bar_i))

        Sigma_SE += Sigma_SE_i
        
    vech_Sigma_hat = np.linalg.inv(Phi_samples.T @ Phi_samples) @ Phi_samples.T @ Sigma_SE
    
    upper_Sigma_hat = np.zeros((d,d))
    upper_Sigma_hat[np.triu_indices(d,1)] = vech_Sigma_hat[0:-d]
    Sigma_hat = upper_Sigma_hat + upper_Sigma_hat.T
    np.fill_diagonal(Sigma_hat, vech_Sigma_hat[-d:len(vech_Sigma_hat)])

    return Sigma_hat