import numpy as np
from sklearn import linear_model
import time

def LassoWrapper(alpha, **kwargs):
    # return linear_model.Lasso(alpha = 0.3, selection='random', tol = 5e-4) # For faster run-time
    return linear_model.Lasso(alpha = alpha, selection='random', tol = 5e-4, **kwargs) # For faster run-time

# Lasso Bandit by Bastani and Bayati (2015). Online decision-making with high-dimensional covariates.
class LassoBandit:
    def __init__(self,q,h,lam1,lam2,d,N):
        self.Tx=np.empty((N, 0)).tolist()
        self.Sx=np.empty((N, 0)).tolist()
        self.Tr=np.empty((N, 0)).tolist()
        self.Sr=np.empty((N, 0)).tolist()
        self.q=q
        self.h=h
        self.lam1=lam1
        self.lam2=lam2
        self.d=d
        self.N=N
        self.beta_t=np.zeros((N,N*d))
        self.beta_a=np.zeros((N,N*d))
        self.n=0
        self.lasso_t=LassoWrapper(alpha=self.lam1) #for force-sample estimator
    
    def choose_a(self,t,x): #x is N*d-dim vector 
        if t==((2**self.n-1)*self.N*self.q+1):
            self.set=np.arange(t,t+self.q*self.N)
            self.n+=1
        if t in self.set:
            ind=list(self.set).index(t)
            self.action=ind//self.q
            self.Tx[self.action].append(x)
        else:
            est=np.dot(self.beta_t,x) #N by 1
            max_est=np.amax(est)
            self.K=np.argwhere(est>max_est-self.h/2.) # action indexes
            est2=[np.dot(x,self.beta_a[k[0]]) for k in self.K]
            self.action=self.K[np.argmax(est2)][0]
        self.Sx[self.action].append(x)
        return(self.action)            
             
    def update_beta(self,rwd,t):
        if t in self.set:
            self.Tr[self.action].append(rwd)
            self.lasso_t.fit(self.Tx[self.action],self.Tr[self.action])
            self.beta_t[self.action]=self.lasso_t.coef_
        self.Sr[self.action].append(rwd)
        lam2_t=self.lam2*np.sqrt((np.log(t)+np.log(self.N*self.d))/t)
        lasso_a=LassoWrapper(alpha=lam2_t)
        if t>5:
            lasso_a.fit(self.Sx[self.action],self.Sr[self.action])
            self.beta_a[self.action]=lasso_a.coef_

# Lasso Bandit by Bastani and Bayati (2015). Online decision-making with high-dimensional covariates.
# Adjusted to Single parameter setting by us
class FSLassoBandit:
    def __init__(self,q,h,lam1,lam2,d,N):
        self.Tx=[]
        self.Sx=[]
        self.Tr=[]
        self.Sr=[]
        self.q=q
        self.h=h
        self.lam1=lam1
        self.lam2=lam2
        self.d=d
        self.N=N
        self.beta_t=np.zeros(d)
        self.beta_a=np.zeros(d)
        self.n_g = 0
        self.n_e = 0
        self.exploration = True
        self.lasso_t=LassoWrapper(alpha=self.lam1) #for force-sample estimator
        # self.time = 0
        # self.cnt = 0
    
    def choose_a(self,t,x):
        K = x.shape[0]
        if self.n_e <= self.q * (np.log(2 * self.d) + 3 * np.log(self.n_g +1)) :
            self.action = np.random.choice(K)
            self.Tx.append(x[self.action])
            self.n_e += 1
            self.exploration = True
        else:
            est = np.dot(x, self.beta_t)
            max_est = np.amax(est)
            cnt = 0
            for i in range(K):
                if est[i] > max_est - self.h:
                    cnt += 1
                    if cnt >= 2:
                         break
            if cnt <= 1:
                self.action = np.argmax(est)
            else:
                est2 = np.dot(x, self.beta_a)
                self.action = np.argmax(est2)
            
            self.n_g += 1
            self.exploration = False

        self.Sx.append(x[self.action])
        return(self.action)
             
    def update_beta(self,rwd,t):
        self.Sr.append(rwd)
        if self.exploration:
            self.Tr.append(rwd)
            if self.n_e > 5:
                # start_time = time.process_time()
                self.lasso_t.fit(self.Tx, self.Tr)
                # self.time += time.process_time() - start_time
                # self.cnt+=1
                self.beta_t=self.lasso_t.coef_
        if t > 5 and self.n_g > 1:
            lam2_t = self.lam2 * np.sqrt((2 * ( np.log(self.n_g) + np.log(self.d)))/ t)
            lasso_a=LassoWrapper(alpha=lam2_t)
            # start_time = time.process_time()
            lasso_a.fit(self.Sx,self.Sr)
            # self.time += time.process_time() - start_time
            # self.cnt+=1
            self.beta_a=lasso_a.coef_
        

