import numpy as np
from scipy.optimize import fmin_tnc
from numpy.linalg import inv
import random
from scipy.optimize import minimize, NonlinearConstraint
from itertools import combinations
from math import comb

def set_function(cat_info,omega,step):         #(1-r^(1/step))-strict-sub-modular
    r=(1-omega)**step
    coverage=np.sum(np.max(cat_info, axis=0))
    ncat=cat_info.shape[0]
    set_score = (1-r**coverage)/(1-r**ncat)
    return set_score

# ofu_dmnl
class ofu_dmnl(object):
    def __init__(self, N, K, d, omega, step, kappa = 0.5, beta = None, r_lambda=None,  vzero = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors 
        :param K: maximum assortment size
        :param S: upper bound of L2 norm of each w
        :param eta: step-size parameter for online update
        :param r_lambda: regularization parameter
        """
        # immediate attributes from the constructor
        self.N = N
        self.d = d
        self.S = 1
        self.K = K
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        self.W_t = np.zeros(self.d+1)[:, None]  # estimated parameter
        self.W_t[-1]=0.25
        self.H = self.r_lambda * np.identity(self.d+1)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d+1)  # inverse of H
        self.inputbeta=beta 
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        self.omega=omega
        self.step=step
        self.numinit=0
        
    def choose_S(self, t, x, cat):
        """
        choose the optimistic assortment
        """
        if self.inputbeta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   


        z01 = np.zeros(self.d + 1, dtype=float)
        z01[-1] = 1.0
        z01 = np.concatenate([x[1], np.array([0.5])])         
        confi_radii = 2.0 * self.beta*np.sqrt(float(z01 @ self.inv_H @ z01))
        threshold = self.omega * self.W_t[-1]



        if confi_radii > threshold:
            if self.K > self.N:
                raise ValueError(f"K(={self.K}) cannot exceed N(={self.N}).")
            St = np.random.choice(self.N, size=self.K, replace=False).tolist()
            self.numinit=self.numinit+1
        else:
            St=[]
            for it in range(self.K):
                new_rwd=[0]*self.N
                for i in range(self.N): 
                    if i in St:
                        continue
                    St1=St+[i]
                    new_rwd[i]=self.compute_rwd(x[St1], set_function(cat[St1], self.omega, self.step))
                St.append(np.argmax(new_rwd))
        z=np.hstack([x[St], np.full((x[St].shape[0], 1), set_function(cat[St], self.omega, self.step))])
        means = np.squeeze(np.dot(z,self.W_t))
        #print("exp_relevance: ", means)
        zv = np.sqrt((np.matmul(z, self.inv_H) * z).sum(axis = 1))
        #print("confi_width: ", self.beta*zv)
        self.assortment = np.array(St, dtype=int)    
        self.chosen_vectors = x[self.assortment,:]
        self.chosen_vectors=np.hstack([self.chosen_vectors, np.full((self.K,1), set_function(cat[St], self.omega, self.step))])
        return(self.assortment)
    
    def compute_rwd(self, cxts, setscore):
        z=np.hstack([cxts, np.full((cxts.shape[0], 1), setscore)])
        means = np.squeeze(np.dot(z,self.W_t))
        zv = np.sqrt((np.matmul(z, self.inv_H) * z).sum(axis = 1))
        u = means + self.beta * zv
        eu = np.exp(u)
        uSum = self.vzero + eu.sum()
        rwd = eu.sum()/uSum
        return rwd

    def update_state(self, y):
        """
        update state
        """
        Z = self.chosen_vectors
        assert isinstance(Z, np.ndarray), 'np.array required'
        self.update_estimator(Z, y)
        probs = self.Sigma(self.W_t, Z)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', Z, Z), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', Z, Z), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
        
    def update_estimator(self, X, y):
        """
        update parameter
        """
        y = np.squeeze(y[:-1])
        W_estimate = self.W_t
        probs = np.squeeze(self.Sigma(self.W_t, X))
        g_Wt = np.sum(np.multiply(np.repeat((probs - y)[...,np.newaxis], self.d+1, axis=1), X), axis=0) # d dimension
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term
        M_t = 1/(2*self.eta) * self.H + 1/2 * gg_Wt
        inv_Mt = inv(M_t)
        unprojected_update = np.squeeze(W_estimate) - np.dot(inv_Mt,g_Wt)
        if np.linalg.norm(unprojected_update) > self.S:
            if self.K == 1:
                W_estimate = self.S * unprojected_update / np.linalg.norm(unprojected_update)
            else:
                W_estimate = self.projection(unprojected_update, M_t)[:,None]
            self.W_t = W_estimate
        else:
            self.W_t = unprojected_update
        self.W_t = unprojected_update
        #print("OFU-DMNL W_t: ", self.W_t)

    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma

    def proj_fun(self, W, un_projected, M):
        diff = W-un_projected
        fun = np.dot(diff, np.dot(M, diff))
        return fun

    def projection(self, unprojected, M):
        fun = lambda t: self.proj_fun(t, unprojected, M)
        constraints = []
        norm = lambda t: np.linalg.norm(t[self.d :self.d  + self.d])
        constraint = NonlinearConstraint(norm, 0, self.S)
        constraints.append(constraint)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
        return opt.x
    

# ofu_dmnl_full
class ofu_dmnl_full(object):
    def __init__(self, N, K, d, omega, step, kappa = 0.5, beta = None, r_lambda=None,  vzero = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors 
        :param K: maximum assortment size
        :param S: upper bound of L2 norm of each w
        :param eta: step-size parameter for online update
        :param r_lambda: regularization parameter
        """
        # immediate attributes from the constructor
        self.N = N
        self.d = d
        self.S = 1
        self.K = K
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        self.W_t = np.zeros(self.d+1)[:, None]  # estimated parameter
        self.W_t[-1]=0.25
        self.H = self.r_lambda * np.identity(self.d+1)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d+1)  # inverse of H
        self.inputbeta=beta 
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        self.omega=omega
        self.step=step
        self.numinit=0
        
    def choose_S(self, t, x, cat):
        """
        choose the optimistic assortment
        """
        if self.inputbeta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   


        max_rwd = 0
        St= None
        for tuple in combinations(range(self.N), self.K):
            S=list(tuple)
            rwd=self.compute_rwd(x[S], set_function(cat[S], self.omega, self.step))
            if rwd>max_rwd:
                max_rwd=rwd
                St=S 
        z=np.hstack([x[St], np.full((x[St].shape[0], 1), set_function(cat[St], self.omega, self.step))])
        means = np.squeeze(np.dot(z,self.W_t))
        #print("exp_relevance: ", means)
        zv = np.sqrt((np.matmul(z, self.inv_H) * z).sum(axis = 1))
        #print("confi_width: ", self.beta*zv)
        self.assortment = np.array(St, dtype=int)    
        self.chosen_vectors = x[self.assortment,:]
        self.chosen_vectors=np.hstack([self.chosen_vectors, np.full((self.K,1), set_function(cat[St], self.omega, self.step))])
        return(self.assortment)
    
    def compute_rwd(self, cxts, setscore):
        z=np.hstack([cxts, np.full((cxts.shape[0], 1), setscore)])
        means = np.squeeze(np.dot(z,self.W_t))
        zv = np.sqrt((np.matmul(z, self.inv_H) * z).sum(axis = 1))
        u = means + self.beta * zv
        eu = np.exp(u)
        uSum = self.vzero + eu.sum()
        rwd = eu.sum()/uSum
        return rwd

    def update_state(self, y):
        """
        update state
        """
        Z = self.chosen_vectors
        assert isinstance(Z, np.ndarray), 'np.array required'
        self.update_estimator(Z, y)
        probs = self.Sigma(self.W_t, Z)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', Z, Z), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', Z, Z), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
        
    def update_estimator(self, X, y):
        """
        update parameter
        """
        y = np.squeeze(y[:-1])
        W_estimate = self.W_t
        probs = np.squeeze(self.Sigma(self.W_t, X))
        g_Wt = np.sum(np.multiply(np.repeat((probs - y)[...,np.newaxis], self.d+1, axis=1), X), axis=0) # d dimension
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term
        M_t = 1/(2*self.eta) * self.H + 1/2 * gg_Wt
        inv_Mt = inv(M_t)
        unprojected_update = np.squeeze(W_estimate) - np.dot(inv_Mt,g_Wt)
        if np.linalg.norm(unprojected_update) > self.S:
            if self.K == 1:
                W_estimate = self.S * unprojected_update / np.linalg.norm(unprojected_update)
            else:
                W_estimate = self.projection(unprojected_update, M_t)[:,None]
            self.W_t = W_estimate
        else:
            self.W_t = unprojected_update
        self.W_t = unprojected_update
        #print("OFU-DMNL W_t: ", self.W_t)

    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma

    def proj_fun(self, W, un_projected, M):
        diff = W-un_projected
        fun = np.dot(diff, np.dot(M, diff))
        return fun

    def projection(self, unprojected, M):
        fun = lambda t: self.proj_fun(t, unprojected, M)
        constraints = []
        norm = lambda t: np.linalg.norm(t[self.d :self.d  + self.d])
        constraint = NonlinearConstraint(norm, 0, self.S)
        constraints.append(constraint)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
        return opt.x

