 # !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
import random
from itertools import combinations

 # from pygame.sndarray import samples
from scipy.optimize import minimize
from scipy.optimize import linear_sum_assignment
import torch
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
from backpack import backpack, extend
from backpack.extensions import BatchGrad

# ################## Linear Dueling Bandit Algorithms ##################
# Linear Dueling Bandit Algorithms with confidence bounds

class LinearConfDB:
    def __init__(self, input_dim, each_round_arms, lamdba=1, nu=1, strategy='ucb', reward_function='add', learner_update=20, delta=0.05,
                increasing_delay=False):
        # Initialial parameters
        self.input_dim = input_dim  # dimension of input
        self.each_round_arms = each_round_arms   # size of each super-arm
        self.lamdba = lamdba  # regularization parameter
        self.nu = nu  # confidence parameter
        self.strategy = strategy  # arm-selection strategy
        self.reward_function = reward_function     #reward function of super-arm
        self.delta = delta  # confidence parameter

        # Norm of learner's parameter
        self.S = np.sqrt(input_dim)

        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update = max(1, learner_update)  # Next update of the learner
        self.samples = 0  # Number of samples
        self.Z = []  # Context-actions feature vectors
        self.ZY_sum = np.zeros(self.input_dim)  # Sum of Context-actions feature vectors * Y

        # Initialialization of gram matrix if features are fixed
        self.V = lamdba * np.identity(self.input_dim)
        self.kappa = 1.0 / ((1 + np.exp(1.0)) * (1 + np.exp(-1.0)))  # Assuming l2-norm of theta = 1

        # Initializing model parameters
        self.theta = np.ones(self.input_dim) / self.input_dim

    def select(self, context_arms):

        k = self.each_round_arms

        self.context_arms = context_arms

        if self.reward_function == 'add':

            N, d = context_arms.shape
            est_r = context_arms @ self.theta
            V_inv = np.linalg.inv(self.V)

            # Step 1: Greedy select arm_set1 (top-k by score)
            arm_set1 = np.argsort(est_r)[-k:][::-1]
            X1 = context_arms[arm_set1]

            # Step 2: Compute cost matrix for matching with all N arms
            cost_matrix = np.zeros((k, N))
            for i in range(k):
                x1 = X1[i]
                for j in range(N):
                    x2 = context_arms[j]
                    diff = x1 - x2
                    norm = diff @ V_inv @ diff
                    norm = np.sqrt(norm)
                    conf = self.nu * np.sqrt(max(norm, 0))
                    if self.strategy == 'ucb':
                        score = est_r[j] + conf
                    elif self.strategy == 'ts':
                        score = np.random.normal(loc=est_r[j], scale=conf)
                    else:
                        raise ValueError("Unknown strategy")
                    cost_matrix[i, j] = -score  # 转成最小化问题

            # Step 3: Solve assignment problem
            row_ind, col_ind = linear_sum_assignment(cost_matrix)
            arm_set2 = col_ind.tolist()

            # Step 4: Update V
            for i1, i2 in zip(arm_set1, arm_set2):
                zt = context_arms[i1] - context_arms[i2]
                self.V += np.outer(zt, zt)

            self.Z.extend([context_arms[i1] - context_arms[i2] for i1, i2 in zip(arm_set1, arm_set2)])

            print(arm_set1.tolist(), arm_set2)
            return arm_set1.tolist(), arm_set2

    # Update the model with new context-action pair and feedback
    def update(self, yt_list):
        # Updating the context-action pairs and feedback tensors
        zt_12_list = self.Z[-self.each_round_arms:]

        for yt, zt_12 in zip(yt_list, zt_12_list):
            self.ZY_sum += zt_12 * yt

            # Update the model with new context-action pair and feedback
            self.samples = len(self.Z)
            # print(self.samples, self.next_update)
            if (self.samples / self.each_round_arms) % self.next_update == 0:
                # print('begin update')
                self.next_update += self.learner_update
                Z = np.array(self.Z)

                # Objective function
                def negloglik_glm(th):
                    return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))))

                #  Gradient of the objective function
                def negloglik_glm_grad(th):
                    ZMu = np.zeros(len(Z[0]))
                    for t in range(len(Z)):
                        ZMu += Z[t] * (1.0 / (1.0 + np.exp(-Z.dot(th))))[t]
                    return -self.ZY_sum + ZMu

                # Solving the optimization problem
                res = minimize(
                    negloglik_glm,
                    self.theta,
                    jac=negloglik_glm_grad,
                    method='BFGS',
                    options={"disp": False, "gtol": 1e-04}
                )
                theta = np.array(res['x']).flatten()
                self.theta = theta / np.linalg.norm(theta)


    # reset the model
    def reset(self):
        # Reset the model to initial state
        self.samples = 0
        self.Z = []
        self.ZY_sum = np.zeros(self.input_dim)
        self.V = self.lamdba * np.identity(self.input_dim)
        self.theta = np.ones(self.input_dim) / self.input_dim


