# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np
from scipy.optimize import minimize


#  ################## Baselines: Active Dueling Bandit Algorithms ##################
# Active Dueling Bandit Algorithm: Random 
class Random:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='none', learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.Z              = []                        # Context-arms feature vectors
        self.ZY_sum         = np.zeros(self.dim)        # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim

        # Get the context-arm pairs
        self.flatten_context_arms = self.context_arms.reshape(self.C * self.A, -1)


    # Selecting context and pair of arms
    def select(self):
        # Select the context and pair of arms randomly
        xt_ind = np.random.choice(np.arange(self.C))
        at_1 = np.random.choice(np.arange(self.A))
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))

        # Context-arms feature vectors for the selected context
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.ZY_sum += zt*yt
        self.Z.append(zt)
        self.V += np.outer(zt, zt)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively

        # Update the model with new context-arm pair and feedback
        self.samples = len(self.Z)
        if (self.samples + 1) % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
         # Compute the estimated rewards for each context-arm pair 
        est_rt = self.flatten_context_arms.dot(self.theta)    

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = np.argmax(est_rt, axis=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.ravel()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.Z          = []
        self.ZY_sum     = np.zeros(self.dim)
        self.V          = self.lamdba * np.identity(self.dim)
        self.V_inv      = np.linalg.inv(self.V)
        self.theta      = np.ones(self.dim)/self.dim


# Active Dueling Bandit Algorithm: AE-Borda (Mehta et al., 2023) 
class AEBorda:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms
        
        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.UCB            = None                      # Upper confidence bound
        self.LCB            = None                      # Lower confidence bound
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.Z              = []                        # Context-arms feature vectors
        self.ZY_sum         = np.zeros(self.dim)        # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim

        # Get the context-arm pairs
        self.flatten_context_arms = self.context_arms.reshape(self.C * self.A, -1)

        # Return policy
        self.contx_max_est_rewards  = None      # Maximum rewards for all contexts
        self.contx_max_arm          = None      # Maximum arms for all contexts


    # Selecting context and pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        est_rt = self.flatten_context_arms.dot(self.theta)                   # Compute the estimated rewards
        X_V_inv = self.flatten_context_arms @ self.V_inv                     # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms, X_V_inv)  # Efficiently computes row-wise dot product
        conf_term = self.nu * np.sqrt(results)                      # Compute the confidence term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)
        conf_term = conf_term.reshape(self.C, self.A, -1)

        # Compute the upper and lower confidence bounds for each context-arm pair
        self.UCB = est_rt + conf_term
        self.LCB = est_rt - conf_term

        # Compute the maximum difference between the upper and lower confidence bounds for each context
        max_diff = np.max(self.UCB, axis=1) - np.max(self.LCB, axis=1)

        # Select the context with the maximum difference
        xt_ind = np.argmax(max_diff)
        
        # Select the arms for the selected context: maximizes the UCB values
        at_1 = np.argmax(self.UCB[xt_ind])    
        
        # Select the second best arm: randomly among the remaining arms
        at_2 = np.random.choice(np.delete(np.arange(self.A), at_1))

        # Context-arms feature vectors for the selected context
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.ZY_sum += zt*yt
        self.Z.append(zt)
        self.V += np.outer(zt, zt)
        self.V_inv = np.linalg.inv(self.V)      # TODO: Update the inverse of the gram matrix recursively

        # Update the model with new context-arm pair and feedback
        self.samples = len(self.Z)
        if (self.samples + 1) % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)

    
    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Update the policy's output for each context
        max_rewards = np.max(self.LCB, axis=1)
        max_arm = np.argmax(self.LCB, axis=1)
        if self.contx_max_est_rewards is None:
            self.contx_max_est_rewards = max_rewards
            self.contx_max_arm = max_arm

        else:
            # Check if current LCB is greater than the previous LCB
            lcb_cond = max_rewards > self.contx_max_est_rewards

            # Update the maximum rewards and arms
            self.contx_max_est_rewards[lcb_cond] = max_rewards[lcb_cond]
            self.contx_max_arm[lcb_cond] = max_arm[lcb_cond]

        # Return the estimated best arm for all contexts in flattened form
        return self.contx_max_arm.ravel()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.Z          = []
        self.ZY_sum     = np.zeros(self.dim)
        self.V          = self.lamdba * np.identity(self.dim)
        self.V_inv      = np.linalg.inv(self.V)
        self.theta      = np.ones(self.dim)/self.dim


