# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np


# Select arm uniform at random
# Linear contextual bandit learner: LinUCB and LinTS
class Random:
    def __init__(self, arms, dim, strategy=None, lamdba=0.01, delta=0.05, max_sigma=0.1):
        # Initializing parameters
        self.arms       = arms              # Number of arms
        self.dim        = dim               # Dimension of context-arms
        self.strategy   = strategy          # arm-selection strategy
        self.lamdba     = lamdba            # regularization parameter
        self.delta      = delta             # confidence parameter
        self.max_sigma  = max_sigma         # noise parameter

        # Active arms
        self.active_arms = [a for a in range(self.arms)] 

    # Selecting the arm
    def select(self, context_arms):
        # Select the arm uniformly at random
        a_t = np.random.choice(self.arms)

        return a_t, self.active_arms
        

    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        pass

    # Reset the model
    def reset(self):
        pass


# Linear contextual bandit learner: LinUCB and LinTS
class Linear:
    def __init__(self, arms, dim, strategy=None, lamdba=0.01, delta=0.05, max_sigma=0.1):
        # Initializing parameters
        self.arms       = arms              # Number of arms
        self.dim        = dim               # Dimension of context-arms
        self.strategy   = strategy          # arm-selection strategy
        self.lamdba     = lamdba            # regularization parameter
        self.delta      = delta             # confidence parameter
        self.max_sigma  = max_sigma         # noise parameter

        # Active arms
        self.active_arms = [a for a in range(self.arms)] 

        # Initial variables for storing information
        self.S          = np.sqrt(dim)          # Value of S, i.e., max ||\theta||
        self.L          = np.sqrt(dim)          # Value of L, i.e., max ||x||
        self.x_t        = None                  # Selected context-arm pair
        self.arm        = None                  # Selected arm
        self.t          = 0                     # Number of samples
        self.XY_sum     = np.zeros(self.dim)    # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim
        
        # Fixed terms in confidence bound used in UCB strategy
        self.alpha_fix   = self.S*np.sqrt(self.lamdba)   # Fix term for alpha
        self.conf_ratio  = (self.L*self.L)/self.lamdba   # Ratio term in confidence bound


    # Selecting the arm
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms
  
        # Compute the estimated rewards
        est_rt = self.context_arms.dot(self.theta)
        
        # Compute the matrix norm for all context-arms pairs
        X_V_inv = self.context_arms @ self.V_inv                        # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.context_arms, X_V_inv)     # Efficiently computes row-wise dot product
        matrix_norm = np.sqrt(results)                                  # Compute the uncertainty term
        
        # Compute the confidence term
        log_cnfterm = self.dim*np.log((1.0 + ((self.t+1)*self.conf_ratio))/self.delta)
        alpha_t     = self.alpha_fix + (self.max_sigma*log_cnfterm)
        conf_term   = alpha_t * matrix_norm   

        # Select the arm based on the strategy
        if self.strategy == 'ucb':
            # Select the arm with maximum upper confidence bound
            arm = np.argmax(est_rt + conf_term)

        elif self.strategy == 'ts':
            # Following Linear Contextual TS (ICML 2023) paper approach
            alpha_t = max(self.max_sigma*np.sqrt(9.0*self.dim*np.log((self.t+1)/self.delta)), 0)
            theta_tilde = np.random.multivariate_normal(self.theta, alpha_t*alpha_t*self.V_inv)

            # Select the arm with maximum sample value 
            arm = np.argmax(self.context_arms.dot(theta_tilde))

        elif self.strategy == 'greedy':
            # Select the arm with maximum estimated reward with high probability, else random
            arm = np.argmax(est_rt) if np.random.uniform(0, 1) >= 0.1 else np.random.choice(self.arms)

        else:
            raise RuntimeError('Exploration strategy not set') 

        # Context-arms feature vectors for the selected context
        self.x_t = self.context_arms[arm]

        # Returning selected arm and active arms
        return arm, self.active_arms
        

    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        self.XY_sum += self.x_t*yt
        self.V += np.outer(self.x_t, self.x_t)
        self.V_inv = np.linalg.inv(self.V)
        self.t += 1

        # Update the model parameters
        self.theta = self.V_inv.dot(self.XY_sum)
        

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.XY_sum     = np.zeros(self.dim)
        self.V          = self.lamdba * np.identity(self.dim)
        self.V_inv      = np.linalg.inv(self.V)
        self.theta      = np.ones(self.dim)/self.dim
        self.t          = 0


