 # !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from scipy.optimize import minimize
import torch
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
from backpack import backpack, extend
from backpack.extensions import BatchGrad


#  ################## Linear Dueling Bandit Algorithms ##################
# Linear Dueling Bandit Algorithms with confidence bounds
class LinearConfDB:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, delta=0.05, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.delta          = delta             # confidence parameter
        
        # Norm of learner's parameter
        self.S = np.sqrt(input_dim)
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update = max(1, learner_update)   # Next update of the learner 
        self.samples    = 0                         # Number of samples
        self.Z          = []                        # Context-actions feature vectors
        self.ZY_sum     = np.zeros(self.input_dim)  # Sum of Context-actions feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.input_dim)
        self.kappa  = 1.0 / ((1 + np.exp(1.0)) * (1 + np.exp(-1.0)))    # Assuming l2-norm of theta = 1
        
        # Initializing model parameters
        self.theta  = np.ones(self.input_dim)/self.input_dim
    
    # Selecting a pair of arms
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = context_arms.dot(self.theta)
        
        # Fixed confidence term
        # Fixed terms in confidence bound
        sigma       = 0.5                           # From paper
        alpha_fix   = self.nu * (sigma/self.kappa)  # Fix term for alpha
        conf_delta  = np.log(1.0/self.delta)        # Confidence term with delta
           
        # Selecting the first arm
        at_1 = np.argmax(est_rt)
        at_2 = 1
        max_score = -np.inf

        if self.strategy == 'tss':
            # Following Linear TS (ICML 2023) paper approach
            log_term =  np.log(1+((2.0*(self.samples+1))/self.input_dim))
            alpha_t = alpha_fix * np.sqrt(((0.5*self.input_dim) * log_term) + conf_delta) + np.sqrt(self.lamdba)*self.S
            theta_tilde = np.random.multivariate_normal(self.theta, alpha_t*alpha_t*np.linalg.inv(self.V))
        
        for j in range(len(est_rt)):
            # Context-action difference
            zt_j1 = context_arms[j] - context_arms[at_1]  

            # Selecting the arm based on the strategy
            est_rt_j1 = est_rt[j]
            if self.strategy == 'ts':
                # Confidance term
                log_term =  np.log(1+((2.0*(self.samples+1))/self.input_dim))
                alpha_t = alpha_fix * np.sqrt(((0.5*self.input_dim) * log_term) + conf_delta) + np.sqrt(self.lamdba)*self.S
                zt_dot_V = np.inner(zt_j1, np.linalg.inv(self.V)) 
                zt_dot_V_zt = np.inner(zt_dot_V, zt_j1)  
                conf_term = alpha_t * np.sqrt(max(zt_dot_V_zt, 0)) 
                
                # Score based on Thompson Sampling like NeuralTS
                action_score = np.random.normal(loc=est_rt_j1, scale=conf_term)
            
            elif self.strategy == 'ucb':
                # Confidance term
                log_term =  np.log(1+((2.0*(self.samples+1))/self.input_dim))
                alpha_t = alpha_fix * np.sqrt(((0.5*self.input_dim) * np.log(1+((2.0*(self.samples+1))/self.input_dim))) + conf_delta) + np.sqrt(self.lamdba)*self.S              
                zt_dot_V = np.inner(zt_j1, np.linalg.inv(self.V)) 
                zt_dot_V_zt = np.inner(zt_dot_V, zt_j1)  
                conf_term = alpha_t * np.sqrt(max(zt_dot_V_zt, 0))
                
                # Score based on UCB
                action_score = est_rt_j1 + conf_term
                
            elif self.strategy == 'tss':
                # Score follows Linear Contextual Thompson Sampling paper
                action_score = zt_j1.dot(theta_tilde)
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                at_2 = j
            
        # Update the confidence matrix
        zt_12 = context_arms[at_1] - context_arms[at_2]
        self.Z.append(zt_12)
        self.V += np.outer(zt_12, zt_12)

        # Keeping the selected arms for model update
        self.at_1 = at_1
        self.at_2 = at_2
        
        return at_1, at_2
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Updating the context-action pairs and feedback tensors
        zt_12 = self.Z[-1]
        self.ZY_sum += zt_12*yt
        
        # Update the model with new context-action pair and feedback
        self.samples = len(self.Z)
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.samples    = 0
        self.Z          = []
        self.ZY_sum     = np.zeros(self.input_dim)
        self.V          = self.lamdba * np.identity(self.input_dim)
        self.theta      = np.ones(self.input_dim)/self.input_dim


