import numpy as np

class RECON:
    def __init__(self, env, lam, eta, c, seed, decay, normal_para=None):
        self.K = env.K
        self.T = env.T
        self.delta = env.delta
        self.d = env.d
        self.reg_sq = env.reg_sq
        self.c = c
        self.p = np.zeros((self.T, self.K))
        self.opt_arms = np.zeros(self.T)
        self.selected_arms = [0 for i in range(self.T)]
        self.lam = lam
        self.eta = np.zeros(self.K)
        self.reward = np.zeros(self.T)
        self.seed = seed
        self.sigma = env.sigma
        self.T_bic = env.n_bic
        self.normal_para = normal_para
        self.min_eig = env.min_eig
        if self.normal_para == None:
            self.gamma = np.zeros(self.T)
        else:
            self.gamma = np.zeros((self.T, self.K))
            
        self.decay_method = decay
        
        # true parameters
        self.theta = env.theta
        self.theta_hat = np.zeros((self.K, self.d))
        self.theta_hat_traj = np.zeros((self.K, self.T, self.d))
        self.count = np.zeros(self.K) # number of pulls for each arm
        self.epoch_count = {}
        # data  matrix
        self.xt = env.xt
        # reward matrix
        self.yt = env.yt
        self.y_hat = np.zeros((self.T, self.K))
        self.yt_mean = env.yt_mean
        self.true_arm = env.true_arm
        
        # initial estimator 
        self._init_all_sample()
        self._init_epoch_determination()
        
    def _init_all_sample(self):
        self.all_x = np.zeros((self.K, self.T, self.d))
        self.all_y = np.zeros((self.K, self.T))
        
    def _init_epoch_determination(self):  
        # create a list to store the epoch based on log2(T) and intialize the epoch parameter
        self.total_epoch = int(np.ceil(np.log2(self.T+1)))
        self.epoch_param = {}
        for epoch in range(self.total_epoch):
            self.epoch_param[epoch] = {}
            for i in range(self.K):
                self.epoch_param[epoch][i] = np.zeros(self.d)
    
    def _refresh_all(self, m):
        self.epoch_count[m+1] = np.zeros(self.K)
        self.epoch_x = np.zeros((self.K, self.T_bic * np.power(2, m+2), self.d))
        self.epoch_y = np.zeros((self.K, self.T_bic * np.power(2, m+2)))
        self.theta_hat = np.zeros((self.K, self.d))

    def gamma_determination(self, m, t):
        c_1 = 1
        c_2 = 4
        c_3 = (c_1 * c_2 + c_2 * np.power(self.sigma, 2) * self.d)/(np.power(self.sigma, 2) * self.d)
        epoch_sample = self.T_bic * np.power(2, m-1)
        if self.normal_para == None:
            error_b = c_3 * np.power(self.sigma, 2) * self.d /(self.min_eig * epoch_sample)
            self.gamma[t] = self.c * np.sqrt(self.K/error_b)
            return self.gamma[t]
        else:
            for i in range(self.K):
                error_b = c_3 * np.power(self.sigma, 2) * self.d /(self.lam[i] * epoch_sample)
                gamma = self.c * np.sqrt(self.K/error_b)
                self.gamma[t,i] = gamma
            return np.max(self.gamma[t,])

        
    def Select(self, xt, t, m):
        if t < self.T_bic:
            self.selected_arms[t] = t % self.K
            self.reward[t] = self.yt[t, self.selected_arms[t]]
            #self.update_sgd(xt, self.reward[t], self.selected_arms[t])
            self.update_lr_start(xt, self.reward[t], self.selected_arms[t])
            self.count[self.selected_arms[t]] += 1

        else:  
            # step 0: check the T_bic or m_0 epoch's parameter
            if t == self.T_bic:
                # initilize the epoch parameter based on the T_bic data
                for i in range(self.K):
                    self.epoch_param[m-1][i] = self.theta_hat[i]
                self._refresh_all(m-1)
                
            # step 1: predict the reward for each arm
            for i in range(self.K):
                # use last epoch parameter to predict
                self.y_hat[t][i] = self.predict(xt, i, m)
    
            # step 2: select the arm with the highest y_hat
            opt_arm = np.argmax(self.y_hat[t])
            self.opt_arms[t] = opt_arm
            
            # step 3: calculate the probability of selecting each arm
            gamma_m = self.gamma_determination(m,t)
            for i in range(self.K):
                if i != opt_arm:
                    gap = self.y_hat[t][opt_arm] - self.y_hat[t][i]
                    self.p[t, i] = 1.0/(self.K + gamma_m * gap)
            
            self.p[t, opt_arm] = 1 - np.sum(self.p[t, :])
            
            # step 4: select the arm according to the probability and obtain the reward
            self.selected_arms[t] = int(np.random.choice(self.K, p=self.p[t, :]))
            
            # update the number of pulls for optimization use
            self.epoch_count[m][self.selected_arms[t]] += 1
            self.reward[t] = self.yt[t, self.selected_arms[t]]
            
            # step 5: update the model
            #self.update_sgd(xt, self.reward[t], self.selected_arms[t])
            # update model only on the end of the epoch
            self.update_lr_steady(xt, self.reward[t], self.selected_arms[t], t, m)
        
    def predict(self, xt, arm_id, m):
        pred_y = self.epoch_param[m-1][arm_id].dot(xt)

        return pred_y
    
    def update_lr_start(self, xt, yt, id):
        # step 1: update the all sample data
        self.all_x[id, int(self.count[id]), :] = xt
        self.all_y[id, int(self.count[id])] = yt

        # step 2: extract the data for the current arm
        x_selected = self.all_x[id, :int(self.count[id]), :]
        y_selected = self.all_y[id, :int(self.count[id])]
        
        # step 3: update the model with linear regression
        if self.normal_para == None:
            self.theta_hat[id] = \
                np.linalg.inv(x_selected.T.dot(x_selected) + self.lam * np.identity(self.d)).dot(x_selected.T).dot(y_selected)
        else:
            self.theta_hat[id] = \
                np.linalg.inv(x_selected.T.dot(x_selected) + np.linalg.inv(self.normal_para[id]['cov'])).dot(\
                    (x_selected.T).dot(y_selected) + \
                    (np.linalg.inv(self.normal_para[id]['cov']).dot(self.normal_para[id]['mean'])))

        self.theta_hat_traj[id, int(self.count[id]), :] = self.theta_hat[id]    
        
        
        
    def update_lr_steady(self, xt, yt, id, t, m):
        # step 1: update the epoch sample data from 0.
        self.epoch_x[id, int(self.epoch_count[m][id])-1, :] = xt
        self.epoch_y[id, int(self.epoch_count[m][id])-1] = yt
        # step 2: extract the data for the current arm
   
        # step 3: update the model with linear regression
        if self.normal_para == None:
            if (t+1) == self.T_bic * np.power(2, m):
                # update parameters
                for id in range(self.K):
                    x_selected_epoch = self.epoch_x[id, :int(self.epoch_count[m][id]), :]
                    y_selected_epoch = self.epoch_y[id, :int(self.epoch_count[m][id])]
        
                    self.theta_hat[id] = np.linalg.inv(x_selected_epoch.T.dot(x_selected_epoch) + (self.lam/(t-self.T_bic)) * np.identity(self.d)).dot(x_selected_epoch.T).dot(y_selected_epoch)
                    self.epoch_param[m][id] = self.theta_hat[id]
                    #print("epoch %d, para: %s" % (m, self.theta_hat[id]))
                # forget current all data
                self._refresh_all(m)

        else:                
            if (t+1) == self.T_bic * np.power(2, m):
                # update parameter
                for id in range(self.K):
                    x_selected_epoch = self.epoch_x[id, :int(self.epoch_count[m][id]), :]
                    y_selected_epoch = self.epoch_y[id, :int(self.epoch_count[m][id])]
                    if self.decay_method == 'linear':
                        x_inv = np.linalg.inv(x_selected_epoch.T.dot(x_selected_epoch) + np.linalg.inv(self.normal_para[id]['cov'] * (t-self.T_bic + 1)))
                        #print(np.linalg.inv(self.normal_para[id]['cov'] * (t-self.T_bic + 1)))
                        y_bias = (x_selected_epoch.T).dot(y_selected_epoch) + (np.linalg.inv(self.normal_para[id]['cov'] * (t-self.T_bic + 1)).dot(self.normal_para[id]['mean']))
                    elif self.decay_method == 'sqrt':
                        x_inv = np.linalg.inv(x_selected_epoch.T.dot(x_selected_epoch) + np.linalg.inv(self.normal_para[id]['cov'] * np.sqrt(t-self.T_bic + 1)))
                        #print(np.linalg.inv(self.normal_para[id]['cov'] * np.sqrt(t-self.T_bic + 1)))
                        y_bias = (x_selected_epoch.T).dot(y_selected_epoch) + (np.linalg.inv(self.normal_para[id]['cov'] * np.sqrt(t-self.T_bic + 1)).dot(self.normal_para[id]['mean']))
                    elif self.decay_method == 'log':
                        x_inv = np.linalg.inv(x_selected_epoch.T.dot(x_selected_epoch) + np.linalg.inv(self.normal_para[id]['cov'] * np.log(t-self.T_bic + 1)))
                        #print(np.linalg.inv(self.normal_para[id]['cov'] * np.log(t-self.T_bic + 1)))
                        y_bias = (x_selected_epoch.T).dot(y_selected_epoch) + (np.linalg.inv(self.normal_para[id]['cov'] * np.log(t-self.T_bic + 1)).dot(self.normal_para[id]['mean']))
                
                    self.theta_hat[id] = x_inv.dot(y_bias)
                    self.epoch_param[m][id] = self.theta_hat[id]
                

                # forget current all data
                self._refresh_all(m)

        #self.theta_hat_traj[id, int(self.count[id]), :] = self.theta_hat[id]    

        
        
    def update_sgd(self, xt, yt, arm_id):
        # step 0: calculate the learning rate
        for i in range(self.K):
            if self.count[i] != 0:
                self.eta[i] = 1.0/(np.sqrt(self.count[i]))

        # step 1: update the model with stochastic gradient descent
        if self.count[arm_id] == 1:
            self.theta_hat[arm_id] = self.theta_hat[arm_id] - self.eta[arm_id] * \
                [(self.theta_hat[arm_id].dot(xt)- yt) * xt + 2 * self.lam * self.theta_hat[arm_id]][0]
        else:
            self.theta_hat[arm_id] = self.theta_hat[arm_id] - self.eta[arm_id] * \
                (self.theta_hat[arm_id].dot(xt)- yt) * xt
        
        # step 2: store the trajectory of theta_hat
        #print(self.theta_hat[arm_id])
        self.theta_hat_traj[arm_id, int(self.count[arm_id])-1, :] = self.theta_hat[arm_id]
    
    def regret(self):
        # step 1: calculate the regret
        regret = np.zeros(self.T)
        for t in range(self.T):
            regret[t] = np.max(self.yt_mean[t, :]) - self.yt_mean[t, self.selected_arms[t]]
        accum_regret = np.cumsum(regret)
        
        # step 2: calculate the incorrect rate of the selected arm
        incorrect_count = np.zeros(self.T)
        for t in range(self.T_bic, self.T):
            if self.true_arm[t] != self.selected_arms[t]:
                incorrect_count[t] = 1
        
        accum_incorrect = np.cumsum(incorrect_count[self.T_bic:])
        accum_incorrect_rate = accum_incorrect / np.arange(1, self.T - self.T_bic + 1)

        return accum_regret, accum_incorrect_rate       
    
    def BIC_gain(self):
        # Here the gain is in the posterior version
        # if selected arm is the best arm, then gain is the optimal mean reward - second best mean reward
        # else: gain is the selected mean reward - optimal mean reward 
        self.gain = np.zeros(self.T)
        for t in range(self.T_bic, self.T):
            if self.true_arm[t] == self.selected_arms[t]:
                # gain is optimal mean reward - second best mean reward
                #self.gain[t] = np.max(self.yt_mean[t, :]) - self.yt_mean[t, np.argsort(self.yt_mean[t, :])[-2]]
                self.gain[t] = np.max(self.y_hat[t, :]) - self.y_hat[t, np.argsort(self.y_hat[t, :])[-2]]
            else:
                # gain is selected mean reward - optimal mean reward
                # self.gain[t] = self.yt_mean[t, self.selected_arms[t]] - np.max(self.yt_mean[t, :])
                self.gain[t] = self.y_hat[t, self.selected_arms[t]] - np.max(self.y_hat[t, :])
        # accumlate gain is calculated after the self.T_bic
        self.accum_gain = np.cumsum(self.gain[self.T_bic:self.T])/np.arange(1, self.T+1-self.T_bic)
        return self.accum_gain
    
    def Correct_Table(self):
        # create a 3 by 3 table to store the correct rate
        # row is the true arm, column is the selected arm
        # the entry is the correct rate
        self.correct_table = np.zeros((self.K, self.K))
        # count true arm count 
        self.true_arm_num = np.zeros(self.K)
        
        for t in range(self.T_bic, self.T):
            if self.true_arm[t] == 0:
                self.true_arm_num[0] += 1
            elif self.true_arm[t] == 1:
                self.true_arm_num[1] += 1
            elif self.true_arm[t] == 2:
                self.true_arm_num[2] += 1 
                
            # assign numbers to the table
            self.correct_table[self.true_arm[t], self.selected_arms[t]] += 1
        
        # normalize the table
        for i in range(self.K):
            self.correct_table[i, :] = self.correct_table[i, :] / self.true_arm_num[i]
        
        # true arm ratio:
        self.true_arm_ratio = self.true_arm_num / (self.T - self.T_bic)
        return self.correct_table, self.true_arm_ratio

                
