# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from scipy.optimize import minimize
import torch
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
from backpack import backpack, extend
from backpack.extensions import BatchGrad


#  ################## Active Dueling Bandit Algorithms for Non-linear functions ##################
# Pytorch keyword arguments
tkwargs = {
    "device": torch.device("cuda:0"),  # Other option: "cuda:0", "cpu", "mps" [For Apple M2 chips]
    # Another way to set the device: "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}
 
    
# Base Neural Network class
class Network(nn.Module):
    def __init__(self, input_dim, hidden_size=50, depth=2, init_params=None):
        # Calling the parent class constructor 
        super(Network, self).__init__()
        
        # Activation function
        self.activate = nn.ReLU()
        
        # Neural network architecture
        self.layer_list = nn.ModuleList()
        self.layer_list.append(nn.Linear(input_dim, hidden_size))
        for i in range(depth-1):
            self.layer_list.append(nn.Linear(hidden_size, hidden_size))
        self.layer_list.append(nn.Linear(hidden_size, 1))
        
        # Same NN initialization to maintain consistancy across all runs
        if init_params is None:
            # Initialization using normal distribution
            for i in range(len(self.layer_list)):
                torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
                torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
        else:
            # Use given initialization vector
            for i in range(len(self.layer_list)):
                self.layer_list[i].weight.data = init_params[i*2]
                self.layer_list[i].bias.data = init_params[i*2+1]
    
    def forward(self, x):
        # Input
        y = x
        
        # Forward pass
        for i in range(len(self.layer_list)-1):
            y = self.activate(self.layer_list[i](y))
        
        # Output
        y = self.layer_list[-1](y)
        
        return y


#  ################## Context selection uses Grad features ##################
# Active Dueling Bandit Algorithm: Neural-ADB using NDB (with both ucb and ts) for NN ablations
class NeuralADBGradAbl:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, layers=2, hidden=32, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim, layers, hidden).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_V_inv = self.flatten_context_arms_features_diff @ self.V_inv                      # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_features_diff, X_V_inv)   # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                                 # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Select the arms for the selected context: maximizes the UCB values
        xt = torch.from_numpy(self.context_arms[xt_ind]).float().to(**tkwargs)
        est_rt = self.func(xt)
        at_1 = torch.argmax(est_rt).item()

        # Select the second best arm that maximize the UCB wrt at_1
        max_ucb = -np.inf
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Compute the information gain for all arms
        for a in range(self.A):
            if a != at_1:
                arm_ucb = 0
                xt_af = self.context_arms_features[xt_ind,a]
                rt_a = est_rt[a]
                ig = ((self.xt_1f - xt_af).dot(self.V_inv)).dot(self.xt_1f - xt_af)
                ct = self.nu * np.sqrt(ig)  

                if self.strategy == 'ts':
                    arm_ucb = torch.normal(rt_a, ct)

                elif self.strategy == 'ucb':
                    arm_ucb = rt_a + ct

                if arm_ucb > max_ucb:
                    max_ucb, at_2, self.xt_2f = arm_ucb, a, xt_af

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: AE-Borda (NN)
class AEBordaNNGrad:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.UCB            = None                      # Upper confidence bound
        self.LCB            = None                      # Lower confidence bound
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        est_rt  = self.func(self.flatten_context_arms)                                  # Compute the estimated rewards
        X_V_inv = self.flatten_context_arms_features @ self.V_inv                       # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_features, X_V_inv)    # Efficiently computes row-wise dot product
        conf_term = self.nu * np.sqrt(results)                                          # Compute the confidence term
        
        # Loading from CPU to specific device
        conf_term = torch.from_numpy(conf_term).float().to(**tkwargs)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)
        conf_term = conf_term.reshape(self.C, self.A, -1)

        # Compute the upper and lower confidence bounds for each context-arm pair
        self.UCB = est_rt + conf_term
        self.LCB = est_rt - conf_term

        # Compute the maximum difference between the upper and lower confidence bounds for each context
        max_diff = torch.max(self.UCB, dim=1).values - torch.max(self.LCB, dim=1).values

        # Select the context with the maximum difference
        xt_ind = torch.argmax(max_diff).item()

        # Select the arms for the selected context: maximizes the UCB values
        at_1 = torch.argmax(self.UCB[xt_ind]).item()

        # Select the second best arm: randomly among the remaining arms
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))

        # Selected contexts and arms features
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2 
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
                
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
         # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.contx_actions = None
        self.feedback_list = None
        
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: Neural-APO
class NeuralAPOGrad:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_V_inv = self.flatten_context_arms_features_diff @ self.V_inv                      # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_features_diff, X_V_inv)   # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                                 # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Selected contexts and arms features
        self.xt_1f = self.context_arms_features[xt_ind,max_index[1]]
        self.xt_2f = self.context_arms_features[xt_ind,max_index[2]]

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,max_index[1]]
        self.xt_2 = self.context_arms[xt_ind,max_index[2]]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: Neural-ADB using NDB(with both ucb and ts)
