import sys
import numpy as np
from numpy import random as rand
import pandas as pd
from multiprocessing import Pool
import functools, multiprocessing
from datetime import datetime
import math
from collections import OrderedDict
import collections
import time
from scipy.stats import multivariate_normal
from scipy.stats import norm
from experimental_design import get_XYdesign, get_oracle
from HeteroVar_two import two_spaces        



def AdaptiveXY(theta, X, Z, Sigma, heteroskedastic=False, oracle=False, delta = 0.05, log=False, seed=0, MVT = False):
    #Best-arm identification master function- can call for perform H-RAGE, RAGE and the oracle allocations.
    np.random.seed()
    now = datetime.now()
    current_time = now.strftime("%H:%M:%S")
    print("Current Time =", current_time)
    
    #setup
    n, d = X.shape
    z_n = Z.shape[0]
    iterations = 1000 #iterations for the Franke-Wolfe Algorithm    
    
    #compute the outer products ahead of time
    outers = np.matmul(X[:,:,np.newaxis], X[:,np.newaxis, :]) 
    
    #the truth
    opt_arm_index = np.argmax(Z@theta)
    opt_arm =  Z[opt_arm_index]
    potential_opt_set = list(range(n))
    total_arms = list(range(n))
    sample_counts = np.zeros(n)
    t = 0 

    #initialize lambda vector and best arms vector
    lambda_vec = np.array([1/n]*n)
    sigmas = np.array([X[i].T@Sigma@X[i] for i in range(X.shape[0])])
    sigma_max = np.max(sigmas)
    sigma_min = np.min(sigmas)
    initial_t = 0
    
    if not heteroskedastic:
        sigmas = np.max(sigmas)*np.ones(len(sigmas)) 
        estimated_sigmas = np.max(sigmas)*np.ones(len(sigmas)) 
    else:
        kappa = sigma_max/sigma_min
        c = np.sqrt(128*sigma_max/(np.log(n)+1))+np.sqrt(2/d)+np.sqrt(10/(np.log(n+1)+1+np.log(12/delta))) 
        t = int(2*c**2*d**2*np.log(12*n/delta)*kappa**2)
        #print(t)
        initial_t = t
        _, Sigma_hat = two_spaces(X, Sigma, theta, 0, 1000, t)
        estimated_sigmas = np.array([X[i].T@Sigma_hat@X[i] for i in range(X.shape[0])])
        #print("sigmas")
        #print(sigmas)
        #estimated_sigmas = sigmas.copy()
    
    #print(estimated_sigmas)

    #find the oracle lambda distribution 
    _, oracle_value, oracle_lambda = get_oracle(outers,X,Z,10000,opt_arm_index, theta, 0, lambda_vec.copy(), estimated_sigmas)
    print(oracle_lambda)
    #print("oracle lambda")
    #print(oracle_lambda)
     
    ell = 1
    theta_hat_prec = np.zeros((d,d))
    XtY = np.zeros(d)
    while len(potential_opt_set) > 1:

        theta_hat_prec = np.zeros((d,d))
        XtY = np.zeros(d)
        V = Z[potential_opt_set] #optimal arm set
        if oracle:
            eps_ell = 2**(-ell)
            potential_opt_set = list(range(n))
            lambda_vec = oracle_lambda
            #max here?
            sample_size = int(np.ceil(2*ell*oracle_value*np.log(4*ell**2*z_n/delta)))
        else:
            #print(V)
            #print(potential_opt_set)
            #print(X)
            _, design_value, lambda_vec = get_XYdesign(outers, V.copy(), X.copy(), iterations, 0.000001, lambda_vec.copy(), estimated_sigmas.copy()) 
            #print(design_value)
            eps_ell = 2**(-ell)
            if heteroskedastic:
                sample_size = int(np.ceil(3*eps_ell**(-2)*design_value*np.log(8*ell**2*z_n**2/delta)))
            else:
                sample_size = int(np.ceil(2*eps_ell**(-2)*design_value*np.log(8*ell**2*z_n**2/delta)))
        #print("lambda_vec")
        #print(lambda_vec)
        
        #arm_choice_index = rand.choice(total_arms, size=sample_size, p = lambda_vec/np.sum(lambda_vec))
        X_samples = np.ceil(sample_size*(lambda_vec/np.sum(lambda_vec)))
        arm_choice_index = sum([ [i]*int(X_samples[i]) for i in total_arms], [])
        sample_size = len(arm_choice_index)
        
        t += sample_size
        
        arm_choice = X[arm_choice_index]
        reward = arm_choice @ theta + np.random.normal(0, sigmas[arm_choice_index], sample_size)
        theta_hat_prec = np.sum([np.outer(X[arm_choice_i], X[arm_choice_i])/estimated_sigmas[arm_choice_i]
                                  for arm_choice_i in arm_choice_index], axis=0)
        XtY = np.sum([X[arm_choice_i]*reward[i]/estimated_sigmas[arm_choice_i]
                       for i, arm_choice_i in enumerate(arm_choice_index)], axis=0)
        unique, counts = np.unique(arm_choice_index, return_counts=True)
        round_sample_counts = dict(zip(unique, counts))
        for key in round_sample_counts:
            sample_counts[key] += round_sample_counts[key]
           
        theta_hat_cov = np.linalg.pinv(theta_hat_prec)
        theta_hat = theta_hat_cov @ XtY
        empirical_best_value = np.max(V@theta_hat)
        estimated_gaps = empirical_best_value-V@theta_hat
        if oracle:
            stop = True
            #why index this way?
            arm = Z[opt_arm_index, :, None]

            for arm_idx_prime in range(n):
                #continue if it's the optimal arm
                if opt_arm_index == arm_idx_prime:
                    continue
                
                arm_prime = Z[arm_idx_prime, :, None]
                y = arm - arm_prime
                if  0 >= y.T@theta_hat:
                        stop = False
                        break 
                        
            if stop:
                potential_opt_set = [opt_arm_index]
        else:
            #print(eps_ell)
            eliminated_arms = np.where(estimated_gaps > eps_ell)[0]
            #mistake here
            potential_opt_set = np.delete(potential_opt_set, eliminated_arms)
           
        if log:
            if oracle:
                print(now.strftime("%H:%M:%S"), 'sample size=', t)
            else:
                print(now.strftime("%H:%M:%S"), 'sample size=', t - initial_t, 'round ell', 
                      ell, 'gaps eliminated', eps_ell*2, 'potentially optimal arms',  potential_opt_set,
                      'estimated gaps', estimated_gaps, 'allocation',  sample_counts/t)

        #Increment
        ell+=1
    return potential_opt_set[0], 2*oracle_value*np.log(n/delta), t


                              