# DR Lasso Bandit by Kim and Paik (2019). Doubly-Robust Lasso Bandit.
class DRLassoBandit:
    def __init__(self,lam1,lam2,d,N,tc,tr,zt):
        self.x=[]
        self.r=[]
        self.lam1=lam1
        self.lam2=lam2
        self.d=d
        self.N=N
        self.beta=np.zeros(d)
        self.tc=tc
        self.tr=tr
        self.zt=zt
        # self.time = 0.0
        # self.cnt = 0
        
    def choose_a(self,t,x):  # x is N*d matrix
        if t<self.zt:
            self.action=np.random.choice(range(self.N))
            self.pi=1./self.N
        else:
            uniformp=self.lam1*np.sqrt((np.log(t)+np.log(self.d))/t)
            uniformp=np.minimum(1.0,np.maximum(0.,uniformp))
            choice=np.random.choice([0,1],p=[1.-uniformp,uniformp])
            est=np.dot(x,self.beta)
            if choice==1:
                self.action=np.random.choice(range(self.N))
                if self.action==np.argmax(est):
                    self.pi=uniformp/self.N+(1.-uniformp)
                else:
                    self.pi=uniformp/self.N
            else:
                self.action=np.argmax(est)
                self.pi=uniformp/self.N+(1.-uniformp)
        self.x.append(np.mean(x,axis=0))
        self.rhat=np.dot(x,self.beta)
        return(self.action)            
             
     
    def update_beta(self,rwd,t):
        pseudo_r=np.mean(self.rhat)+(rwd-self.rhat[self.action])/self.pi/self.N
        if self.tr==True:
            pseudo_r=np.minimum(3.,np.maximum(-3.,pseudo_r))
        self.r.append(pseudo_r)
        if t>5:
            if t>self.tc:
                lam2_t=self.lam2*np.sqrt(2 * (np.log(t)+np.log(self.d))/t) 
            lasso=LassoWrapper(alpha=lam2_t)
            # start_time = time.process_time()
            lasso.fit(self.x,self.r)
            # self.time += time.process_time() - start_time
            # self.cnt+=1

            self.beta=lasso.coef_
            

# Sparsity-Agnostic Lasso Bandit by Oh, Iyengar, Zeevi (2021)
class SALassoBandit:
    def __init__(self, lam0, d, N):
        self.x=[]
        self.r=[]
        self.lam0 = lam0
        self.d=d
        self.N=N
        self.beta=np.zeros(d)
        # self.time = 0.0
        # self.cnt = 0
        
    def choose_a(self,t,x): 
        est=np.dot(x,self.beta)
        self.action=np.argmax(est)
        self.x.append(x[self.action])
        return(self.action)            
             
     
    def update_beta(self,rwd,t):
        self.r.append(rwd)
        if t>5:
            lam_t= self.lam0 * np.sqrt((2*np.log(t) + np.log(self.d))/t) 
            lasso=LassoWrapper(alpha=lam_t)
            # start_time = time.process_time()
            lasso.fit(self.x,self.r)
            # self.time += time.process_time() - start_time
            # self.cnt+=1
            self.beta=lasso.coef_

# Lasso-UCB Li et al. (2021)
class LassoUCBBandit:
    def __init__(self,lam0,d,N, tau):
        self.x=[]
        self.r=[]
        self.lam0 = lam0
        self.tau = tau
        self.d=d
        self.N=N
        self.beta=np.zeros(d)
        # self.time = 0.0
        # self.cnt = 0
        
    def choose_a(self,t,x): 
        tau_t = self.tau * np.sqrt( np.log( self.d * t) / t)
        est=np.dot(x,self.beta) + tau_t * np.amax(x, axis = 1)
        self.action=np.argmax(est)
        self.x.append(x[self.action])
        return(self.action)  
             
     
    def update_beta(self,rwd,t):
        self.r.append(rwd)
        if t>5:
            lam_t= self.lam0 * np.sqrt(2 * np.log(t * self.d)/t) 
            lasso=LassoWrapper(alpha=lam_t)
            # start_time = time.process_time()
            lasso.fit(self.x,self.r)
            # self.time += time.process_time() - start_time
            # self.cnt+=1
            self.beta=lasso.coef_