#  ################## Neural Dueling Bandit Algorithms ##################
# Pytorch keyword arguments
tkwargs = {
    "device": torch.device("cuda:0"),  # Other option: "cuda:0", "cpu", "mps" [For Apple M2 chips]
    # Another way to set the device: "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "dtype": torch.float32,
}

# Base Neural Network class
class Network(nn.Module):
    def __init__(self, input_dim, hidden_size=32, depth=2, init_params=None):
        # Calling the parent class constructor
        super(Network, self).__init__()

        # Activation function
        self.activate = nn.ReLU()

        # Neural network architecture
        self.layer_list = nn.ModuleList()
        self.layer_list.append(nn.Linear(input_dim, hidden_size))
        for i in range(depth - 1):
            self.layer_list.append(nn.Linear(hidden_size, hidden_size))
        self.layer_list.append(nn.Linear(hidden_size, 1))

     # Same NN initialization to maintain consistancy across all runs
        if init_params is None:
            # Initialization using normal distribution
            for i in range(len(self.layer_list)):
                torch.nn.init.normal_(self.layer_list[i].weight, mean=0, std=1.0)
                torch.nn.init.normal_(self.layer_list[i].bias, mean=0, std=1.0)
        else:
            # Use given initialization vector
            for i in range(len(self.layer_list)):
                self.layer_list[i].weight.data = init_params[i * 2]
                self.layer_list[i].bias.data = init_params[i * 2 + 1]

    def forward(self, x):
        # Input
        y = x

        # Forward pass
        for i in range(len(self.layer_list) - 1):
            y = self.activate(self.layer_list[i](y))

        # Output
        y = self.layer_list[-1](y)

        return y



# ### Neural Dueling Bandit with diagonalization ###