# OptGTM: Kleine et. al. (2024)
class OptGTM:
    def __init__(self, arms, dim, strategy=None, lamdba=0.01, delta=0.05, max_sigma=0.1):
        # Initializing parameters
        self.arms       = arms              # Number of arms
        self.dim        = dim               # Dimension of context-arms
        self.strategy   = strategy          # arm-selection strategy
        self.lamdba     = lamdba            # regularization parameter
        self.delta      = delta             # confidence parameter
        self.max_sigma  = max_sigma         # noise parameter

        # Active arms
        self.active_arms = [a for a in range(self.arms)] 

        # Initial variables for storing information
        self.S          = np.sqrt(dim)          # Value of S, i.e., max ||\theta||
        self.L          = np.sqrt(dim)          # Value of L, i.e., max ||x||
        self.x_t        = None                  # Selected context-arm pair
        
        # Initializing model variables
        self.strategic_arms       = []
        self.theta                  = np.ones((self.arms, self.dim))/self.dim
        self.XY_sum                 = np.zeros((self.arms, self.dim))
        self.V_0                    = lamdba * np.identity(self.dim)
        self.V                      = np.array([self.V_0 for _ in range(self.arms)])
        self.V_inv                  = np.linalg.inv(self.V)
        self.t                      = np.zeros(self.arms)
        self.found_strategic_arm    = False
        self.strategic_arm_index    = None

        # Fixed terms in confidence bound used in UCB strategy
        self.alpha_fix   = self.S*np.sqrt(self.lamdba)   # Fix term for alpha
        self.conf_ratio  = (self.L*self.L)/self.lamdba   # Ratio term in confidence bound

        # To store the data of each arms
        self.arms_contexts = np.array([None for _ in range(self.arms)])
        self.env_feedback = np.array([None for _ in range(self.arms)]) 

        # Variables for GGTM
        self.LCB_sum = np.zeros(self.arms)
        self.Y_sum = np.zeros(self.arms)


    # Selecting the arm
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # Selecting the arm 
        max_score = 0
        arm = 0
        for a in self.active_arms:
            # Compute the estimated rewards
            est_rt = self.context_arms[a].dot(self.theta[a])

            # Compute the matrix norm for all context-arms pairs using
            matrix_norm = np.sqrt(np.inner(np.inner(self.context_arms[a], self.V_inv[a]), self.context_arms[a]) )

            # Compute the confidence term
            log_cnfterm = self.dim*np.log((1.0 + ((self.t[a]+1)*self.conf_ratio))/self.delta)
            alpha_t     = self.alpha_fix + (self.max_sigma*log_cnfterm)
            conf_term   = alpha_t * matrix_norm                                   # Compute the uncertainty term

            # Select the arm based on the strategy
            if self.strategy == 'ucb':
                # Select the arm with maximum upper confidence bound 
                arm_score = est_rt + conf_term

                # Select the arm with maximum upper confidence bound
                if arm_score > max_score:
                    max_score = arm_score
                    arm = a

            elif self.strategy == 'ts':
                # Following Linear Contextual TS (ICML 2023) paper approach
                alpha_t = max(self.max_sigma*np.sqrt(9.0*self.dim*np.log((self.t[a]+1)/self.delta)), 0)
                theta_tilde = np.random.multivariate_normal(self.theta[a], alpha_t*alpha_t*self.V_inv[a])

                # Select the arm with maximum sample value 
                arm_score = self.context_arms[a].dot(theta_tilde)

                # Select the arm with maximum upper confidence bound
                if arm_score > max_score:
                    max_score = arm_score
                    arm = a

            elif self.strategy == 'greedy':
                # Select the arm with maximum estimated reward
                if est_rt > max_score:
                    max_score = est_rt
                    arm = a

            else:
                raise RuntimeError('Exploration strategy not set')     

        if self.strategy == 'greedy': 
            arm = arm if np.random.uniform(0, 1) >= 0.1 else np.random.choice(self.active_arms)  

        # GGTM: finding the strategic arms
        for a in self.active_arms:
            if self.t[a] > 2 and self.found_strategic_arm == False:
                # Calculate upper bound on observed rewards
                sum_y_upper_bound = 2*np.sqrt((self.t[a] + 1)*np.log(1.0/self.delta))
                
                # Check if arm is strategic
                if self.LCB_sum[a] > self.Y_sum[a] + sum_y_upper_bound:
                    self.strategic_arms.append(a)
                    self.active_arms.remove(a)
                    self.found_strategic_arm = True
                    # print (self.found_strategic_arm)
                    # print ('arm {} is strategic'.format(a), self.t)
        
        # Feature vector for the selected context-arm pair and arm
        self.x_t = self.context_arms[arm]
        self.arm = arm

        # Lower bound on estimated rewards
        self.LCB_sum[arm] += est_rt  - conf_term

        return arm, self.active_arms


    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        xy_t = self.x_t*yt
        self.XY_sum[self.arm] += xy_t
        self.V[self.arm] += np.outer(self.x_t, self.x_t)
        self.V_inv[self.arm] = np.linalg.inv(self.V[self.arm])
        self.t[self.arm] += 1

        # Update the model parameters
        self.theta[self.arm] = self.V_inv[self.arm].dot(self.XY_sum[self.arm])

        # Update for GGTM
        self.Y_sum[self.arm] += yt
        

    # Reset the model
    def reset(self):
        # Reset the model to initial state
        self.active_arms            = [a for a in range(self.arms)]
        self.strategic_arms         = []
        self.theta                  = np.ones((self.arms, self.dim))/self.dim
        self.XY_sum                 = np.zeros((self.arms, self.dim))
        self.V_0                    = self.lamdba * np.identity(self.dim)
        self.V                      = np.array([self.V_0 for _ in range(self.arms)])
        self.V_inv                  = np.linalg.inv(self.V)
        self.t                      = np.zeros(self.arms)
        self.found_strategic_arm    = False