# ETC - LASSO Bandit (Ours)
class ETCLassoBandit:
    def __init__(self, M_0, w, sigma, d, delta):
        self.x = []
        self.r = []
        self.beta = np.zeros(d)
        self.M_0 = M_0
        self.w = w
        self.sqrtw = np.sqrt(w)
        self.lam_e = 4 * sigma * np.sqrt( np.log( d / delta)) * w
        self.lam_g = 7 * sigma
        self.d = d
        self.delta = delta
        self.lam_g_add = np.log ( 7 * d / delta)
        # self.time = 0.0
        # self.cnt = 0

    def choose_a(self, t, x):
        if t <= self.M_0:
            self.action = np.random.choice(x.shape[0])
            self.x.append(x[self.action] * self.sqrtw)
        else:
            est = np.dot(x, self.beta)
            self.action=np.argmax(est)
            self.x.append(x[self.action])
        return self.action

    def update_beta(self, rwd, t):
        if t <= self.M_0:
            self.r.append(rwd * self.sqrtw)
        else:
            self.r.append(rwd)

        if t >= self.M_0:
            lam_t = self.lam_e * np.sqrt(self.M_0)
            if t > self.M_0:
                lam_t += self.lam_g * np.sqrt( (t - self.M_0) * (2 * np.log( np.log( 2*(t - self.M_0))) + self.lam_g_add))
            lam_t /= 2 * t
            lasso = LassoWrapper(alpha = lam_t)
            # start_time = time.process_time()
            lasso.fit(self.x, self.r)
            # self.time += time.process_time() - start_time
            # self.cnt+=1
            self.beta = lasso.coef_

# ESTC (Hao et al. (2020))
class ESTCBandit:
    def __init__(self, M_0, lam0, d):
        self.x = []
        self.r = []
        self.beta = np.zeros(d)
        self.M_0 = M_0
        self.lasso = LassoWrapper(alpha = lam0 * np.sqrt(np.log(d) / M_0))
        self.d = d
        # self.time = 0
        # self.cnt = 0

    def choose_a(self, t, x):
        if t <= self.M_0:
            self.action = np.random.choice(x.shape[0])
            self.x.append(x[self.action])
        else:
            est = np.dot(x, self.beta)
            self.action=np.argmax(est)
        return self.action

    def update_beta(self, rwd, t):
        if t <= self.M_0:
            self.r.append(rwd)
            if t == self.M_0:
                # start_time = time.process_time()
                self.lasso.fit(self.x, self.r)
                # self.time += time.process_time() - start_time
                # self.cnt+=1
                self.beta = self.lasso.coef_


def linear_regression(x, y):
    try:
        CXX = np.dot(x.T, x) / x.shape[0]
        CXY = np.dot(x.T, y) / x.shape[0]
        return np.linalg.solve(CXX, CXY).T
    except np.linalg.LinAlgError:
        return linear_model.LinearRegression(fit_intercept=False).fit(x, y).coef_


class THLassoBandit(object):
    def __init__(self, K, d, lam0):
        self.K = K
        self.d = d
        self.lam0 = lam0
        self.beta = np.zeros(d)
        self.xs = []
        self.rs = []
        self.S = np.arange(d)
        # self.time = 0.0
        # self.cnt = 0

    def choose_a(self, t, x):
        if len(self.S) == 0:
            a = np.random.randint(self.K, dtype=np.int64)
        else:
            a = np.argmax(np.dot(x, self.beta))
        self.xs.append(x[a].copy())
        return a

    def update_beta(self, r, t):
        self.rs.append(r)
        if t > 5 :
            lam = self.lam0 * np.sqrt(2 * np.log(t) * np.log(self.d) / t)
            # start_time = time.process_time()
            beta = LassoWrapper(alpha=lam, fit_intercept = False).fit(self.xs, self.rs).coef_
            # self.time += time.process_time() - start_time
            # self.cnt+=1
            self.S = np.where(np.abs(beta) > 4 * lam)[0]
            if len(self.S) == 0:
                self.beta = np.zeros(self.d)
                return
            for i in range(1):
                beta_cp = np.zeros(self.d)
                beta_cp[self.S] = beta[self.S]
                self.S = np.where(np.abs(beta_cp) > 4 * lam * np.sqrt(len(self.S)))[0]
                if len(self.S) == 0:
                    self.beta = np.zeros(self.d)
                    return
                beta[self.S] = linear_regression(np.array(self.xs)[:, self.S], self.rs)
            self.beta = np.zeros(self.d)
            self.beta[self.S] = beta[self.S]