# !/usr/bin/env python
# coding: utf-8

# Importing python packages
import numpy as np


# COBRA with LOOM
class COBRA:
    def __init__(self, arms, dim, strategy=None, lamdba=0.01, delta=0.05, max_sigma=0.1):
        # Initializing parameters
        self.arms       = arms              # Number of arms
        self.dim        = dim               # Dimension of context-arms
        self.strategy   = strategy          # arm-selection strategy
        self.lamdba     = lamdba            # regularization parameter
        self.delta      = delta             # confidence parameter
        self.max_sigma  = max_sigma         # noise parameter

        # Active arms
        self.active_arms = [a for a in range(self.arms)] 

        # Initial variables for storing information
        self.x_t        = None                  # Selected context-arm pair
        self.t          = 0                     # Number of samples
        self.XY_sum     = np.zeros(self.dim)    # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = self.lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim

        # To store the data of each arms
        self.arms_contexts = np.array([None for _ in range(self.arms)])
        self.env_feedback  = np.array([None for _ in range(self.arms)]) 

        # Initializing model parameters for LOO
        self.strategic_arms         = []
        self.theta_loo              = np.ones((self.arms, self.dim))/self.dim
        self.XY_sum_loo             = np.zeros((self.arms, self.dim))
        self.V_loo                  = np.array([self.V for _ in range(self.arms)])
        self.V_inv_loo              = np.array([self.V_inv for _ in range(self.arms)])
        self.t_loo                  = np.zeros(self.arms)
        self.t_arm                  = np.zeros(self.arms)
        self.strategic_arm_found    = False

        # Fixed terms in confidence bound used in UCB strategy
        self.S          = np.sqrt(dim)                  # Value of S, i.e., max ||\theta||
        self.L          = np.sqrt(dim)                  # Value of L, i.e., max ||x||
        self.alpha_fix  = self.S*np.sqrt(self.lamdba)   # Fix term for alpha
        self.conf_ratio = (self.L*self.L)/self.lamdba   # Ratio term in confidence bound


    # Selecting the arm
    def select(self, context_arms):
        # Keeping context-arms for model update
        self.context_arms = context_arms

        # Compute the estimated rewards
        est_rt = self.context_arms.dot(self.theta)

        # Compute the matrix norm for all context-arms pairs
        X_V_inv = self.context_arms @ self.V_inv                        # Efficiently computes matrix multiplication
        results = np.einsum('ij,ij->i', self.context_arms, X_V_inv)     # Efficiently computes row-wise dot product
        matrix_norm = np.sqrt(results)                                  # Compute the uncertainty term
        
        # Compute the confidence term
        log_cnfterm = self.dim*np.log((1.0 + ((self.t+1)*self.conf_ratio))/self.delta)
        alpha_t     = self.alpha_fix + (self.max_sigma*log_cnfterm)
        conf_term   = alpha_t * matrix_norm   

        # Make ucb 0 for strategic arms to remove them from selection
        for a in self.strategic_arms:
            est_rt[a] = -np.inf
            conf_term[a] = -np.inf

        # Select the arm based on the strategy
        if self.strategy == 'ucb':
            # Select the arm with maximum upper confidence bound      
            arm = np.argmax(est_rt + conf_term)

        elif self.strategy == 'ts':
            # Following Linear Contextual TS (ICML 2023) paper approach
            alpha_t = max(self.max_sigma*np.sqrt(9.0*self.dim*np.log((self.t+1)/self.delta)), 0)
            theta_tilde = np.random.multivariate_normal(self.theta, alpha_t*alpha_t*self.V_inv)

            # Select the arm with maximum sample value 
            sampled_rewards = self.context_arms.dot(theta_tilde)

            # Make sampled reward 0 for strategic arms to remove them from selection
            for a in self.strategic_arms:
                sampled_rewards[a] = -np.inf

            # Select the arm with maximum sample value
            arm = np.argmax(sampled_rewards)

        elif self.strategy == 'greedy':
            # Select the arm with maximum estimated reward with high probability, else random
            arm = np.argmax(est_rt) if np.random.uniform(0, 1) >= 0.1 else np.random.choice(self.active_arms)

        else:
            raise RuntimeError('Exploration strategy not set') 
        
        # Debugging print statements
        if arm in self.strategic_arms and self.strategic_arm_found:
            print(est_rt, conf_term)
            print('Selected arm is strategic')
            raise RuntimeError('Selected arm is strategic')

        # ### LOOM: finding the strategic arms ###
        # If strategic arm already found, skip the LOOM step
        if self.strategic_arm_found:
            # Feature vector for the selected context-arm pair and arm
            self.x_t = self.context_arms[arm]
            self.arm = arm
            return arm, self.active_arms

        # LOOM pass: Check for strategic arms
        # are_arm_strategic = False
        for a in self.active_arms:
            if self.t_arm[a] > 2 and self.strategic_arm_found == False:
                # Compute the estimated rewards using Loo estimator
                est_rt_a = self.arms_contexts[a].dot(self.theta_loo[a])

                # Compute the matrix norm for all context-arms pairs
                X_V_inv_a = self.arms_contexts[a] @ self.V_inv_loo[a]               # Efficiently computes matrix multiplication
                results = np.einsum('ij,ij->i', self.arms_contexts[a], X_V_inv_a)   # Efficiently computes row-wise dot product
                matrix_norm = np.sqrt(results)                                      # Compute the uncertainty term
                
                # Compute the confidence term
                log_cnfterm = self.dim*np.log((1.0 + ((self.t_loo[a]+1)*self.conf_ratio))/self.delta)
                alpha_t     = self.alpha_fix + (self.max_sigma*log_cnfterm)
                conf_term   = alpha_t * matrix_norm   

                # Lower bound on estimated rewards
                LCB_x = (est_rt_a - conf_term).sum()

                # Computing upper bound on observed rewards
                sum_y = self.env_feedback[a].sum()
                sum_y_upper_bound = np.sqrt(2*(self.t_arm[a] + 1)*np.log(1.0/self.delta))
                UCB_y = sum_y + sum_y_upper_bound   

                # Check if arm is strategic
                if LCB_x > UCB_y:
                    # print (self.active_arms)
                    self.strategic_arms.append(a)
                    self.active_arms.remove(a)
                    # are_arm_strategic = True
                    self.strategic_arm_found = True
                    # print (self.strategic_arm_found)
                    # print ('Arm {} is strategic. Found in round {}'.format(a, self.t))
                    # print (self.active_arms)

                elif self.t%1000 == 0 and a == 3:
                    print ('No strategic arm found yet', self.t, a, LCB_x, UCB_y, self.t_arm[a])
        
        # LOOM pass: Update the model parameters
        if self.strategic_arm_found:
            # Recompute the model parameters using only the active arms data
            self.V = self.lamdba * np.identity(self.dim)
            self.XY_sum = np.zeros(self.dim)
            self.t = 0

            for a in self.active_arms:
                if self.arms_contexts[a] is not None:
                    for i in range(self.arms_contexts[a].shape[0]):
                        x_a = self.arms_contexts[a][i]
                        y_a = self.env_feedback[a][i]
                        xy_a = x_a*y_a
                        self.XY_sum += xy_a
                        self.V += np.outer(x_a, x_a)
                        self.t += 1

            self.V_inv = np.linalg.inv(self.V)
            self.theta = self.V_inv.dot(self.XY_sum)
            
            # print ('Recomputed model after finding strategic arm', self.t)


        # Feature vector for the selected context-arm pair and arm
        self.x_t = self.context_arms[arm]
        self.arm = arm

        return arm, self.active_arms


    # Update the model with new context-arm pair and feedback
    def update(self, yt):
        # Updating the context-arm pairs and feedback variables
        xy_t = self.x_t*yt
        self.XY_sum += xy_t
        self.V += np.outer(self.x_t, self.x_t)
        self.V_inv = np.linalg.pinv(self.V)
        self.t += 1

        # Update the model parameters
        self.theta = self.V_inv.dot(self.XY_sum)

        # Pass if strategic arm already found
        if not self.strategic_arm_found:
            # Update the arm data
            if self.arms_contexts[self.arm] is None:
                self.arms_contexts[self.arm] = self.x_t.reshape(1, -1)
                self.env_feedback[self.arm] = np.array(yt).reshape(1, -1)
            else:
                self.arms_contexts[self.arm] = np.vstack([self.arms_contexts[self.arm], self.x_t])
                self.env_feedback[self.arm] = np.vstack([self.env_feedback[self.arm], np.array(yt)])

            self.t_arm[self.arm] += 1
        
            # Update for LOO
            for a in self.active_arms:
                # Updating the context-arm pairs and feedback variables for other arms
                if a != self.arm:
                    # Updating the context-arm pairs and feedback variables
                    self.XY_sum_loo[a] += xy_t
                    self.V_loo[a] += np.outer(self.x_t, self.x_t)
                    self.V_inv_loo[a] = np.linalg.inv(self.V_loo[a])
                    self.t_loo[a] += 1

                    # Update the model parameters
                    self.theta_loo[a] = self.V_inv_loo[a].dot(self.XY_sum_loo[a])


    # Reset the model
    def reset(self):
        # ### Reset all variables ###
        # Active arms
        self.active_arms = [a for a in range(self.arms)] 

        # Initial variables for storing information
        self.x_t        = None                  # Selected context-arm pair
        self.t          = 0                     # Number of samples
        self.XY_sum     = np.zeros(self.dim)    # Sum of Context-arms feature vectors * Y 
       
        # Initialialization of gram matrix if features are fixed
        self.V      = self.lamdba * np.identity(self.dim)
        self.V_inv  = np.linalg.inv(self.V)
        
        # Initializing model parameters
        self.theta  = np.ones(self.dim)/self.dim

        # To store the data of each arms
        self.arms_contexts = np.array([None for _ in range(self.arms)])
        self.env_feedback  = np.array([None for _ in range(self.arms)]) 

        # Initializing model parameters for LOO
        self.strategic_arms         = []
        self.theta_loo              = np.ones((self.arms, self.dim))/self.dim
        self.XY_sum_loo             = np.zeros((self.arms, self.dim))
        self.V_loo                  = np.array([self.V for _ in range(self.arms)])
        self.V_inv_loo              = np.array([self.V_inv for _ in range(self.arms)])
        self.t_loo                  = np.zeros(self.arms)
        self.t_arm                  = np.zeros(self.arms)
        self.strategic_arm_found    = False  


# CobraLoo: leaving the arm out for his arm selection
class CobraLoo:
    def __init__(self, dim, strategy='ucb',  mechanism='loom', lamdba=0.01, delta=0.05, sigma=0.05):
        # Initialial parameters
        # self.arms     = arms                # Number of arms
        self.dim        = dim                   # Dimension of context-arms
        self.strategy   = strategy              # arm-selection strategy
        self.lamdba     = lamdba                # regularization parameter
        self.delta      = delta                 # confidence parameter
        self.sigma      = sigma                 # noise parameter