# ofu_mnl_dr
class ofu_mnl_dr(object):
    def __init__(self, LamD, N, K, d, omega, step, kappa = 0.5, beta = None, r_lambda=None,  vzero = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param S: upper bound of L2 norm of each w
        :param eta: step-size parameter for online update
        :param r_lambda: regularization parameter
        """
        # immediate attributes from the constructor
        self.LamD=LamD
        self.N = N
        self.d = d
        self.S = 1
        self.K = K
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        self.W_t = np.zeros(self.d)[:, None]  # estimated parameter
        self.H = self.r_lambda * np.identity(self.d)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d)  # inverse of H
        self.inputbeta=beta 
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        self.omega=omega
        self.step=step
        
    def choose_S(self, t, x, cat):
        """
        choose the optimistic assortment
        """
        if self.inputbeta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   


        means = np.squeeze(np.dot(x,self.W_t))
        xv = np.sqrt((np.matmul(x, self.inv_H) * x).sum(axis = 1))
        u = means + self.beta * xv
        #self.assortment = np.argsort(u)[::-1][:self.K]
        St=[]
        for it in range(self.K):
            new_rwd=[0]*self.N
            for i in range(self.N): 
                if i in St:
                    continue
                St1=St+[i]
                new_rwd[i]=self.compute_rwd(u[St1], self.LamD*set_function(cat[St1], self.omega, self.step))
            St.append(np.argmax(new_rwd))
        means = np.squeeze(np.dot(x[St],self.W_t))
        #print("exp_relevance: ", means)
        xv = np.sqrt((np.matmul(x[St], self.inv_H) * x[St]).sum(axis = 1))
        #print("confi_width: ", xv)
        self.assortment = np.array(St, dtype=int)    
        self.chosen_vectors = x[self.assortment,:]
        return(self.assortment)
    
    def compute_rwd(self, means, setrwd):
        u = np.exp(means)
        uSum = self.vzero + u.sum()
        rwd = u.sum()/uSum+setrwd
        return rwd

    def update_state(self, y):
        """
        update state
        """
        X = self.chosen_vectors
        assert isinstance(X, np.ndarray), 'np.array required'
        self.update_estimator(X, y)
        probs = self.Sigma(self.W_t, X)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
        
    def update_estimator(self, X, y):
        """
        update parameter
        """
        y = np.squeeze(y[:-1])
        W_estimate = self.W_t
        probs = np.squeeze(self.Sigma(self.W_t, X))
        g_Wt = np.sum(np.multiply(np.repeat((probs - y)[...,np.newaxis], self.d, axis=1), X), axis=0) # d dimension
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term
        M_t = 1/(2*self.eta) * self.H + 1/2 * gg_Wt
        inv_Mt = inv(M_t)
        unprojected_update = np.squeeze(W_estimate) - np.dot(inv_Mt,g_Wt)
        if np.linalg.norm(unprojected_update) > self.S:
            if self.K == 1:
                W_estimate = self.S * unprojected_update / np.linalg.norm(unprojected_update)
            else:
                W_estimate = self.projection(unprojected_update, M_t)[:,None]
            self.W_t = W_estimate
        else:
            self.W_t = unprojected_update
        self.W_t = unprojected_update
        #print("OFU-MNL-DR W_t: ", self.W_t)

    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma

    def proj_fun(self, W, un_projected, M):
        diff = W-un_projected
        fun = np.dot(diff, np.dot(M, diff))
        return fun

    def projection(self, unprojected, M):
        fun = lambda t: self.proj_fun(t, unprojected, M)
        constraints = []
        norm = lambda t: np.linalg.norm(t[self.d :self.d  + self.d])
        constraint = NonlinearConstraint(norm, 0, self.S)
        constraints.append(constraint)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
        return opt.x
    

# ofu_mnl_plus
class ofu_mnl_plus(object):
    def __init__(self, N, K, d, kappa = 0.5, beta = None, r_lambda=None,  vzero = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param S: upper bound of L2 norm of each w
        :param eta: step-size parameter for online update
        :param r_lambda: regularization parameter
        """
        # immediate attributes from the constructor
        self.N = N
        self.d = d
        self.S = 1
        self.K = K
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        self.W_t = np.zeros(self.d)[:, None]  # estimated parameter
        self.H = self.r_lambda * np.identity(self.d)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d)  # inverse of H
        self.inputbeta=beta 
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        
    def choose_S(self, t, x):
        """
        choose the optimistic assortment
        """
        if self.inputbeta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   

        means = np.squeeze(np.dot(x,self.W_t))
        xv = np.sqrt((np.matmul(x, self.inv_H) * x).sum(axis = 1))
        u = means + self.beta * xv
        self.assortment = np.argsort(u)[::-1][:self.K]
        self.chosen_vectors = x[self.assortment,:]
        return(self.assortment)

    def update_state(self, y):
        """
        update state
        """
        X = self.chosen_vectors
        assert isinstance(X, np.ndarray), 'np.array required'
        self.update_estimator(X, y)
        probs = self.Sigma(self.W_t, X)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
        
    def update_estimator(self, X, y):
        """
        update parameter
        """
        y = np.squeeze(y[:-1])
        W_estimate = self.W_t
        probs = np.squeeze(self.Sigma(self.W_t, X))
        g_Wt = np.sum(np.multiply(np.repeat((probs - y)[...,np.newaxis], self.d, axis=1), X), axis=0) # d dimension
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', X, X), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', X, X), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term
        M_t = 1/(2*self.eta) * self.H + 1/2 * gg_Wt
        inv_Mt = inv(M_t)
        unprojected_update = np.squeeze(W_estimate) - np.dot(inv_Mt,g_Wt)
        if np.linalg.norm(unprojected_update) > self.S:
            if self.K == 1:
                W_estimate = self.S * unprojected_update / np.linalg.norm(unprojected_update)
            else:
                W_estimate = self.projection(unprojected_update, M_t)[:,None]
            self.W_t = W_estimate
        else:
            self.W_t = unprojected_update
        self.W_t = unprojected_update
        #print("OFU-MNL W_t: ", self.W_t)

    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma

    def proj_fun(self, W, un_projected, M):
        diff = W-un_projected
        fun = np.dot(diff, np.dot(M, diff))
        return fun

    def projection(self, unprojected, M):
        fun = lambda t: self.proj_fun(t, unprojected, M)
        constraints = []
        norm = lambda t: np.linalg.norm(t[self.d :self.d  + self.d])
        constraint = NonlinearConstraint(norm, 0, self.S)
        constraints.append(constraint)
        opt = minimize(fun, x0=np.zeros(self.d), method='SLSQP', constraints=constraints)
        return opt.x
    

