import numpy as np
import random
import math
from sklearn.linear_model import LogisticRegression
from algorithms.AutoTuning import *

class UCB_GLM:
    def __init__(self, class_context, T):
        self.data = class_context
        self.T = T
        self.d = self.data.d

    def glmucb_theoretical_explore(self, lamda=1, delta=0.1, explore = -1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        B = np.identity(d) * lamda
        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            B += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1], this is for fitting logistic regression in sklearn
        B_inv = np.linalg.inv(B)
        
        explore_flag = explore
        for t in range(2, T):
            # when explore = -1, which is impossible, use theoretical value
            # otherwise, it means I have specify a fixed value of explore in the code
            # specifying a fixed value for explore is only for grid serach
            if explore_flag == -1:
                explore = self.data.sigma*math.sqrt( d*math.log((t*self.data.max_norm**2/lamda+1)/delta) ) + math.sqrt(lamda)
            else:
                explore = explore_flag
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X, y)
            theta = clf.coef_[0]
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t, pull)
            tmp = B_inv.dot(feature[pull])
            B += np.outer(feature[pull], feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret
    
    def glmucb_tl(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        B = np.identity(d) * lamda
        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            B += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        B_inv = np.linalg.inv(B)
        
        # initialization for exp3 algo
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial hyper-para
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X, y)
            theta = clf.coef_[0]
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t, pull)
            tmp = B_inv.dot(feature[pull])
            B += np.outer(feature[pull], feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update explore rates by auto_tuning (tl)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
        return regret    
    
    def glmucb_op(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        B = np.identity(d) * lamda
        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            B += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        B_inv = np.linalg.inv(B)
        
        # initialization for op_tuning
        Kexp = len(explore_rates)
        s = np.ones(Kexp)
        f = np.ones(Kexp)
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X, y)
            theta = clf.coef_[0]
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t, pull)
            tmp = B_inv.dot(feature[pull])
            B += np.outer(feature[pull], feature[pull])
            B_inv -= np.outer(tmp, tmp)/ (1+feature[pull].dot(tmp))
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            
            # update explore rates by op_tuning
            s, f, index = op_tuning(s, f, observe_r, index)
            explore = explore_rates[index]
        return regret
    
    def glmucb_syndicated(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        
        xxt = np.zeros((d,d))
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            xxt += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        
        # initialization for exp3 algo
        Kexp = len(explore_rates)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial hyper
        index = np.random.choice(Kexp)
        explore = explore_rates[index]
        
        # initialization for lambda
        Klam = len(lamdas)
        loglamw = np.zeros(Klam)
        plam = np.ones(Klam) / Klam
        gamma_lam = min(1, math.sqrt( Klam*math.log(Klam) / ( (np.exp(1)-1) * T ) ) )
        # random initial explore rate
        index_lam = np.random.choice(Klam)
        lamda = lamdas[index_lam]

        B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X, y)
            theta = clf.coef_[0]
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t, pull)

            # update explore rates by auto_tuning (syndicated)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore = explore_rates[index]
            loglamw, plam, index_lam = auto_tuning(loglamw, plam, observe_r, index_lam, gamma_lam)
            lamda = lamdas[index_lam]
            
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret 
    
    def glmucb_tl_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)
        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        
        xxt = np.zeros((d,d))
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            xxt += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        
        # initialization for exp3 algo
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2)
        Kexp = len(explore_lamda)
        logw = np.zeros(Kexp)
        p = np.ones(Kexp) / Kexp
        gamma = min(1, math.sqrt( Kexp*math.log(Kexp) / ( (np.exp(1)-1) * T ) ) )
        # random initial hyper-paras
        index = np.random.choice(Kexp)
        explore, lamda = explore_lamda[index]

        B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            ucb_idx = [0]*K
            clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X, y)
            theta = clf.coef_[0]
            for arm in range(K):
                ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv).dot(feature[arm]) )
            pull = np.argmax(ucb_idx)
            observe_r = self.data.random_sample(t, pull)

            # update explore rates by auto_tuning (tl-combined)
            logw, p, index = auto_tuning(logw, p, observe_r, index, gamma)
            explore, lamda = explore_lamda[index]
            
            xxt += np.outer(feature[pull], feature[pull])
            B_inv = np.linalg.inv(xxt + lamda*np.identity(d))
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
        return regret 
    
    def glmucb_corral(self, explore_rates, lamda=1):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)

        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        xxt = np.zeros((d,d))
        
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            xxt += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        M = len(explore_rates)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        B_inv = [np.linalg.inv(xxt + lamda*np.identity(d)) for _ in range(M)]
        xxt = [xxt for _ in range(M)]
        y = [y for _ in range(M)]
        X = [X for _ in range(M)]
        
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore = explore_rates[base]
                clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X[base], y[base])
                theta = clf.coef_[0]
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv[base]).dot(feature[arm]) )
                pull += [np.argmax(ucb_idx)]
                
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            
            # update glmucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                xxt[base] += np.outer(feature[pull[base]], feature[pull[base]])
                B_inv[base] = np.linalg.inv(xxt[base] + lamda*np.identity(d))
                y[base] = np.concatenate((y[base], [rew]), axis = 0)
                X[base] = np.concatenate((X[base], [feature[pull[base]]]), axis = 0)
            
            passl = np.zeros(M)
            passl[chosen_base] = 0.5-observe_r # use 0.5 because obeserve_r is either 0 or 1, to avoid that passl are all 0s
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta   

            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
        return regret   
    
    def glmucb_corral_combined(self, explore_rates, lamdas):
        T = self.T
        d = self.data.d
        regret = np.zeros(T)

        theta_hat = np.zeros(d)
        y = np.array([])
        y = y.astype('int')
        X = np.empty([0, d])
        theta = np.zeros(d)
        xxt = np.zeros((d,d))
        
        for t in range(2):
            feature = self.data.fv[t]
            K = len(feature)
            pull = np.random.choice(K)
            observe_r = self.data.random_sample(t, pull) 
            y = np.concatenate((y, [observe_r]), axis = 0)
            X = np.concatenate((X, [feature[pull]]), axis = 0)
            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull]
            xxt += np.outer(feature[pull], feature[pull])
        if y[0] == y[1]:
            y[1] = 1-y[0]
        # random pull in the first two rounds to make sure y[0] != y[1]
        
        K = len(self.data.fv[0])
        eta0 = 1/math.sqrt(K*T*math.log(K))
        explore_lamda = np.array(np.meshgrid(explore_rates, lamdas)).T.reshape(-1,2) # combination set
        M = len(explore_lamda)
        p = np.ones(M) / M
        pbar = np.ones(M) / M
        gamma = 1/T
        beta = np.exp(1/math.log(T))
        rho = [2*M] * M
        eta = [eta0] * M
        B_inv = []
        for base in range(M):
            _, lamda = explore_lamda[base]
            B_inv += [np.linalg.inv(xxt + lamda*np.identity(d))]
        xxt = [xxt for _ in range(M)]
        y = [y for _ in range(M)]
        X = [X for _ in range(M)]
        
        for t in range(2, T):
            feature = self.data.fv[t]
            K = len(feature)
            pull = []
            for base in range(M):
                ucb_idx = [0]*K
                explore, lamda = explore_lamda[base]
                clf = LogisticRegression(penalty = 'l2', C = lamda, fit_intercept = False, solver = 'lbfgs').fit(X[base], y[base])
                theta = clf.coef_[0]
                for arm in range(K):
                    ucb_idx[arm] = feature[arm].dot(theta) + explore * math.sqrt( feature[arm].T.dot(B_inv[base]).dot(feature[arm]) )
                pull += [np.argmax(ucb_idx)]
                
            chosen_base = np.random.choice(M, p=pbar)
            observe_r = self.data.random_sample(t,pull[chosen_base])
            
            # update glmucb
            for base in range(M):
                if base == chosen_base:
                    rew = observe_r
                else: 
                    rew = 0
                explore, lamda = explore_lamda[base]
                xxt[base] += np.outer(feature[pull[base]], feature[pull[base]])
                B_inv[base] = np.linalg.inv(xxt[base] + lamda*np.identity(d))
                y[base] = np.concatenate((y[base], [rew]), axis = 0)
                X[base] = np.concatenate((X[base], [feature[pull[base]]]), axis = 0)
            
            passl = np.zeros(M)
            passl[chosen_base] = 0.5-observe_r
            p = log_barrier(p, passl, eta)
            pbar = (1-gamma) * p + gamma/M
            for base in range(M):
                if 1/pbar[base] >= rho[base]:
                    rho[base] = 2/pbar[base]
                    eta[base] *= beta   

            regret[t] = regret[t-1] + self.data.optimal[t] - self.data.reward[t][pull[chosen_base]]
        return regret 