class NeuralADBGrad:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_V_inv = self.flatten_context_arms_features_diff @ self.V_inv                      # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_features_diff, X_V_inv)   # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                                 # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Select the arms for the selected context: maximizes the UCB values
        xt = torch.from_numpy(self.context_arms[xt_ind]).float().to(**tkwargs)
        est_rt = self.func(xt)
        at_1 = torch.argmax(est_rt).item()

        # Select the second best arm that maximize the UCB wrt at_1
        max_ucb = -np.inf
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Compute the information gain for all arms
        for a in range(self.A):
            if a != at_1:
                arm_ucb = 0
                xt_af = self.context_arms_features[xt_ind,a]
                rt_a = est_rt[a]
                ig = ((self.xt_1f - xt_af).dot(self.V_inv)).dot(self.xt_1f - xt_af)
                ct = self.nu * np.sqrt(ig)  

                if self.strategy == 'ts':
                    arm_ucb = torch.normal(rt_a, ct)

                elif self.strategy == 'ucb':
                    arm_ucb = rt_a + ct

                if arm_ucb > max_ucb:
                    max_ucb, at_2, self.xt_2f = arm_ucb, a, xt_af

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: Neural-ADB using IG (with both ucb and ts)
class NeuralADBIGGrad:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_V_inv = self.flatten_context_arms_features_diff @ self.V_inv                      # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_features_diff, X_V_inv)   # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                                 # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Select the arms for the selected context: maximizes the UCB values
        xt = torch.from_numpy(self.context_arms[xt_ind]).float().to(**tkwargs)
        est_rt = self.func(xt)
        at_1 = torch.argmax(est_rt).item()

        # Select the second best arm that maximize the UCB wrt at_1
        max_ig = -np.inf
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Compute the information gain for all arms
        for a in range(self.A):
            if a != at_1:
                arm_ig = 0
                xt_af = self.context_arms_features[xt_ind,a]
                rt_a = 0*est_rt[a]
                ig = ((self.xt_1f - xt_af).dot(self.V_inv)).dot(self.xt_1f - xt_af)
                ct = self.nu * np.sqrt(ig)  

                if self.strategy == 'ts':
                    arm_ig = torch.normal(rt_a, ct)

                elif self.strategy == 'ucb':
                    arm_ig = rt_a + ct

                if arm_ig > max_ig:
                    max_ig, at_2, self.xt_2f = arm_ig, a, xt_af

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# ################# Context selection uses actual features ##################
# Active Dueling Bandit Algorithm: Neural-APO
class NeuralAPO:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Used for context selection
        self.U = lamdba * np.identity(self.dim)
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        self.U_inv = np.linalg.inv(self.U)
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        self.flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(self.flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_U_inv = self.flatten_contx_arms_diff @ self.U_inv                       # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_contx_arms_diff, X_U_inv)    # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                       # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Selected contexts and arms features
        self.xt_1f = self.context_arms_features[xt_ind,max_index[1]]
        self.xt_2f = self.context_arms_features[xt_ind,max_index[2]]

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,max_index[1]]
        self.xt_2 = self.context_arms[xt_ind,max_index[2]]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.U += np.outer(zt, zt)
        self.U_inv = np.linalg.inv(self.U)
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()
        self.U = self.lamdba * np.identity(self.dim)
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: Neural-ADB using NDB(with both ucb and ts)
class NeuralADB:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Used for context selection
        self.U = lamdba * np.identity(self.dim)
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        self.U_inv = np.linalg.inv(self.U)
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        self.flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(self.flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_U_inv = self.flatten_contx_arms_diff @ self.U_inv                       # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_contx_arms_diff, X_U_inv)    # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                         # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Select the arms for the selected context: maximizes the UCB values
        xt = torch.from_numpy(self.context_arms[xt_ind]).float().to(**tkwargs)
        est_rt = self.func(xt)
        at_1 = torch.argmax(est_rt).item()

        # Select the second best arm that maximize the UCB wrt at_1
        max_ucb = -np.inf
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Compute the information gain for all arms
        for a in range(self.A):
            if a != at_1:
                arm_ucb = 0
                xt_af = self.context_arms_features[xt_ind,a]
                rt_a = est_rt[a]
                ig = ((self.xt_1f - xt_af).dot(self.V_inv)).dot(self.xt_1f - xt_af)
                ct = self.nu * np.sqrt(ig)  

                if self.strategy == 'ts':
                    arm_ucb = torch.normal(rt_a, ct)

                elif self.strategy == 'ucb':
                    arm_ucb = rt_a + ct

                if arm_ucb > max_ucb:
                    max_ucb, at_2, self.xt_2f = arm_ucb, a, xt_af

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.U += np.outer(zt, zt)
        self.U_inv = np.linalg.inv(self.U)     
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively

        # Convert the context-arm pairs and feedback to tensors
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()
        self.U = self.lamdba * np.identity(self.dim)
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# Active Dueling Bandit Algorithm: Neural-ADB using IG (with both ucb and ts)
class NeuralADBIG:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False, diagonalize=False, grads_initial_params=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.diagonalize    = diagonalize           # diagonalization of confidence matrix if true else full matrix
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.xt_1f          = None                      # First selected context-arm pair features
        self.xt_2f          = None                      # Second selected context-arm pair features
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.contx_actions  = None                      # All context-action pairs
        self.feedback_list  = None                      # feedback for context-action pairs
        
        # Initializing neural network model with pytorch
        self.func = extend(Network(self.dim).to(**tkwargs))
        
        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())
       
        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
            
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))
        
        # Used for context selection
        self.U = lamdba * np.identity(self.dim)
        
        # Inverse of the gram matrix
        self.V_inv = np.linalg.inv(self.V) 
        self.U_inv = np.linalg.inv(self.U)
        
        # Get the context-arm pairs
        flatten_contx_arms = self.context_arms.reshape(self.C * self.A, -1)
        self.flatten_context_arms = torch.from_numpy(flatten_contx_arms).float().to(**tkwargs)
        self.flatten_context_arms_features = self.all_feature_vectors(self.flatten_context_arms)
        self.context_arms_features = self.flatten_context_arms_features.reshape(self.C, self.A, -1).numpy()

        # Get all possible difference of context-arms pairs
        self.flatten_contx_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])
        self.flatten_context_arms_diff = torch.from_numpy(self.flatten_contx_arms_diff).float().to(**tkwargs)
        self.flatten_context_arms_features_diff = self.all_feature_vectors(self.flatten_context_arms_diff)
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Calculating the feature vectors for given context-action pairs
    def all_feature_vectors(self, context_actions):
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        
        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
            
        return torch.vstack(grad_list)


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_U_inv = self.flatten_contx_arms_diff @ self.U_inv                       # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_contx_arms_diff, X_U_inv)    # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                       # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_features_diff[index_to_remove] = torch.zeros(self.total_param)

        # Select context and arms
        xt_ind = max_index[0]

        # Select the arms for the selected context: maximizes the UCB values
        xt = torch.from_numpy(self.context_arms[xt_ind]).float().to(**tkwargs)
        est_rt = self.func(xt)
        at_1 = torch.argmax(est_rt).item()

        # Select the second best arm that maximize the UCB wrt at_1
        max_ig = -np.inf
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))
        self.xt_1f = self.context_arms_features[xt_ind,at_1]
        self.xt_2f = self.context_arms_features[xt_ind,at_2]

        # Compute the information gain for all arms
        for a in range(self.A):
            if a != at_1:
                arm_ig = 0
                xt_af = self.context_arms_features[xt_ind,a]
                rt_a = 0*est_rt[a]
                ig = ((self.xt_1f - xt_af).dot(self.V_inv)).dot(self.xt_1f - xt_af)
                ct = self.nu * np.sqrt(ig)  

                if self.strategy == 'ts':
                    arm_ig = torch.normal(rt_a, ct)

                elif self.strategy == 'ucb':
                    arm_ig = rt_a + ct

                if arm_ig > max_ig:
                    max_ig, at_2, self.xt_2f = arm_ig, a, xt_af

        # Selected contexts and arms
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2   
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt, local_training_iter=50):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.U += np.outer(zt, zt)
        self.U_inv = np.linalg.inv(self.U)
        zt_f = self.xt_1f - self.xt_2f
        self.V += np.outer(zt_f, zt_f)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively
        xt_1_tensor = torch.from_numpy(self.xt_1).reshape(1, -1).to(**tkwargs)
        xt_2_tensor = torch.from_numpy(self.xt_2).reshape(1, -1).to(**tkwargs)
        xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
        yt_tensor = torch.tensor([yt]).to(**tkwargs)

        # Updating the context-action pairs and feedback tensors
        if self.contx_actions is None:
            # Adding the first context-action pair  
            self.contx_actions = xt_pair
            self.feedback_list = yt_tensor     
        else:
            self.contx_actions = torch.cat((self.contx_actions, xt_pair.reshape(2, 1, -1)), dim=1)
            self.feedback_list = torch.cat([self.feedback_list, yt_tensor])    
        
        # Update the model with new context-action pair and feedback
        self.samples = self.contx_actions.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()
        
        if self.samples % self.next_update == 0:
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.contx_actions[0].reshape(self.samples, -1)
                x_2 = self.contx_actions[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores        
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()
           
            #     print (f"Step {_} Loss: {loss.item()}")     
            # print("Training Loss : ", loss.item(), self.samples)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair 
        est_rt = self.func(self.flatten_context_arms)

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = torch.argmax(est_rt, dim=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.flatten().cpu().numpy()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.func_updates = 0
        self.contx_actions = None
        self.feedback_list = None
        self.context_arms_features_diff = self.flatten_context_arms_features_diff.reshape(self.C, self.A, -1).numpy()
        self.U = self.lamdba * np.identity(self.dim)
        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))


# ########### Context selection uses actual features with weights ###########
# Active Dueling Bandit Algorithm: Neural-APO ??