# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import numpy as np
import utils
from ray.util.multiprocessing import Pool

            
class Algorithm:
    def __init__(self, instance, sampling_method, elimination_method, design_method, 
                 delta, alg_type, burn_in_sampling_method, burn_in_length, reuse_gamma, horizon):
        self.instance = instance
        self.sampling_method = sampling_method
        self.elimination_method = elimination_method
        self.design_method = design_method
        self.delta = delta
        self.alg_type = alg_type
        self.burn_in_sampling_method = burn_in_sampling_method 
        self.burn_in_length = burn_in_length
        self.horizon = horizon
        self.reuse_gamma = reuse_gamma
        self.verbose = True
        self.T = 0 

        
    def run_ucb(self):
        

        A = self.instance.W.copy()
        reg = 1
        m1 = np.linalg.norm(self.instance.theta, 2)
        
        Z_sample_history = []
        X_sample_history = []
        Y_sample_history = []
        ZtZ = reg*np.eye(self.instance.d)
        det = np.linalg.det(ZtZ)
        ZtZ_inv = np.linalg.inv(ZtZ)
        Zty = np.zeros((self.instance.d, 1))
        theta_hat = np.zeros(self.instance.d)
        self.ell = 1
        self.estimate_gamma()
        
        while self.T < self.horizon:

            # Select z's
            beta_root = m1*np.sqrt(reg)+np.sqrt(2*np.log(1/self.delta)+np.log(det/(reg**(self.instance.d)))) 
            conf = beta_root*np.sqrt(np.einsum('ij,ij->i',self.instance.Z@ZtZ_inv, self.instance.Z.T))
            ucbs = (self.instance.Z@theta_hat).flatten() + conf
            self.Z_idx_sample = np.argmax(ucbs)
            self.Z_sample = self.instance.Z[self.Z_idx_sample, :, None]
            

            # generate data
            X_sample, Y_sample = self.generate_data(1, return_observations=True)
            ZtZ += self.Z_sample@self.Z_sample.T
            Zty += self.Z_sample*Y_sample     
            det = det*(1+self.Z_sample.T@ZtZ_inv@self.Z_sample)
            ZtZ_inv = ZtZ_inv - (ZtZ_inv@self.Z_sample@self.Z_sample.T@ZtZ_inv)/(1+self.Z_sample.T@ZtZ_inv@self.Z_sample)
            theta_hat = ZtZ_inv@Zty
                                          
            Z_sample_history.append(self.Z_sample)
            X_sample_history.append(X_sample)
            Y_sample_history.append(Y_sample)
        
            self.T += 1
            
        self.Z_sample = np.hstack(Z_sample_history).T
        self.X_sample = np.vstack(X_sample_history)
        self.Y_sample = np.vstack(Y_sample_history)

        self.Gamma_hat_empirical = np.linalg.pinv(self.Z_sample.T@self.Z_sample)@self.Z_sample.T@self.X_sample

        # update data.
        self.bulk_update_estimates()
        
        
    def run_phased_alg(self):
        
        
        # This is the active set.
        self.A = self.instance.W.copy()
        Z_sample_history = []
        X_sample_history = []
        Y_sample_history = []
        
        self.ell = 1        
        
        while len(self.A) > 1:
            
            self.estimate_gamma()
            

            l, rho = self.get_design(sampling_purpose='Theta')
            if self.sampling_method == 'uniform' and 'stop' not in self.elimination_method:
                l = .975*l+ .025*np.ones(self.instance.d)/self.instance.d 
                X_hat = self.instance.Z@self.instance.Gamma
                V = utils.get_diff(self.A)
                indices = np.any(np.all(self.instance.Z[:, None, :] == self.A, axis=2), axis=1)
                indices = np.where(indices)[0]
                subset = indices.tolist()
                rho = utils.get_value_given_design(X_hat, V, l)

            self.epsilon_ell = 2**(-self.ell)
            if self.sampling_method == 'oracle':
                horizon = int(np.ceil(0.5*self.instance.subgaussian*self.epsilon_ell**(-2)*rho*np.log(4*self.ell**2*self.instance.K/self.delta)))
            else:
                horizon = int(np.ceil(2*self.instance.subgaussian*self.epsilon_ell**(-2)*rho*np.log(4*self.ell**2*self.instance.K**2/self.delta)))
            
            if self.alg_type == 'fixed_budget':
                horizon = min(horizon, self.horizon-self.T)
            
            # if self.sampling_method == 'oracle':
            # print('sampling method: ', self.sampling_method, 
            #         '; elimination method: ', self.elimination_method,
            #         '; design_method: ', self.design_method, '; design: ', l)
            print(self.ell, horizon, l, rho)

            self.Z_idx_sample = np.random.choice(self.instance.K, p=l, size=horizon)        
            self.Z_sample = self.instance.Z[self.Z_idx_sample]
        
            # generate data
            X_sample, Y_sample = self.generate_data(horizon, return_observations=True)
            self.Gamma_hat_empirical = np.linalg.pinv(self.Z_sample.T@self.Z_sample)@self.Z_sample.T@self.X_sample
            
            self.update_estimates()

            # eliminate
            if self.elimination_method == 'p-2sls':
                self.A = self.eliminate(self.A.copy())
            elif self.elimination_method == 'p-2sls-stop':
                A_new = self.eliminate(self.A.copy())
                if len(A_new) == 1:
                    self.A = A_new
            elif self.elimination_method == 'oracle':
                self.A = self.oracle_eliminate(self.A.copy())
            elif self.elimination_method == 'oracle-stop':
                A_new = self.oracle_eliminate(self.A.copy())
                if len(A_new) == 1:
                    self.A = A_new
                    
            # increment
            self.ell += 1
            self.T += horizon
            
            if self.alg_type == 'fixed_budget':
                Z_sample_history.append(self.Z_sample)
                X_sample_history.append(X_sample)
                Y_sample_history.append(Y_sample)
                if self.T >= self.horizon:
                    break
            
        if self.alg_type == 'fixed_budget':
            self.Z_sample = np.vstack(Z_sample_history)
            self.X_sample = np.vstack(X_sample_history)
            self.Y_sample = np.vstack(Y_sample_history)
            self.bulk_update_estimates()
        elif self.alg_type == 'fixed_conf':
            self.accuracy = np.array_equal(self.A[0].flatten(), self.instance.opt_w)
            self.stop_time = self.T
        
        
    def estimate_gamma(self):
        
        if self.ell != 1 and self.burn_in_length != 'adaptive':
            return 
        
        if self.burn_in_length is None:
            self.T += 0
            self.Gamma_hat = self.instance.Gamma
        elif self.burn_in_length == 'adaptive':
            if self.ell == 1:
                self.Gamma_hat = np.eye(self.instance.d)
            total = 0
            # E-optimal design 
            l, min_singular_value = utils.get_Edesign(self.instance.Z)
            # min_eigenvalue = sorted(np.linalg.eigvals(self.instance.Gamma).tolist())[0]
            #log_bar = self.get_log_bar()
            #horizon = int(np.ceil(log_bar/(min_eigenvalue**2 * min_singular_value)))
            horizon = self.instance.d**2 + self.instance.d 
            self.T += horizon
            total += horizon
            self.Z_idx_sample = np.random.choice(self.instance.K, p=l, size=horizon) 
            self.Z_sample = self.instance.Z[self.Z_idx_sample]
            Z_sample_0 = self.Z_sample.copy()
            X_sample_0, Y_sample_0 = self.generate_data(horizon, return_observations=True)

            # XY design. 
            l, rho = self.get_design(sampling_purpose='Gamma')
            support_l = int(np.sum(np.round(l, 4)>0))
            t_norm_2 = np.linalg.norm(self.instance.theta, 2)**2
            j = 1
            if self.ell == 1 or not self.reuse_gamma:
                condition_val = np.inf
            else:        
                self.Gamma_hat = np.linalg.inv(self.Z_sample_GH.T@self.Z_sample_GH)@self.Z_sample_GH.T@self.X_sample_GH
                max_var = np.max(self.get_diff_var(self.Gamma_hat, self.Z_sample_GH)[2])
                log_bar = self.get_log_bar()
                condition_val = np.sqrt(t_norm_2*max_var*log_bar)
            
            while (self.ell == 1 and j==1) or condition_val>2**(-self.ell):
                
                horizon = 2**j*support_l 
                self.T += horizon
                total += horizon
                
                self.Z_idx_sample = np.random.choice(self.instance.K, p=l, size=horizon)        
                self.Z_sample = self.instance.Z[self.Z_idx_sample]
                Z_sample_1 = self.Z_sample.copy()
                X_sample_1, Y_sample_1 = self.generate_data(horizon, return_observations=True)
                
                if (self.ell == 1 and j==1) or not self.reuse_gamma:
                    self.Z_sample_GH = np.vstack([Z_sample_0, Z_sample_1])
                    self.X_sample_GH = np.vstack([X_sample_0, X_sample_1])
                else:
                    if j == 1:
                        self.Z_sample_GH = np.concatenate((self.Z_sample_GH, np.vstack([Z_sample_0, Z_sample_1])))
                        self.X_sample_GH = np.concatenate((self.X_sample_GH, np.vstack([X_sample_0, X_sample_1])))
                    else:
                        self.Z_sample_GH = np.concatenate((self.Z_sample_GH, np.vstack([Z_sample_1])))
                        self.X_sample_GH = np.concatenate((self.X_sample_GH, np.vstack([X_sample_1])))
                self.Gamma_hat = np.linalg.inv(self.Z_sample_GH.T@self.Z_sample_GH)@self.Z_sample_GH.T@self.X_sample_GH
                max_var = np.max(self.get_diff_var(self.Gamma_hat, self.Z_sample_GH)[2])
                log_bar = self.get_log_bar()
                condition_val = np.sqrt(t_norm_2*max_var*log_bar)

                j += 1
                
        else:
            self.T += self.burn_in_length
            l = np.ones(self.instance.K)/self.instance.K
            self.Z_idx_sample = np.random.choice(self.instance.K, p=l, size=self.burn_in_length)
            self.Z_sample = self.instance.Z[self.Z_idx_sample]
            self.X_sample_GH, self.Y_sample_GH = self.generate_data(self.burn_in_length, return_observations=True)
            self.Z_sample_GH = self.Z_sample.copy()
            self.Gamma_hat = np.linalg.inv(self.Z_sample_GH.T@self.Z_sample_GH)@self.Z_sample_GH.T@self.X_sample_GH
            self.epsilon_threshold = 6*self.get_bias_term()
        
    def get_design(self, sampling_purpose='Theta'):
        
        if self.design_method == 'oracle':
            X_hat = self.instance.Z@self.instance.Gamma
        elif self.design_method == 'empirical':
            X_hat = self.instance.Z@self.Gamma_hat
        elif self.design_method == 'regular':
            X_hat = self.instance.Z

        # get the sampling design.
        if sampling_purpose == 'Gamma':
            sampling_method = self.burn_in_sampling_method 
        elif sampling_purpose == 'Theta':
            sampling_method = self.sampling_method
            
        if sampling_method=='uniform':
            V = utils.get_diff(self.A)
            if self.elimination_method is None or 'stop' in self.elimination_method or self.burn_in_length=='adaptive':
                _, rho, l = utils.get_uniform_design(X_hat, V)
            else:
                indices = np.any(np.all(self.instance.Z[:, None, :] == self.A, axis=2), axis=1)
                indices = np.where(indices)[0]
                subset = indices.tolist()
                _, rho, l = utils.get_uniform_subset_design(X_hat, V, subset)
        elif sampling_method=='g-optimal':
            V = self.A.copy()
            _, rho, l = utils.get_XYdesign(X_hat, V)
            V = utils.get_diff(self.A)
            rho = utils.get_value_given_design(X_hat, V, l)
        elif sampling_method=='xy-optimal':
            V = utils.get_diff(self.A)
            _, rho, l = utils.get_XYdesign(X_hat, V)
        elif sampling_method=='oracle':
            V = utils.get_diff_opt(self.instance.W, self.instance.theta)
            _, rho, l = utils.get_XYdesign(X_hat, V)
            
        return l, rho

        
    def generate_data(self, horizon, return_observations=False):
        
        # Get x's
        if self.instance.type == 'compliance':
            # Gamma is a stochastic matrix. The column gives the probability of x given z
            prob_matrix = self.instance.Gamma.T[:, self.Z_idx_sample].T 
            # Get x by sampling from the probability of x given z.
            self.X_idx = np.array([np.random.choice(self.instance.K, size=1, p=prob_vector) for prob_vector in prob_matrix]).flatten()
            self.X_sample = self.instance.W[self.X_idx]
            # Get the noise as the difference between X and E[X|Z]
            eta = self.X_sample-self.instance.Gamma.T[:, self.Z_idx_sample].T 
        
        elif self.instance.type == 'special_compliance':
            epsilon = np.random.normal(0, self.instance.sigma_y, size=horizon)
            dist = np.vstack([np.abs([self.Z_idx_sample+epsilon-i]) for i in range(self.instance.K)]).T
            self.X_idx = np.argmin(dist, axis=1)
            self.X_sample = self.instance.W[self.X_idx]
        else:
            # maybe make this a generic covariance matrix.
            noise_x = np.random.multivariate_normal(np.zeros(self.instance.d), np.eye(self.instance.d), size=(horizon))
            self.X_sample = self.Z_sample@self.instance.Gamma + noise_x 

        # Get y's
        if self.instance.type == 'compliance':
            tau = self.instance.sigma_y
            v = np.random.randn(self.instance.d)
            v = v/np.linalg.norm(v)
            # noise_y = np.random.normal(0, 1, size=(len(X_sample), 1))
            epsilon = tau*eta@v
            self.Y_sample = self.X_sample@self.instance.theta + epsilon.reshape(-1, 1) 
        elif self.instance.type == 'special_compliance':
            self.Y_sample = self.X_sample@self.instance.theta + epsilon.reshape(-1, 1)
        else:
            epsilon = np.random.normal(size=(len(self.X_sample), 1))
            self.Y_sample = self.X_sample@self.instance.theta + epsilon
            
        if return_observations:
            return self.X_sample, self.Y_sample

        
    def update_estimates(self):
        
        # estimate
        # 2sls - (Z^TX)^{-1}Z^Ty
        ZtX = self.Z_sample.T@self.X_sample
        Zty = self.Z_sample.T@self.Y_sample
        ZtX_inv = np.linalg.pinv(ZtX)
        self.two_sls_theta = ZtX_inv@Zty

        # pseudo-2sls - (Z^TZGamma_hat)^{-1}Z^Ty
        ZtZ = self.Z_sample.T@self.Z_sample
        ZtZGammahat = ZtZ@self.Gamma_hat
        Zty = self.Z_sample.T@self.Y_sample
        ZtZGammahat_inv = np.linalg.pinv(ZtZGammahat)
        self.pseudo_two_sls_theta = ZtZGammahat_inv@Zty

        # oracle - (Z^TZGamma)^{-1}Z^Ty
        ZtZ = self.Z_sample.T@self.Z_sample
        ZtZGamma = ZtZ@self.instance.Gamma   
        Zty = self.Z_sample.T@self.Y_sample
        ZtZGamma_inv = np.linalg.pinv(ZtZGamma)
        # ZtZGamma_inv = np.linalg.inv(ZtZGamma)
        self.oracle_theta = ZtZGamma_inv@Zty # ZtZ number of counts
        # print('ztz:', ZtZ)
        # print('zty:', Zty)
        # print('gamma:', self.instance.Gamma)
        # print('ZtZGamma_inv:', ZtZGamma_inv)

        # ols - (X^TX)^{-1}X^Ty
        XtX = self.X_sample.T@self.X_sample
        Xty = self.X_sample.T@self.Y_sample
        XtX_inv = np.linalg.pinv(XtX)
        self.ols_theta = XtX_inv@Xty
        
        
    def bulk_update_estimates(self):
        
        # estimate
        # 2sls
        ZtX = np.cumsum(np.matmul(self.Z_sample[:, :, np.newaxis], self.X_sample[:, np.newaxis, :]), axis=0)
        Zty = np.cumsum(self.Z_sample*self.Y_sample, axis=0)
        ZtX_inv = np.linalg.pinv(ZtX)
        self.two_sls_theta = np.einsum('ijk,ik->ij', ZtX_inv, Zty)

        # pseudo-2sls
        ZtZ = np.cumsum(np.matmul(self.Z_sample[:, :, np.newaxis], self.Z_sample[:, np.newaxis, :]), axis=0)
        ZtZGammahat = ZtZ@self.Gamma_hat.T
        Zty = np.cumsum(self.Z_sample*self.Y_sample, axis=0)
        ZtZGammahat_inv = np.linalg.pinv(ZtZGammahat)
        self.pseudo_two_sls_theta = np.einsum('ijk,ik->ij', ZtZGammahat_inv, Zty)

        ZtZ = np.cumsum(np.matmul(self.Z_sample[:, :, np.newaxis], self.Z_sample[:, np.newaxis, :]), axis=0)
        ZtZGamma = ZtZ@self.instance.Gamma.T
        Zty = np.cumsum(self.Z_sample*self.Y_sample, axis=0)
        ZtZGamma_inv = np.linalg.pinv(ZtZGamma)
        self.oracle_theta = np.einsum('ijk,ik->ij', ZtZGamma_inv, Zty)

        # ols
        XtX = np.cumsum(np.matmul(self.X_sample[:, :, np.newaxis], self.X_sample[:, np.newaxis, :]), axis=0)
        Xty = np.cumsum(self.X_sample*self.Y_sample, axis=0)
        XtX_inv = np.linalg.pinv(XtX)
        self.ols_theta = np.einsum('ijk,ik->ij', XtX_inv, Xty)

        # get optimal arm estimate
        self.two_sls_prediction = np.argmax(self.two_sls_theta@self.instance.W.T, axis=1)
        self.pseudo_two_sls_prediction = np.argmax(self.pseudo_two_sls_theta@self.instance.W.T, axis=1)
        self.oracle_prediction = np.argmax(self.oracle_theta@self.instance.W.T, axis=1)
        self.ols_prediction = np.argmax(self.ols_theta@self.instance.W.T, axis=1)

        # get accuracy of prediction
        self.acc_two_sls = self.two_sls_prediction == self.instance.opt
        self.acc_pseudo_two_sls = self.pseudo_two_sls_prediction == self.instance.opt
        self.acc_oracle = self.oracle_prediction == self.instance.opt
        self.acc_ols = self.ols_prediction == self.instance.opt

        # get simple regret of prediction
        self.simple_regret_two_sls = (self.instance.W[self.instance.opt]@self.instance.theta - self.instance.W[self.two_sls_prediction]@self.instance.theta).flatten()
        self.simple_regret_pseudo_two_sls = (self.instance.W[self.instance.opt]@self.instance.theta - self.instance.W[self.pseudo_two_sls_prediction]@self.instance.theta).flatten()
        self.simple_regret_oracle = (self.instance.W[self.instance.opt]@self.instance.theta - self.instance.W[self.oracle_prediction]@self.instance.theta).flatten()
        self.simple_regret_ols =  (self.instance.W[self.instance.opt]@self.instance.theta - self.instance.W[self.ols_prediction]@self.instance.theta).flatten()

        self.bias_two_sls = np.linalg.norm(self.two_sls_theta.T-self.instance.theta, 2, axis=0)
        self.bias_pseudo_two_sls = np.linalg.norm(self.pseudo_two_sls_theta.T-self.instance.theta, 2, axis=0)
        self.bias_oracle = np.linalg.norm(self.oracle_theta.T-self.instance.theta, 2, axis=0)
        self.bias_ols = np.linalg.norm(self.ols_theta.T-self.instance.theta, 2, axis=0)
        
        
    def eliminate(self, A):
        
        idx, diff_matrix, var1 = self.get_diff_var(self.Gamma_hat, self.Z_sample)
        idx, diff_matrix, var2 = self.get_diff_var(self.Gamma_hat, self.Z_sample_GH)
        log_term = self.get_log_bar()
        t_norm_2 = np.linalg.norm(self.instance.theta, 2)**2
        conf_1 = np.sqrt(2*var1*self.instance.subgaussian * np.log(self.instance.K*2*self.ell**2/self.delta)) 
        conf_2 = np.sqrt(t_norm_2*var2*log_term)
        conf = conf_1+conf_2
        elim = (diff_matrix@self.pseudo_two_sls_theta).flatten() - conf >0
        elim_idx = np.unique(np.vstack(idx).T[:, 1][elim])
        
        A = np.delete(A, elim_idx, axis=0)
        
        if self.verbose:
            pass

        return A
    
    def oracle_eliminate(self, A):
        
        idx, diff_matrix, var = self.get_diff_var(self.instance.Gamma, self.Z_sample)
        conf = np.sqrt(2*var*self.instance.subgaussian * np.log(self.instance.K*2*self.ell**2/self.delta))
        elim = (diff_matrix@self.oracle_theta).flatten() - conf >0
        elim_idx = np.unique(np.vstack(idx).T[:, 1][elim])

        A = np.delete(A, elim_idx, axis=0)
        
        if self.verbose:
            pass 

        return A
        
        
        
    def get_diff_var(self, Gamma, Z_sample):
        
        indices = np.triu_indices(self.A.shape[0], 1)
        idx = [np.concatenate((indices[0],indices[1])), np.concatenate((indices[1],indices[0]))]
        # print('idx:', idx)
        diff_matrix = self.A[idx[0]] - self.A[idx[1]]
        design = np.linalg.pinv(Gamma.T@Z_sample.T@Z_sample@Gamma)
        var = np.einsum('ij,jk,ik->i', diff_matrix, design, diff_matrix)

        return idx, diff_matrix, var
    
    
    def get_log_bar(self):
        
        T = len(self.Z_sample_GH)
        d = self.instance.d
        A = self.Z_sample_GH.T@self.Z_sample_GH
        sigma_min = sorted(np.linalg.eigvals(A).tolist())[0]
        min_value = min(2, sigma_min)
        log_bar = 8*d*np.log(1+(2*T*self.instance.subgaussian)/(d*min_value)) + 16*np.log((2*6**d)/self.delta*np.log2(4/min_value)**2)
        log_bar = 4*d + np.log(1/self.delta)
        return log_bar
    
    
    def get_bias_term(self):
        
        idx, diff_matrix, var = self.get_diff_var(self.Gamma_hat, self.Z_sample_GH)
        log_term = self.get_log_bar()
        t_norm_2 = np.linalg.norm(self.instance.theta, 2)**2
        bias_term = np.sqrt(t_norm_2*var*log_term)
        return bias_term
    