# Active Dueling Bandit Algorithm: APO (Das et al., 2024) -- Practical version 
class APO:
    def __init__(self, context_arms, dim, lamdba=1, nu=1, strategy='ucb', learner_update=20, increasing_delay=False):
        # Initialial parameters
        self.context_arms   = context_arms          # All context-arms pairs
        self.dim            = dim                   # Dimension of context-arms
        self.lamdba         = lamdba                # regularization parameter
        self.nu             = nu                    # confidence parameter
        self.strategy       = strategy              # arm-selection strategy
        self.C              = len(context_arms)     # Number of contexts 
        self.A              = len(context_arms[0])  # Number of arms

        # Setting the incrteasing delay of the learner being updated
        self.learner_update = learner_update if increasing_delay else 0

        # Initial variables for storing information
        self.xt_1           = None                      # First selected context-arm pair
        self.xt_2           = None                      # Second selected context-arm pair
        self.next_update    = max(1, learner_update)    # Next update of the learner 
        self.samples        = 0                         # Number of samples
        self.Z              = []                        # Context-arms feature vectors
        self.ZY_sum         = np.zeros(self.dim)        # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim

        # Get the context-arm pairs
        self.flatten_context_arms = self.context_arms.reshape(self.C * self.A, -1)

        # Get all possible difference of context-arms pairs
        self.flatten_context_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])

        # Return policy
        self.contx_max_arm  = None              # Maximum arms for all contexts


    # Selecting a pair of arms
    def select(self):
        # Selecting the next context by parallelizing the computation
        # Compute the estimated rewards and confidence bounds for each context-arm pair
        X_V_inv = self.flatten_context_arms_diff @ self.V_inv                       # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.flatten_context_arms_diff, X_V_inv)    # Efficiently computes row-wise dot product
        uncertainty_term = np.sqrt(results)                                         # Compute the uncertainty term

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        uncertainty_term = uncertainty_term.reshape(self.C, self.A, self.A, -1)
        max_index = np.unravel_index(np.argmax(uncertainty_term), uncertainty_term.shape)
     
        # Replacing the max index in all_context_arms_diff with 0 vector to avoid selecting it again
        index_to_remove = max_index[0]*self.A*self.A + max_index[1]*self.A + max_index[2]
        self.flatten_context_arms_diff[index_to_remove] = np.zeros(self.dim)

        # Select context and arms
        xt_ind = max_index[0]
        at_1 = max_index[1]
        at_2 = max_index[2]
        
        # Context-arms feature vectors for the selected context
        self.xt_1 = self.context_arms[xt_ind,at_1]
        self.xt_2 = self.context_arms[xt_ind,at_2]

        return xt_ind, self.xt_1, self.xt_2
    

    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        zt = self.xt_1 - self.xt_2
        self.ZY_sum += zt*yt
        self.Z.append(zt)
        self.V += np.outer(zt, zt)
        self.V_inv = np.linalg.inv(self.V)

        # Update the model with new context-arm pair and feedback
        self.samples = len(self.Z)
        if (self.samples + 1) % self.next_update == 0:
            self.next_update += self.learner_update
            Z = np.array(self.Z)
            
            # Objective function
            def negloglik_glm(th):
                return -self.ZY_sum.dot(th) + np.sum(np.log(1.0 + np.exp(Z.dot(th))) )
            
            #  Gradient of the objective function
            def negloglik_glm_grad(th):
                ZMu = np.zeros(len(Z[0]))
                for t in range(len(Z)):
                    ZMu += Z[t] * (1.0/(1.0+np.exp(-Z.dot(th))))[t]
                return -self.ZY_sum + ZMu

            # Solving the optimization problem
            res = minimize(
                negloglik_glm, 
                self.theta, 
                jac=negloglik_glm_grad, 
                method='BFGS', 
                options={"disp": False, "gtol": 1e-04}
            )    
            theta = np.array(res['x']).flatten()
            self.theta = theta/np.linalg.norm(theta)


    # Get the estimated best arm for all contexts
    def get_policy(self):
        # Compute the estimated rewards for each context-arm pair
        est_rt = self.flatten_context_arms.dot(self.theta)    

        # Reshape the estimated rewards and confidence bounds for all context-arm pairs
        est_rt = est_rt.reshape(self.C, self.A, -1)

        # Compute the arms with maximum rewards for each context
        max_arms = np.argmax(est_rt, axis=1)

        # Return the estimated best arm for all contexts in flattened form
        return max_arms.ravel()


    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.Z          = []
        self.ZY_sum     = np.zeros(self.dim)
        self.V          = self.lamdba * np.identity(self.dim)
        self.V_inv      = np.linalg.inv(self.V)
        self.theta      = np.ones(self.dim)/self.dim
        self.flatten_context_arms_diff = np.array([self.context_arms[c,a] - self.context_arms[c,b] for c in range(self.C) for a in range(self.A) for b in range(self.A)])