# Linear Dueling Bandit Algorithms with confidence bounds like Neural Bandits
class LinearDB:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update = max(1, learner_update)   # Next update of the learner 
        self.samples    = 0                         # Number of samples
        self.Z          = []                        # Context-actions feature vectors
        self.ZY_sum     = np.zeros(self.input_dim)  # Sum of Context-actions feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.input_dim)
        
        # Initializing model parameters
        self.theta  = np.ones(self.input_dim)/self.input_dim
    
    # Selecting a pair of arms
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = context_arms.dot(self.theta)
        
        # Selecting the first arm
        at_1 = np.argmax(est_rt)
        at_2 = 1
        max_score = -np.inf
        
        # Alternative for Strategy: Thompson Sampling
        # sigma = self.lamdba * self.nu*np.linalg.inv(self.V)
        # theta_tilde = np.random.multivariate_normal(self.theta, sigma)
        # at_2 = np.argmax(context_arms.dot(theta_tilde))      
        for j in range(len(est_rt)):
            zt_j1 = context_arms[j] - context_arms[at_1]        
            zt_dot_V = np.inner(zt_j1, np.linalg.inv(self.V)) 
            zt_dot_V_zt = np.inner(zt_dot_V, zt_j1)  
            conf_term = self.nu * np.sqrt(max(zt_dot_V_zt, 0))

            # Selecting the arm based on the strategy
            est_rt_j1 = est_rt[j]
            if self.strategy == 'ts':
                action_score = np.random.normal(loc=est_rt_j1, scale=conf_term)
            
            elif self.strategy == 'ucb':
                action_score = est_rt_j1 + conf_term
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                at_2 = j
            
        # Update the confidence matrix
        zt_12 = context_arms[at_1] - context_arms[at_2]
        self.Z.append(zt_12)
        self.V += np.outer(zt_12, zt_12)

        # Keeping the selected arms for model update
        self.at_1 = at_1
        self.at_2 = at_2
        
        return at_1, at_2
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Updating the context-action pairs and feedback tensors
        zt_12 = self.Z[-1]
        self.ZY_sum += zt_12*yt
        
        # Update the model with new context-action pair and feedback
        self.samples = len(self.Z)
        if (self.samples + 1) % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.samples    = 0
        self.Z          = []
        self.ZY_sum     = np.zeros(self.input_dim)
        self.V          = self.lamdba * np.identity(self.input_dim)
        self.theta      = np.ones(self.input_dim)/self.input_dim