class NeuralInitDB:
    def __init__(self, input_dim, each_round_arms, sample_superarm = 10, lamdba=1, nu=1, strategy='ucb', reward_function='add', diagonalize=False, learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input
        self.each_round_arms = each_round_arms  # size of each super-arm
        self.sample_superarm = sample_superarm  # samples of superarm
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.reward_function= reward_function   # reward function of super-arm
        self.diagonalize    = diagonalize       # diagonalization of confidence matrix if true

        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.next_update            = max(1, learner_update)   # Next update of the learner
        self.samples                = 0       # number of samples
        self.context_actions_list   = None    # context-action pairs
        self.feedback_list          = None    # feedback for context-action pairs

        # Initializing neural network model with pytorch
        self.func = extend(Network(self.input_dim).to(**tkwargs))

        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())

        # Initial NN model for feature extraction
        self.init_func = deepcopy(self.func)

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))


    # select_with_hungarian
    def select(self, context_actions):
        self.context_arms = context_actions
        context_actions = torch.from_numpy(context_actions).float().to(**tkwargs)

        grad_list = []
        for a in range(len(context_actions)):
            rt_init_a = self.init_func(context_actions[a])
            self.init_func.zero_grad()
            rt_init_a.backward(retain_graph=True)
            grad = torch.cat([p.grad.flatten().detach() for p in self.init_func.parameters()])
            grad_list.append(grad)

        with torch.no_grad():
            est_rt = self.func(context_actions).view(-1)
        _, topk_indices = torch.topk(est_rt, k=self.each_round_arms)
        arm_set1 = topk_indices.tolist()

        k = self.each_round_arms
        n = len(est_rt)
        cost_matrix = np.zeros((k, n))

        for i, a1 in enumerate(arm_set1):
            for a2 in range(n):
                z = grad_list[a2] - grad_list[a1]
                if self.diagonalize:
                    conf = torch.clamp(self.nu * z * z / self.V, min=0)
                    sigma = torch.sqrt(torch.sum(conf))
                else:
                    z = z.to("cpu")
                    zt_Vinv = torch.matmul(z, torch.linalg.inv(self.V))
                    sigma = torch.sqrt(torch.clamp(torch.matmul(zt_Vinv, z.t()) * self.nu, min=0))
                if self.strategy == 'ucb':
                    score = est_rt[a2].item() + sigma.item()
                else:
                    score = np.random.normal(loc=est_rt[a2].item(), scale=sigma.item())
                cost_matrix[i, a2] = -score  # minimize cost = maximize score

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        arm_set2 = col_ind.tolist()

        for i1, i2 in zip(arm_set1, arm_set2):
            z_update = grad_list[i1] - grad_list[i2]
            z_update = z_update.to(self.V.device)
            self.V += torch.outer(z_update, z_update)

        self.arm_set1 = arm_set1
        self.arm_set2 = arm_set2
        return arm_set1, arm_set2

    # # Selecting a pair of arms
    # def select(self, context_actions):
    #     # Keeping context-arms for model update
    #     self.context_arms = context_actions
    #
    #     # Changing context-actions to tensor
    #     context_actions = torch.from_numpy(context_actions).float().to(**tkwargs)
    #
    #     # Calculating the feature vectors for each context-action pair
    #     grad_list = []
    #     for a in range(len(context_actions)):
    #         # Reward for using the initial NN model
    #         rt_init_a = self.init_func(context_actions[a])
    #
    #         # Zeroing the gradients before backpropagation step
    #         self.init_func.zero_grad()
    #         rt_init_a.backward(retain_graph=True)
    #
    #         # Extracting gradients of the context-action vectors
    #         grad_rt_init_a = torch.cat([p.grad.flatten().detach() for p in self.init_func.parameters()])
    #         grad_list.append(grad_rt_init_a)
    #
    #     # ### Selecting the arms ###
    #     # Current estimate of latent reward
    #     # est_rt = self.func(context_actions)
    #     # est_rt = est_rt.view(-1)
    #     with torch.no_grad():
    #         est_rt = self.func(context_actions).view(-1)
    #     _, topk_indices = torch.topk(est_rt, k=self.each_round_arms)
    #     print('est_rt', est_rt)
    #
    #     if self.reward_function == 'add':
    #         # Selecting the first super-arm
    #         # topk_indices = np.argsort(est_rt.detach().cpu().numpy())[-self.each_round_arms:][::-1].copy()
    #
    #         # print('topk:', topk_indices)
    #
    #         arm_set1 = topk_indices.tolist()
    #         est_rt_set1 = sum(est_rt[topk_indices])
    #         max_score = -np.inf
    #
    #         all_combinations = list(combinations(range(len(est_rt)), self.each_round_arms))
    #
    #         if self.sample_superarm >= len(all_combinations):
    #             sampled_combs = all_combinations  # 全部返回
    #         else:
    #             sampled_combs = random.sample(all_combinations, self.sample_superarm)
    #
    #         for tem_arm_set in sampled_combs:
    #             # print('tmp arm set', tem_arm_set)
    #             tem_arm_set = list(tem_arm_set)
    #             context_diff = []
    #             norm_value_diff = []
    #             for i1, i2 in zip(arm_set1, tem_arm_set):
    #                 # print(i1,i2)
    #                 zt_j1 = grad_list[i2] - grad_list[int(i1)]
    #
    #                 if self.diagonalize:
    #                     ### diagonalization
    #                     conf_term = torch.clamp(self.nu * zt_j1 * zt_j1 / self.V, min=0)
    #                     sigma = torch.sqrt(torch.sum(conf_term))
    #                 else:
    #                     ### no diagonalization
    #                     zt_j1 = zt_j1.to("cpu")
    #                     zt_dot_U = torch.matmul(zt_j1, torch.inverse(self.V))
    #                     zt_dot_U_zt = torch.matmul(zt_dot_U, zt_j1.t())
    #                     conf_term = torch.clamp(self.nu * zt_dot_U_zt, min=0)
    #                     sigma = torch.sqrt(conf_term)
    #
    #                 context_diff.append(zt_j1)
    #                 norm_value_diff.append(conf_term)
    #             sigma_t = np.sqrt(sum(norm_value_diff))
    #             est_rt_tep_set = sum(est_rt[tem_arm_set])
    #
    #             if self.strategy == 'ts':
    #                 action_score = torch.normal(est_rt_tep_set.view(-1), sigma_t.view(-1))
    #                 # Alternative: np.random.normal(loc=est_rt_tep_set.item(), scale=sigma.item())
    #
    #             elif self.strategy == 'ucb':
    #                 action_score = est_rt_tep_set.item() + sigma_t.item()
    #
    #             else:
    #                 raise RuntimeError('Exploration strategy not set')
    #
    #             if action_score > max_score:
    #                 max_score = action_score
    #                 arm_set2 = tem_arm_set
    #
    #         # Update the confidence matrix
    #         for i1, i2 in zip(arm_set1, arm_set2):
    #             zt_12 = grad_list[int(i1)] - grad_list[int(i2)]
    #             # self.Z.append(zt_12)
    #             # self.V += np.outer(zt_12, zt_12)
    #
    #             if self.diagonalize:
    #                 ### diagonalization
    #                 # self.V += zt_12 * zt_12
    #                 device = self.V.device
    #                 zt_12 = zt_12.to(device)
    #                 self.V += torch.outer(zt_12, zt_12)
    #             else:
    #                 ### no diagonalization
    #                 # self.V += torch.outer(zt_12, zt_12)
    #                 device = self.V.device
    #                 zt_12 = zt_12.to(device)
    #                 self.V += torch.outer(zt_12, zt_12)
    #
    #         # Keeping the selected arms for model update
    #         self.arm_set1 = arm_set1.cpu().numpy() if isinstance(arm_set1, torch.Tensor) else arm_set1
    #         self.arm_set2 = arm_set2.cpu().numpy() if isinstance(arm_set2, torch.Tensor) else arm_set2
    #
    #         return self.arm_set1, self.arm_set2
    #         # self.arm_set1 = arm_set1
    #         # self.arm_set2 = arm_set2
    #         #
    #         # return arm_set1, arm_set2

    # Update the model with new context-action pair and feedback
    def update(self, yt_list, local_training_iter=50):

        # Ensuring same initial state of the NN model
        if self.init_state_dict is not None:
            self.func.load_state_dict(deepcopy(self.init_state_dict))

        for at_1, at_2, yt in zip(self.arm_set1, self.arm_set2, yt_list):
            # print('update:', at_1, at_2, yt)
            # Converting numpy variables to tensors
            xt_1 = self.context_arms[at_1]
            xt_2 = self.context_arms[at_2]
            xt_1_tensor = torch.from_numpy(xt_1).reshape(1, -1).to(**tkwargs)
            xt_2_tensor = torch.from_numpy(xt_2).reshape(1, -1).to(**tkwargs)
            xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
            yt_tensor = torch.tensor([yt]).to(**tkwargs)

            # Updating the context-action pairs and feedback tensors
            if self.context_actions_list is None:
                # Adding the first context-action pair
                self.context_actions_list = xt_pair
                self.feedback_list = yt_tensor
            else:
                self.context_actions_list = torch.cat((self.context_actions_list, xt_pair.reshape(2, 1, -1)), dim=1)
                self.feedback_list = torch.cat([self.feedback_list, yt_tensor])

        # Update the model with new context-action pair and feedback
        self.samples = self.context_actions_list.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-2,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()

        # print(self.samples, (self.samples / self.each_round_arms), self.next_update)
        if (self.samples) % self.next_update == 0:
            print('update')
            self.next_update += self.learner_update
            for _ in range(local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.context_actions_list[0].reshape(self.samples, -1)
                x_2 = self.context_actions_list[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()

                # print (f"Step {_} Loss: {loss.item()}")
            # print("Training Loss : ", loss.item(), self.samples)

    # Reset the learner
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.context_actions_list = None
        self.feedback_list = None

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))