#UCB-MNL-plus
class ucb_mnl_plus:
    def __init__(self, N, K, d, kappa = 0.5, beta=None, vzero=1.0, lam = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param X: set of contexts
        :param Y: set of choice feedbacks
        :param S: upper bound of L2 norm of each w
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        """
        super(ucb_mnl_plus, self).__init__()
        self.N = N
        self.K = K
        self.d = d
        self.X = np.zeros((K,d))[np.newaxis, ...]
        self.Y = np.zeros(K+1)[np.newaxis, ...]
        self.S = 1
        self.kappa = kappa
        self.lam = lam
        self.eta = (self.S + 1) + np.log(self.K+1)/2
        self.r_lambda = 84 * np.sqrt(2) * self.eta * self.d
        # init parameter
        self.t = 1
        #self.W_t = np.zeros(self.d)[:, None]  # estimated parameter
        self.H = self.r_lambda * np.identity(self.d)  # hessian of loss matrix
        self.inv_H = 1/self.r_lambda * np.identity(self.d)  # inverse of H
        self.beta = beta  # confidence radius
        self.vzero = vzero  # utility for the outside option
        self.theta = np.zeros(d)  # estimated parameter
        self.V = np.eye(d)*lam  # grammatrix
        self.mnl = RegularizedMNLRegression()  # MLE loss function
        #self.alpha = alpha  # confidence radius
        
    def choose_S(self,t,x):  # x is N*d matrix
        """
        choose the optimistic assortment
        """
        if self.beta is None:
            ## calculate beta
            self.beta = np.sqrt(2 *self.eta *( (3*np.log(1 + (self.K +1)*t ) +3)
            *(17/16*self.r_lambda + 2* np.sqrt(self.r_lambda)*np.log(2*np.sqrt(1+2*t)) + 16*(np.log(2*np.sqrt(1+2*t)))**2 )
            + 2 + np.sqrt(6)*7/6*self.eta*self.d*np.log(1 + (t+1)/(2*self.r_lambda))) + 4*self.r_lambda)   
        means = np.dot(x,self.theta)
        xv = np.sqrt((np.matmul(x, self.inv_H) * x).sum(axis = 1))
        u = means + self.beta * xv
        self.assortment = np.argsort(u)[::-1][:self.K]
        self.chosen_vectors = x[self.assortment,:]
        self.X = np.concatenate((self.X, x[self.assortment,:][np.newaxis, ...]))
        self.V += np.matmul(x[self.assortment,:].T, x[self.assortment,:])
        return(self.assortment)
    
    def update_state(self, y):
        """
        update state
        """
        St = self.chosen_vectors
        assert isinstance(St, np.ndarray), 'np.array required'
        self.update_theta(y,self.t+1)
        probs = self.Sigma(self.theta, St)
        gg_Wt_fst_term = np.sum(probs[:, None, None] * np.einsum('ki,kj->kij', St, St), axis=0) 
        gg_Wt_snd_term = np.sum(np.einsum('i,j->ij', probs, probs)[:, :, np.newaxis, np.newaxis] 
                        * np.einsum('ik,jl->ijkl', St, St), axis=(0, 1))
        gg_Wt = gg_Wt_fst_term - gg_Wt_snd_term 
        self.H += gg_Wt
        self.inv_H = inv(self.H)
        self.t += 1
    
    def Sigma(self, W, X):
        """
        calculate MNL probability 
        """
        z = np.matmul(X, W)
        sigma = np.exp(z)
        sigma = sigma / (sigma.sum(axis=0) + self.vzero)
        return sigma
    
    def update_theta(self,Y,t):
        """
        update parameter
        """
        self.Y = np.concatenate((self.Y, Y[np.newaxis, ...]))
        if t==2:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        self.mnl.fit(self.X, self.Y, self.theta, self.lam)
        self.theta = self.mnl.w    



#UCB-MNL
class ucb_mnl:
    def __init__(self, N, K, d, kappa = 0.5, alpha=None, lam = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param X: set of contexts
        :param Y: set of choice feedbacks
        :param S: upper bound of L2 norm of each w
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        """
        super(ucb_mnl, self).__init__()
        self.N = N
        self.K = K
        self.d = d
        self.X = np.zeros((K,d))[np.newaxis, ...]
        self.Y = np.zeros(K+1)[np.newaxis, ...]
        self.S = 1
        self.kappa = kappa
        self.lam = lam
        # init parameter
        self.theta = np.zeros(d)  # estimated parameter
        self.V = np.eye(d)*lam  # grammatrix
        self.mnl = RegularizedMNLRegression()  # MLE loss function
        self.inputalpha=alpha
        self.alpha = alpha  # confidence radius
        
    def choose_S(self,t,x):  # x is N*d matrix
        """
        choose the optimistic assortment
        """
        if self.inputalpha is None:
            self.alpha = (1/(2*self.kappa))*np.sqrt(2*self.d*np.log(1+t/self.d)+2*np.log(t))
            self.alpha=self.alpha*0.01
        means = np.dot(x,self.theta)
        xv = np.sqrt((np.matmul(x, inv(self.V)) * x).sum(axis = 1))
        u = means + self.alpha * xv
        self.S = np.argsort(u)[::-1][:self.K]
        self.X = np.concatenate((self.X, x[self.S,:][np.newaxis, ...]))
        self.V += np.matmul(x[self.S,:].T, x[self.S,:])
        return(self.S)

    def update_theta(self,Y,t):
        """
        update parameter
        """
        self.Y = np.concatenate((self.Y, Y[np.newaxis, ...]))
        if t==2:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        self.mnl.fit(self.X, self.Y, self.theta, self.lam)
        self.theta = self.mnl.w

# TS-MNL with Gaussian approximation
class ts_mnl:
    def __init__(self, N, K, d, kappa = 0.5, alpha=None, lam = 1.0):
        """
        :param N: number of items
        :param d: dimension of the context vectors and unknown parameter w
        :param K: maximum assortment size
        :param X: set of contexts
        :param Y: set of choice feedbacks
        :param S: upper bound of L2 norm of each w
        :param kappa: degree of non-linearlity
        :param lam: regularization parameter
        """
        super(ts_mnl, self).__init__()
        self.N=N
        self.K=K
        self.d=d
        self.X=np.zeros((K,d))[np.newaxis, ...]
        self.Y=np.zeros(K+1)[np.newaxis, ...]
        self.S = 1
        self.kappa = kappa
        self.lam = lam  

        # init parameter
        self.theta=np.zeros(d)  # estimated parameter
        self.V=np.eye(d)*lam  # grammatrix
        self.mnl=RegularizedMNLRegression()  # MLE loss function
        self.alpha=alpha  # confidence radius
        self.inputalpha=alpha
        
        
    def choose_S(self,t,x):  # x is N*d matrix
        """
        choose the optimistic assortment
        """
        if self.inputalpha is None:
            self.alpha = (1/(2*self.kappa))*np.sqrt(2*self.d*np.log(1+t/self.d)+2*np.log(t))
            self.alpha=self.alpha*0.01
        theta_tilde = np.random.multivariate_normal(self.theta, np.square(self.alpha)*inv(self.V))
        means = np.dot(x,theta_tilde)            
        self.S = np.argsort(means)[::-1][:self.K]
        self.X = np.concatenate((self.X, x[self.S,:][np.newaxis, ...]))
        self.V += np.matmul(x[self.S,:].T, x[self.S,:])
        return(self.S)

    def update_theta(self,Y,t):
        """
        update parameter
        """
        self.Y = np.concatenate((self.Y, Y[np.newaxis, ...]))
        if t==2:
            self.X = np.delete(self.X, (0), axis=0)
            self.Y = np.delete(self.Y, (0), axis=0)
        self.mnl.fit(self.X, self.Y, self.theta, self.lam)
        self.theta = self.mnl.w

class RegularizedMNLRegression:

    def compute_prob(self, theta, x):
        means = np.dot(x, theta)
        u = np.exp(means)
        u_ones = np.column_stack((u,np.ones(u.shape[0])))
        logSumExp = u_ones.sum(axis=1)
        prob = u_ones/logSumExp[:,None]
        return prob

    def cost_function(self, theta, x, y, lam):
        m = x.shape[0]
        prob = self.compute_prob(theta, x)
        return -(1/m)*np.sum( np.multiply(y, np.log(prob))) + (1/m)*lam*np.linalg.norm(theta)

    def gradient(self, theta, x, y, lam):
        m = x.shape[0]
        prob = self.compute_prob(theta, x)
        eps = (prob-y)[:,:-1]
        grad = (1/m)*np.tensordot(eps,x,axes=([1,0],[1,0])) + (1/m)*lam*theta
        return grad

    def fit(self, x, y, theta, lam):
        opt_weights = fmin_tnc(func=self.cost_function, x0=theta, fprime=self.gradient, args=(x, y, lam), disp=False)
        self.w = opt_weights[0]
        return self