# Linear Dueling Bandit Algorithm: CoLSTIM
class CoLSTIM:
    def __init__(self, input_dim, lamdba=1, nu=1, size=100, learner_update=20, delta=0.05, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.size           = size              # Total numbers of rounds
        self.delta          = delta             # confidence parameter
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update = max(1, learner_update)   # Next update of the learner 
        self.samples    = 0                         # Number of samples
        self.Z          = []                        # Context-actions feature vectors
        self.ZY_sum     = np.zeros(self.input_dim)  # Sum of Context-actions feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.input_dim)
        
        # Initializing model parameters
        self.theta  = np.ones(self.input_dim)/self.input_dim
    
    # Selecting a pair of arms
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = context_arms.dot(self.theta)
        
        # Fixed confidence term
        # Fixed terms in confidence bound
        # Fixed terms in confidence bound
        C_thresh    = np.sqrt(self.input_dim * np.log(self.size))       # Threshold for the confidence bound
        c1          = C_thresh                                          # From paper 
        fix_pt_term = self.input_dim*np.log(self.input_dim*self.size)   # Confidence term with delta
           
        # Finding best pair of actions for all possible pairs
        pt = min(1.0, fix_pt_term/(np.sqrt(self.samples+1)))
        Bt = np.random.binomial(1, pt)
        tilde_eps_t = np.random.gumbel(size=(len(context_arms),))
        
        if Bt == 0:
            tilde_eps_t[:] = tilde_eps_t[0]
            
        # Selecting first arm
        max_ucb1 = -np.inf
        at_1 = 0
        for i in range(len(context_arms)):
            eps_ti = min(C_thresh, max(-C_thresh, tilde_eps_t[i]))
            conf_term = eps_ti * np.sqrt(np.inner(np.inner(context_arms[i], np.linalg.inv(self.V)), context_arms[i]))
            ucb1_val = est_rt[i] + conf_term
            if ucb1_val > max_ucb1:
                max_ucb1 = ucb1_val
                at_1 = i
        
        # Selecting the first arm
        at_2 = 1
        max_score = -np.inf

        for j in range(len(est_rt)):
            # Context-action difference
            zt_j1 = context_arms[j] - context_arms[at_1]  
            
            # Confidance term
            conf_term2 =  c1 * np.sqrt( np.inner(np.inner(zt_j1, np.linalg.inv(self.V)), zt_j1) )
            action_score = zt_j1.dot(self.theta) + conf_term2
        
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                at_2 = j
            
        # Update the confidence matrix
        zt_12 = context_arms[at_1] - context_arms[at_2]
        self.Z.append(zt_12)
        self.V += np.outer(zt_12, zt_12)

        # Keeping the selected arms for model update
        self.at_1 = at_1
        self.at_2 = at_2
        
        return at_1, at_2
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Updating the context-action pairs and feedback tensors
        zt_12 = self.Z[-1]
        self.ZY_sum += zt_12*yt
        
        # Update the model with new context-action pair and feedback
        self.samples = len(self.Z)
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.samples    = 0
        self.Z          = []
        self.ZY_sum     = np.zeros(self.input_dim)
        self.V          = self.lamdba * np.identity(self.input_dim)
        self.theta      = np.ones(self.input_dim)/self.input_dim


#  ################## Neural Dueling Bandit Algorithms ##################
# Pytorch keyword arguments
tkwargs = {
    "device": torch.device("cuda:0"),  # Other option: "cuda:0", "cpu", "mps" [For Apple M2 chips]
    # Another way to set the device: "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}
 
    
# Base Neural Network class
class Network(nn.Module):
    def __init__(self, input_dim, hidden_size=32, depth=2, init_params=None):
        # Calling the parent class constructor 
        super(Network, self).__init__()
        
        # Activation function
        self.activate = nn.ReLU()
        
        # Neural network architecture
        self.layer_list = nn.ModuleList()
        self.layer_list.append(nn.Linear(input_dim, hidden_size))
        for i in range(depth-1):
            self.layer_list.append(nn.Linear(hidden_size, hidden_size))
        self.layer_list.append(nn.Linear(hidden_size, 1))
        
        # Same NN initialization to maintain consistancy across all runs
        if init_params is None:
            # Initialization using normal distribution
            for i in range(len(self.layer_list)):
                torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
                torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
        else:
            # Use given initialization vector
            for i in range(len(self.layer_list)):
                self.layer_list[i].weight.data = init_params[i*2]
                self.layer_list[i].bias.data = init_params[i*2+1]
    
    def forward(self, x):
        # Input
        y = x
        
        # Forward pass
        for i in range(len(self.layer_list)-1):
            y = self.activate(self.layer_list[i](y))
        
        # Output
        y = self.layer_list[-1](y)
        
        return y


# ### Neural Dueling Bandit with diagonalization ###
class NeuralInitDB:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', diagonalize=False, learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.diagonalize    = diagonalize       # diagonalization of confidence matrix if true 
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update            = max(1, learner_update)   # Next update of the learner 
        self.samples                = 0       # number of samples
        self.context_actions_list   = None    # context-action pairs
        self.feedback_list          = None    # feedback for context-action pairs

        # Initializing neural network model with pytorch
        self.func = extend(Network(self.input_dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
        
        # Initial NN model for feature extraction
        self.init_func = deepcopy(self.func)

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
    
    # Selecting a pair of arms
    def select(self, context_actions):
        # Keeping context-arms for model update
        self.context_arms = context_actions
        
        # Changing context-actions to tensor 
        context_actions = torch.from_numpy(context_actions).float().to(**tkwargs)
        
        # Calculating the feature vectors for each context-action pair
        grad_list = []
        for a in range(len(context_actions)):
            # Reward for using the initial NN model
            rt_init_a = self.init_func(context_actions[a])
            
            # Zeroing the gradients before backpropagation step
            self.init_func.zero_grad()
            rt_init_a.backward(retain_graph=True)
            
            # Extracting gradients of the context-action vectors
            grad_rt_init_a = torch.cat([p.grad.flatten().detach() for p in self.init_func.parameters()])
            grad_list.append(grad_rt_init_a)
        
        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = self.func(context_actions)
        
        # Selecting the first arm
        at_1 = torch.argmax(est_rt).item()
        at_2 = 1
        max_score = -np.inf
        for j in range(len(est_rt)):
            # Difference of features of the selected arm and the current arm
            zt_j1 = grad_list[j] - grad_list[at_1]
            
            if self.diagonalize:
                ### diagonalization
                conf_term = torch.clamp(self.nu * zt_j1 * zt_j1 / self.V, min=0)    
                sigma =torch.sqrt(torch.sum(conf_term))
            else:
                ### no diagonalization
                zt_j1 = zt_j1.to("cpu")
                zt_dot_U = torch.matmul(zt_j1, torch.inverse(self.V))
                zt_dot_U_zt = torch.matmul(zt_dot_U, zt_j1.t())  
                conf_term = torch.clamp(self.nu * zt_dot_U_zt, min=0)
                sigma = torch.sqrt(conf_term)

            # Selecting the arm based on the strategy
            est_rt_j1 = est_rt[j]
            if self.strategy == 'ts':
                action_score = torch.normal(est_rt_j1.view(-1), sigma.view(-1))
                #Alternative: np.random.normal(loc=est_rt_j1.item(), scale=sigma.item())
            
            elif self.strategy == 'ucb':
                action_score = est_rt_j1.item() + sigma.item()
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                at_2 = j
            
        # Update the confidence matrix
        zt_12 = grad_list[at_1] - grad_list[at_2]
        
        if self.diagonalize:
            ### diagonalization
            self.V += zt_12 * zt_12
        else:
            ### no diagonalization
            self.V += torch.outer(zt_12, zt_12)

        # Keeping the selected arms for model update
        self.at_1 = at_1
        self.at_2 = at_2
        
        return at_1, at_2
    
    # Update the model with new context-action pair and feedback
    def update(self, yt,local_training_iter=50):
        # Ensuring same initial state of the NN model
        if self.init_state_dict is not None:
            self.func.load_state_dict(deepcopy(self.init_state_dict))
        
        # Converting numpy variables to tensors
        xt_1 = self.context_arms[self.at_1]
        xt_2 = self.context_arms[self.at_2]        
        xt_1_tensor = torch.from_numpy(xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)
        
        # Updating the context-action pairs and feedback tensors
        if self.context_actions_list is None:
            # Adding the first context-action pair  
            self.context_actions_list = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.context_actions_list = torch.cat((self.context_actions_list, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.context_actions_list.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.context_actions_list[0].reshape(self.samples, -1)
                x_2 = self.context_actions_list[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
                
                # print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    # Reset the learner
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.context_actions_list = None
        self.feedback_list = None
        
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))
            

# ### Neural Dueling Bandit using updated function for features ###
class NeuralDB:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', diagonalize=False, learner_update=20, increasing_delay=False, hidden_size = 50, local_training_iter=50):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.diagonalize    = diagonalize       # diagonalization of confidence matrix if true 
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0
        self.local_training_iter = local_training_iter

        # Initial variables for storing information
        self.next_update            = max(1, learner_update)   # Next update of the learner 
        self.samples                = 0       # number of samples
        self.context_actions_list   = None    # context-action pairs
        self.feedback_list          = None    # feedback for context-action pairs

        # Initializing neural network model with pytorch
        self.func = extend(Network(self.input_dim, hidden_size=hidden_size).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
        
        # Initial NN model for feature extraction
        self.init_func = deepcopy(self.func)

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
    
    # Selecting a pair of arms
    def select(self, context_actions):
        # Keeping context-arms for model update
        self.context_arms = context_actions
                
        # Changing context-actions to tensor 
        self.func.train()
        if self.context_actions_list is not None:
            context_actions = self.context_actions_list.to(**tkwargs)
            
            # Calculating the feature vectors for observed context-action pair
            grad_list = []
            batch = 500
            num_context = len(context_actions)
            last_batch = num_context % batch
            for a in range(0, num_context, batch):
                # Reward for using the initial NN model
                rt_init_a = self.func(context_actions[a:a+batch])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())
            
            # Context-actions in the last batch
            if num_context % batch != 0 and (a+batch) < num_context:
                rt_init_a = self.func(context_actions[-last_batch:])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())
                
            grad_list = torch.vstack(grad_list)
            self.V = grad_list.transpose(0,1).matmul(grad_list) + self.lamdba * torch.eye(self.total_param)
        else:
            self.V = self.lamdba * torch.eye(self.total_param)
        
        
        # Getting the feature vectors for the context-actions of the current round
        context_actions = torch.from_numpy(self.context_arms).float().to(**tkwargs)

        # Calculating the feature vectors for each context-action pair
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the current NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        grad_list = torch.vstack(grad_list)
        
        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = self.func(context_actions)
        
        # Selecting the first arm
        at_1 = torch.argmax(est_rt).item()
        at_2 = 1
        max_score = -np.inf
        for j in range(len(est_rt)):
            # Difference of features of the selected arm and the current arm
            zt_j1 = grad_list[j] - grad_list[at_1]
            
            if self.diagonalize:
                ### diagonalization
                conf_term = torch.clamp(zt_j1 * zt_j1 / self.V, min=0)    
                sigma = self.nu * torch.sqrt(torch.sum(conf_term))   
                
                # Initial rounds sigma can be nan 
                base_sigma = torch.tensor(0.1)          
                sigma = base_sigma if torch.isnan(sigma) else sigma
                
            else:
                ### no diagonalization
                zt_j1 = zt_j1.to("cpu")
                zt_dot_U = torch.matmul(zt_j1, torch.inverse(self.V))
                zt_dot_U_zt = torch.matmul(zt_dot_U, zt_j1.t())  
                conf_term = torch.clamp(zt_dot_U_zt, min=0)
                sigma = self.nu * torch.sqrt(conf_term)

            # Selecting the arm based on the strategy
            est_rt_j1 = est_rt[j]
            if self.strategy == 'ts':
                est_rt_j1 = est_rt_j1.to("cpu")
                action_score = torch.normal(est_rt_j1.view(-1), sigma.view(-1))
                
                #Alternative: np.random.normal(loc=est_rt_j1.item(), scale=sigma.item())
            
            elif self.strategy == 'ucb':
                action_score = est_rt_j1.item() + sigma.item()
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                at_2 = j
            
        # Update the confidence matrix
        zt_12 = grad_list[at_1] - grad_list[at_2]
        
        if self.diagonalize:
            ### diagonalization
            self.V += zt_12 * zt_12
        else:
            ### no diagonalization
            self.V += torch.outer(zt_12, zt_12)

        # Keeping the selected arms for model update
        self.at_1 = at_1
        self.at_2 = at_2
                
        return at_1, at_2
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Ensuring same initial state of the NN model
        if self.init_state_dict is not None:
            self.func.load_state_dict(deepcopy(self.init_state_dict))
        
        # Converting numpy variables to tensors
        xt_1 = self.context_arms[self.at_1]
        xt_2 = self.context_arms[self.at_2]        
        xt_1_tensor = torch.from_numpy(xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)
        
        # Updating the context-action pairs and feedback tensors
        if self.context_actions_list is None:
            # Adding the first context-action pair  
            self.context_actions_list = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.context_actions_list = torch.cat((self.context_actions_list, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.context_actions_list.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(self.local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.context_actions_list[0].reshape(self.samples, -1)
                x_2 = self.context_actions_list[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
                
                # print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    # Reset the learner
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.context_actions_list = None
        self.feedback_list = None
        
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))
            


#  ################## GLM Bandit Algorithms ##################
# Linear GLM Bandit Algorithms with confidence bounds like Neural Bandits
class LinearGLM:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update = max(1, learner_update)   # Next update of the learner 
        self.samples    = 0                         # Number of samples
        self.X          = []                        # Context-actions feature vectors
        self.XY_sum     = np.zeros(self.input_dim)  # Sum of Context-actions feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.input_dim)
        
        # Initializing model parameters
        self.theta  = np.ones(self.input_dim)/self.input_dim
    
    # Selecting a pair of arms
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = context_arms.dot(self.theta)
        
        # Selecting the arm
        a_t = 0
        max_score = -np.inf   
        for a in range(len(est_rt)):
            xt_a = context_arms[a]        
            xt_dot_V = np.inner(xt_a, np.linalg.inv(self.V)) 
            xt_dot_V_zt = np.inner(xt_dot_V, xt_a)  
            conf_term = self.nu * np.sqrt(max(xt_dot_V_zt, 0))

            # Selecting the arm based on the strategy
            est_rt_a = est_rt[a]
            if self.strategy == 'ts':
                action_score = np.random.normal(loc=est_rt_a, scale=conf_term)
            
            elif self.strategy == 'ucb':
                action_score = est_rt_a + conf_term
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                a_t = a
            
        # Update the confidence matrix
        xt_a = context_arms[a_t] 
        self.X.append(xt_a)
        self.V += np.outer(xt_a, xt_a)

        # Keeping the selected arms for model update
        self.a_t = a_t
        
        return a_t
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Updating the context-action pairs and feedback tensors
        xt_a = self.X[-1]
        self.XY_sum += xt_a*yt
        
        # Update the model with new context-action pair and feedback
        self.samples = len(self.X)
        if (self.samples + 1) % self.next_update == 0:
            self.next_update += self.learner_update
            X = np.array(self.X)
            
            # Objective function
            def negloglik_glm(th):
                return -self.XY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(X.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                XMu = np.zeros(len(X[0]))
                for t in range(len(X)):
                    XMu += X[t] * (1.0/(1.0+np.exp(-X.dot(th))))[t]
                return -self.XY_sum + XMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.samples    = 0
        self.X          = []
        self.XY_sum     = np.zeros(self.input_dim)
        self.V          = self.lamdba * np.identity(self.input_dim)
        self.theta      = np.ones(self.input_dim)/self.input_dim


# ### Neural GLM Bandit using updated function for features ###
class NeuralGLM:
    def __init__(self, input_dim, lamdba=1, nu=1, strategy='ucb', diagonalize=False, learner_update=20, increasing_delay=False, hidden_size = 50, local_training_iter=50):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input  
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.diagonalize    = diagonalize       # diagonalization of confidence matrix if true 
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0
        self.local_training_iter = local_training_iter

        # Initial variables for storing information
        self.next_update            = max(1, learner_update)   # Next update of the learner 
        self.samples                = 0       # number of samples
        self.context_actions_list   = None    # context-action pairs
        self.feedback_list          = None    # feedback for context-action pairs

        # Initializing neural network model with pytorch
        self.func = extend(Network(self.input_dim, hidden_size=hidden_size).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
        
        # Initial NN model for feature extraction
        self.init_func = deepcopy(self.func)

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
    
    # Selecting a pair of arms
    def select(self, context_actions):
        # Keeping context-arms for model update
        self.context_arms = context_actions
                
        # Changing context-actions to tensor 
        self.func.train()
        if self.context_actions_list is not None:
            context_actions = self.context_actions_list.to(**tkwargs)
            
            # Calculating the feature vectors for observed context-action pair
            grad_list = []
            batch = 500
            num_context = len(context_actions)
            last_batch = num_context % batch
            for a in range(0, num_context, batch):
                # Reward for using the initial NN model
                rt_init_a = self.func(context_actions[a:a+batch])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())
            
            # Context-actions in the last batch
            if num_context % batch != 0 and (a+batch) < num_context:
                rt_init_a = self.func(context_actions[-last_batch:])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())
                
            grad_list = torch.vstack(grad_list)
            self.V = grad_list.transpose(0,1).matmul(grad_list) + self.lamdba * torch.eye(self.total_param)
        else:
            self.V = self.lamdba * torch.eye(self.total_param)
        
        
        # Getting the feature vectors for the context-actions of the current round
        context_actions = torch.from_numpy(self.context_arms).float().to(**tkwargs)

        # Calculating the feature vectors for each context-action pair
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the current NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        grad_list = torch.vstack(grad_list)
        
        # ### Selecting the arms ###
        # Current estimate of latent reward
        est_rt = self.func(context_actions)
        
        # Selecting the first arm
        a_t = 0
        max_score = -np.inf
        for a in range(len(est_rt)):
            # Difference of features of the selected arm and the current arm
            xt_a = grad_list[a]
            
            if self.diagonalize:
                ### diagonalization
                conf_term = torch.clamp(xt_a * xt_a / self.V, min=0)    
                sigma = self.nu * torch.sqrt(torch.sum(conf_term))   
                
                # Initial rounds sigma can be nan 
                base_sigma = torch.tensor(0.1)          
                sigma = base_sigma if torch.isnan(sigma) else sigma
                
            else:
                ### no diagonalization
                xt_a = xt_a.to("cpu")
                xt_dot_U = torch.matmul(xt_a, torch.inverse(self.V))
                xt_dot_U_xt = torch.matmul(xt_dot_U, xt_a.t())  
                conf_term = torch.clamp(xt_dot_U_xt, min=0)
                sigma = self.nu * torch.sqrt(conf_term)

            # Selecting the arm based on the strategy
            est_rt_a = est_rt[a]
            if self.strategy == 'ts':
                est_rt_a = est_rt_a.to("cpu")
                action_score = torch.normal(est_rt_a.view(-1), sigma.view(-1))
                
                #Alternative: np.random.normal(loc=est_rt_a.item(), scale=sigma.item())
            
            elif self.strategy == 'ucb':
                action_score = est_rt_a.item() + sigma.item()
            
            else:
                raise RuntimeError('Exploration strategy not set')
            
            # Selecting the second best arm
            if action_score > max_score:
                max_score = action_score
                a_t = a
            
        # Update the confidence matrix
        xt_a = grad_list[a_t]
        
        if self.diagonalize:
            ### diagonalization
            self.V += xt_a * xt_a
        else:
            ### no diagonalization
            self.V += torch.outer(xt_a, xt_a)

        # Keeping the selected arms for model update
        self.a_t = a_t
                
        return a_t
    
    # Update the model with new context-action pair and feedback
    def update(self, yt):
        # Ensuring same initial state of the NN model
        if self.init_state_dict is not None:
            self.func.load_state_dict(deepcopy(self.init_state_dict))
        
        # Converting numpy variables to tensors
        xt_a = self.context_arms[self.a_t]       
        xt_a_tensor = torch.from_numpy(xt_a).reshape(1, -1).to(**tkwargs)
        yt_tensor = torch.tensor([yt]).to(**tkwargs)
        
        # Updating the context-action pairs and feedback tensors
        if self.context_actions_list is None:
            # Adding the first context-action pair  
            self.context_actions_list = xt_a_tensor
            self.feedback_list = yt_tensor     
        else:
            self.context_actions_list = torch.cat((self.context_actions_list, xt_a_tensor))
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.context_actions_list.shape[0]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(self.local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x = self.context_actions_list.reshape(self.samples, -1)
                logits = self.func(x).reshape(-1)   
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
                
                # print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    # Reset the learner
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.context_actions_list = None
        self.feedback_list = None
        
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))
            