# ### Neural Dueling Bandit using updated function for features ###
class NeuralDB:
    def __init__(self, input_dim, each_round_arms, sample_superarm = 10, lamdba=1, nu=1, strategy='ucb', reward_function='add', diagonalize=False, learner_update=20, increasing_delay=False, hidden_size = 50, local_training_iter=50):
        # Initialial parameters
        self.input_dim      = input_dim         # dimension of input
        self.each_round_arms = each_round_arms  # size of each super-arm
        self.sample_superarm = sample_superarm  # sample number
        self.lamdba         = lamdba            # regularization parameter
        self.nu             = nu                # confidence parameter
        self.strategy       = strategy          # arm-selection strategy
        self.reward_function= reward_function   # reward function of super-arm
        self.diagonalize    = diagonalize       # diagonalization of confidence matrix if true

        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0
        self.local_training_iter = local_training_iter

        # Initial variables for storing information
        self.next_update            = max(1, learner_update)   # Next update of the learner
        self.samples                = 0       # number of samples
        self.context_actions_list   = None    # context-action pairs
        self.feedback_list          = None    # feedback for context-action pairs

        # Initializing neural network model with pytorch
        self.func = extend(Network(self.input_dim, hidden_size=hidden_size).to(**tkwargs))

        # Storing the initial state of the NN
        self.init_state_dict = deepcopy(self.func.state_dict())

        # Initial NN model for feature extraction
        self.init_func = deepcopy(self.func)

        # Total number of trainable parameters of NN as it will be used as input feature dimension
        self.total_param = sum(p.numel() for p in self.func.parameters() if p.requires_grad)

        # Initialialization of gram matrix if features are fixed
        if self.diagonalize:
            ### diagonalization
            self.V = lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = lamdba * torch.diag(torch.ones((self.total_param,)))


    # Selecting a pair of arms
    def select(self, context_actions):
        # Keeping context-arms for model update
        self.context_arms = context_actions

        # Changing context-actions to tensor
        self.func.train()
        if self.context_actions_list is not None:
            context_actions = self.context_actions_list.to(**tkwargs)

            # Calculating the feature vectors for observed context-action pair
            grad_list = []
            batch = 500
            num_context = len(context_actions)
            last_batch = num_context % batch
            for a in range(0, num_context, batch):
                # Reward for using the initial NN model
                rt_init_a = self.func(context_actions[a:a+batch])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())

            # Context-actions in the last batch
            if num_context % batch != 0 and (a+batch) < num_context:
                rt_init_a = self.func(context_actions[-last_batch:])
                sum_mu = torch.sum(rt_init_a)
                with backpack(BatchGrad()):
                    sum_mu.backward()
                g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
                grad_list.append(g_list_.cpu())

            grad_list = torch.vstack(grad_list)
            self.V = grad_list.transpose(0,1).matmul(grad_list) + self.lamdba * torch.eye(self.total_param)
        else:
            self.V = self.lamdba * torch.eye(self.total_param)


        # Getting the feature vectors for the context-actions of the current round
        context_actions = torch.from_numpy(self.context_arms).float().to(**tkwargs)

        # Calculating the feature vectors for each context-action pair
        grad_list = []
        batch = 500
        num_context = len(context_actions)
        last_batch = num_context % batch
        for a in range(0, num_context, batch):
            # Reward for using the current NN model
            rt_init_a = self.func(context_actions[a:a+batch])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())

        # Context-actions in the last batch
        if num_context % batch != 0 and (a+batch) < num_context:
            rt_init_a = self.func(context_actions[-last_batch:])
            sum_mu = torch.sum(rt_init_a)
            with backpack(BatchGrad()):
                sum_mu.backward()
            g_list_ = torch.cat([p.grad_batch.flatten(start_dim=1).detach() for p in self.func.parameters()], dim=1)
            grad_list.append(g_list_.cpu())
        grad_list = torch.vstack(grad_list)

        # ### Selecting the arms ###
        # Current estimate of latent reward
        with torch.no_grad():
            est_rt = self.func(context_actions).view(-1)
        _, topk_indices = torch.topk(est_rt, k=self.each_round_arms)

        if self.reward_function == 'add':
            # Selecting the first super-arm
            # topk_indices = np.argsort(est_rt.detach().cpu().numpy())[-self.each_round_arms:][::-1].copy()

            # print('topk:', topk_indices)

            arm_set1 = topk_indices.tolist()
            est_rt_set1 = sum(est_rt[topk_indices])
            max_score = -np.inf

            cost_matrix = np.zeros((self.each_round_arms, len(est_rt)), dtype=np.float32)

            for i, arm1 in enumerate(arm_set1):
                context_diff = []
                norm_value_diff = []
                for arm2 in range(len(est_rt)):
                    zt_j1 = grad_list[arm2] - grad_list[arm1]
                    if self.diagonalize:
                        ### diagonalization
                        conf_term = torch.clamp(zt_j1 * zt_j1 / self.V, min=0)
                        sigma = self.nu * torch.sqrt(torch.sum(conf_term))

                        # Initial rounds sigma can be nan
                        base_sigma = torch.tensor(0.1)
                        # sigma = base_sigma if torch.isnan(sigma) else sigma

                    else:
                        ### no diagonalization
                        zt_j1 = zt_j1.to("cpu")
                        zt_dot_U = torch.matmul(zt_j1, torch.inverse(self.V))
                        zt_dot_U_zt = torch.matmul(zt_dot_U, zt_j1.t())
                        conf_term = torch.clamp(zt_dot_U_zt, min=0)
                        # sigma = self.nu * torch.sqrt(conf_term)
                    context_diff.append(zt_j1)
                    norm_value_diff.append(self.nu * conf_term)

                    if self.strategy == 'ts':
                        # action_score = torch.normal(est_rt[arm2], np.sqrt(conf_term))
                        # Alternative: np.random.normal(loc=est_rt_tep_set.item(), scale=sigma.item())
                        std = torch.tensor(np.sqrt(conf_term.item()), device=est_rt.device)
                        action_score = torch.normal(est_rt[arm2], std)
                    elif self.strategy == 'ucb':
                        action_score = est_rt[arm2] + np.sqrt(conf_term)

                    else:
                        raise RuntimeError('Exploration strategy not set')

                    cost_matrix[i, arm2] = -action_score

            # 用匈牙利算法求解最佳匹配
            row_ind, col_ind = linear_sum_assignment(cost_matrix)
            arm_set2 = col_ind.tolist()



            # Update the confidence matrix
            for i1, i2 in zip(arm_set1, arm_set2):
                zt_12 = grad_list[int(i1)] - grad_list[int(i2)]
                # self.Z.append(zt_12)
                # self.V += np.outer(zt_12, zt_12)

                if self.diagonalize:
                    ### diagonalization
                    # self.V += zt_12 * zt_12
                    device = self.V.device
                    zt_12 = zt_12.to(device)
                    self.V += torch.outer(zt_12, zt_12)
                else:
                    ### no diagonalization
                    # self.V += torch.outer(zt_12, zt_12)
                    device = self.V.device
                    zt_12 = zt_12.to(device)
                    self.V += torch.outer(zt_12, zt_12)

            # Keeping the selected arms for model update
            self.arm_set1 = arm_set1.cpu().numpy() if isinstance(arm_set1, torch.Tensor) else arm_set1
            self.arm_set2 = arm_set2.cpu().numpy() if isinstance(arm_set2, torch.Tensor) else arm_set2

            return self.arm_set1, self.arm_set2



    # Update the model with new context-action pair and feedback
    def update(self, yt_list):
        # Ensuring same initial state of the NN model
        if self.init_state_dict is not None:
            self.func.load_state_dict(deepcopy(self.init_state_dict))

        for at_1, at_2, yt in zip(self.arm_set1, self.arm_set2, yt_list):
            # print('update:', at_1, at_2, yt)
            # Converting numpy variables to tensors
            xt_1 = self.context_arms[at_1]
            xt_2 = self.context_arms[at_2]
            xt_1_tensor = torch.from_numpy(xt_1).reshape(1, -1).to(**tkwargs)
            xt_2_tensor = torch.from_numpy(xt_2).reshape(1, -1).to(**tkwargs)
            xt_pair = torch.cat([xt_1_tensor.reshape(1, 1, -1), xt_2_tensor.reshape(1,1,-1)])
            yt_tensor = torch.tensor([yt]).to(**tkwargs)

            # Updating the context-action pairs and feedback tensors
            if self.context_actions_list is None:
                # Adding the first context-action pair
                self.context_actions_list = xt_pair
                self.feedback_list = yt_tensor
            else:
                self.context_actions_list = torch.cat((self.context_actions_list, xt_pair.reshape(2, 1, -1)), dim=1)
                self.feedback_list = torch.cat([self.feedback_list, yt_tensor])

        # Update the model with new context-action pair and feedback
        self.samples = self.context_actions_list.shape[1]
        optimizer = torch.optim.Adam(self.func.parameters(),lr=1e-1,weight_decay=self.lamdba/(self.samples+50))
        self.func.train()

        # print('sample:', self.samples, 'next_update:', self.next_update, 'learn update:', self.learner_update)
        if self.samples % self.next_update == 0:
            # print('update')
            self.next_update += self.learner_update
            for _ in range(self.local_training_iter):
                self.func.zero_grad()
                optimizer.zero_grad()
                x_1 = self.context_actions_list[0].reshape(self.samples, -1)
                x_2 = self.context_actions_list[1].reshape(self.samples, -1)
                score_1 = self.func(x_1)
                score_2 = self.func(x_2)
                logits = (score_1 - score_2).reshape(-1)    # Logits as difference of scores
                feedback = self.feedback_list.reshape(-1)
                loss = F.binary_cross_entropy_with_logits(logits, feedback.to(dtype=torch.float32))
                loss.backward()
                optimizer.step()

                # print (f"Step {_} Loss: {loss.item()}")
            # print("Training Loss : ", loss.item(), self.samples)

    # Reset the learner
    def reset(self):
        # Reset the model to initial state
        self.func.load_state_dict(deepcopy(self.init_state_dict))
        self.samples = 0
        self.context_actions_list = None
        self.feedback_list = None

        if self.diagonalize:
            ### diagonalization
            self.V = self.lamdba * torch.ones((self.total_param,))
        else:
            ### no diagonalization
            self.V = self.lamdba * torch.diag(torch.ones((self.total_param,)))



class RandomSearch:
    def __init__(self, input_dim, each_round_arms, reward_function='add',
                increasing_delay=False):
        # Initialial parameters
        self.input_dim = input_dim  # dimension of input
        self.each_round_arms = each_round_arms   # size of each super-arm
        self.reward_function = reward_function     #reward function of super-arm

        # Norm of learner's parameter
        self.S = np.sqrt(input_dim)


    # Selecting a pair of arms
    def select(self, context_arms):
        self.context_arms = context_arms
        arms = len(context_arms)
        arm_set1 = random.sample(range(arms), self.each_round_arms)
        arm_set2 = random.sample(range(arms), self.each_round_arms)

        return arm_set1, arm_set2

    def update(self, yt_list):
        pass

    def reset(self):
        pass

