import numpy as np
import math
import random
import matplotlib.pyplot as plt


class LinUCB:
    def __init__(self, n_arms, n_features, alpha):
        """
        Initialize the LinUCB algorithm.
        
        Parameters:
        n_arms: int
            Number of arms (actions).
        n_features: int
            Dimensionality of the context vector.
        alpha: float
            Exploration parameter. Higher values increase exploration.
        """
        self.n_arms = n_arms
        self.n_features = n_features
        self.alpha = alpha
        
        # Initialize A (d x d identity matrix for each arm) and b (d-dimensional zero vector for each arm)
        self.A = np.eye(n_features)  # A is the design matrix
        self.b = np.zeros(n_features)  # b is the reward vector

    def select_arm(self, contexts):
        """
        Select an arm using the LinUCB formula.
        
        Parameters:
        contexts: np.ndarray
            Context vectors for all arms (shape: n_arms x n_features).
        
        Returns:
        int
            Index of the selected arm.
        """
        p_values = []  # Store upper confidence bound for each arm
        p_lcb = []
        
        for arm in range(self.n_arms):
            A_inv = np.linalg.inv(self.A)  # Compute A inverse
            theta_hat = A_inv @ self.b  # Compute theta_hat (estimated reward coefficients)
            context = contexts[arm]
            
            # Compute the upper confidence bound
            p = context.T @ theta_hat + self.alpha * np.sqrt(context.T @ A_inv @ context)
            p_l = context.T @ theta_hat  - self.alpha * np.sqrt(context.T @ A_inv @ context)
            p_values.append(p)
            p_lcb.append(p_l)
        
        # Choose the arm with the highest UCB
        LCB = np.max(p_lcb)

        survive = []

        for i in range(self.n_arms):
            if p_values[i] >= LCB:
                survive.append(i)


        random_index = random.randint(0, len(survive) - 1)
        
        random_arm = survive[random_index]
        
        return random_arm

    def update(self, arm, context, reward):
        """
        Update the model parameters after receiving a reward.
        
        Parameters:
        arm: int
            Index of the selected arm.
        context: np.ndarray
            Context vector of the selected arm.
        reward: float
            Observed reward.
        """
        self.A += np.outer(context, context)  # Update A
        self.b += reward * context  # Update b

    def batch_update(self, X,y):
        self.A += X
        self.b += y





class Batch_g:
    def __init__(self, n_arms, n_features, alpha):
        self.n_arms = n_arms
        self.n_features = n_features
        self.alpha = alpha


        W = np.eye(self.n_features)
        U = np.zeros(n_features)

        self.A  = []
        self.A.append(W)

        self.b = []
        self.b.append(U)

        self.InfM = []
        self.InfM.append(W)

        

    def survived_arms(self,contexts):
        m = len(self.A)
        eliminate = np.zeros(self.n_arms)


        for i in range(m):
            
            p_lcb = []
            p_values = []
            
            for arm in range(self.n_arms):
                A_inv = np.linalg.inv(self.A[i])  # Compute A inverse
                theta_hat = A_inv @ self.b[i]  # Compute theta_hat (estimated reward coefficients)
                
                context = contexts[arm]

                
                    # Compute the upper confidence bound
                p = context.T @ theta_hat + self.alpha * np.sqrt(context.T @ A_inv @ context)
                
                p_l = context.T @ theta_hat  - 2* self.alpha * np.sqrt(context.T @ A_inv @ context)
                
                p_values.append(p)
                
                p_lcb.append(p_l)
    
                # Choose the arm with the highest UCB
            LCB = np.max(p_lcb)

            for arm in range(self.n_arms):
                if p_values[arm] < LCB:
                    eliminate[arm] = 1

        survive = []

        for arm in range(self.n_arms):
            if eliminate[arm]<=0.5:
                survive.append(arm)

        if len(survive) == 0:
            survive.append(np.int_(0))

        return survive



    def choose_arm(self, contexts):
        survive = self.survived_arms(contexts)

        random_index = random.randint(0, len(self.InfM) - 1)
        
        random_matrix = self.InfM[random_index]

        M_inv = np.linalg.inv(random_matrix)

        p_list = []

        for i in range(len(survive)):
            arm_context = contexts[survive[i] ]
            p = arm_context.T @ M_inv @ arm_context
            p_list.append(p)

        index = np.argmax(p_list)

        best_arm = survive[i]

        return best_arm

    def update_Infmatrix(self, InfM):
        self.InfM = InfM
                 

    def batch_update(self,X,y):
        self.A.append(X)
        self.b.append(y)
        
        
        
    



def optimal_design(context, n_arms ,n_features):
    Lambda = np.eye(n_features)

    Arm_record = []

    T = 100
    for i in range(T):
        A = []
        Lam_inv = np.linalg.inv(Lambda) 
        for j in range(n_arms):
            arm = context[j]
            a = arm.T @ Lam_inv @ arm
            A.append(a)
        best_arm = np.argmax(A)
        Arm_record.append(best_arm)

        best_context = context[best_arm]
        Lambda = Lambda + np.outer(best_context,best_context)
        
    random_index = random.randint(0, len(Arm_record) - 1)
        
    choose_arm = Arm_record[random_index]
    return choose_arm




def grid(T,M,d):
    grad_t = np.zeros(M)
    G  = np.zeros(M)
    rate_1 = 1/(2-2*pow(2,-M));
    rate_2 = (1-2*pow(2,-M))/(2-2*pow(2,-M))
    R = pow(T,rate_1)*pow(d,rate_2)
    grad_t[0] = R
    for i in range(M-1):
        grad_t[i+1] = math.sqrt(grad_t[i])* R/math.sqrt(d)

    G = np.array(grad_t,dtype = np.int64)
    return G


def max_reward(context,theta, n_arms):
    R = []
    for i in range(n_arms):
        r = context[i].T @ theta
        R.append(r)
    return np.max(R)
        

def choose_j(T,M,t):
    B = np.int_(T/M)
    j = np.int_(t/B)
    if j== M:
        j = M-1
    return j







# parameter configuration

T= 500

M = 7

d = 10

G = grid(T,M,d)





# Simulated example
np.random.seed(42)

n_arms = 40
n_features = d
alpha = 1

# Create LinUCB instance
linucb = LinUCB(n_arms, n_features, alpha)

# Generate random context vectors for T rounds
n_rounds = T
  # Each round has n_arms contexts
true_theta = np.random.rand(n_features)  # True reward coefficients


batch_gucb = Batch_g(n_arms, n_features, alpha)




Regret_list_g = []

acc_reg = 0

for i in range(M):
    contexts = np.random.rand(G[i], n_arms, n_features)
    A = np.eye(n_features)
    b = np.zeros(n_features)

    InfM = []

    Inf_matrix = np.eye(n_features)

    InfM.append(Inf_matrix)

    if i == 0:
        for t in range(G[i]):
            current_contexts = contexts[t]
            chosen_arm = optimal_design(current_contexts, n_arms ,n_features)

            reward = current_contexts[chosen_arm].dot(true_theta) + np.random.normal(0, 0.1)

            max_r = max_reward(current_contexts, true_theta ,n_arms )

            current_reg = max_r - reward
            acc_reg = acc_reg + current_reg

            Regret_list_g.append(acc_reg)

            arm_context = current_contexts[chosen_arm]

            A = A + np.outer(arm_context,arm_context)
            b = b + reward * arm_context

            # update the policy for the next batch
        
            p_list = []
            
            for j in range(n_arms):
                Inf_inv = np.linalg.inv(Inf_matrix)
                p_value = current_contexts[j].T @ Inf_inv @ current_contexts[j]
                p_list.append(p_value)

            best_arm = np.argmax(p_list)
            Inf_matrix += np.outer(  current_contexts[best_arm], current_contexts[best_arm])

            InfM.append(Inf_matrix)

    else:
        for t in range(G[i]):
            current_contexts = contexts[t]
            chosen_arm = batch_gucb.choose_arm(current_contexts)

            reward = current_contexts[chosen_arm].dot(true_theta) + np.random.normal(0, 0.1)

            max_r = max_reward(current_contexts, true_theta ,n_arms )

            current_reg = max_r - reward
            acc_reg = acc_reg + current_reg

            Regret_list_g.append(acc_reg)

            arm_context = current_contexts[chosen_arm]

            A = A + np.outer(arm_context,arm_context)
            b = b + reward * arm_context


            survived_arms_tt = batch_gucb.survived_arms(current_contexts)

            p_list = []


            
            for j in range(len(survived_arms_tt)):
                Inf_inv = np.linalg.inv(Inf_matrix)
                index = survived_arms_tt[j]
                p_value = current_contexts[index].T @ Inf_inv @ current_contexts[index]
                p_list.append(p_value)

            best_arm = np.argmax(p_list)
            best_index = survived_arms_tt[best_arm]
            Inf_matrix += np.outer(  current_contexts[best_index], current_contexts[best_index])

            InfM.append(Inf_matrix)

    


    batch_gucb.batch_update(A,b)
    

    batch_gucb.update_Infmatrix(InfM)

            
            
    











Regret_list = []

acc_reg = 0

for i in range(M):
    contexts = np.random.rand(G[i], n_arms, n_features)
    A = np.eye(n_features)
    b = np.zeros(n_features)

    for t in range(G[i]):
        current_contexts = contexts[t]
        chosen_arm = linucb.select_arm(current_contexts)
        reward = current_contexts[chosen_arm].dot(true_theta) + np.random.normal(0, 0.1)

        max_r = max_reward(current_contexts, true_theta ,n_arms )

        current_reg = max_r - reward
        acc_reg = acc_reg + current_reg

        Regret_list.append(acc_reg)

        arm_context = current_contexts[chosen_arm]

        A = A + np.outer(arm_context,arm_context)
        b = b + reward * arm_context


    linucb.batch_update(A,b)

    
split = []
A = []
b = []

Regret_sp_list =[]

acc_reg = 0

for i in range(M):
    linbandit =  LinUCB(n_arms, n_features, alpha)
    split.append(linbandit)
    Inf = np.zeros((n_features, n_features))
    bv = np.zeros(n_features)
    A.append(Inf)
    b.append(bv)


for i in range(M):
    contexts = np.random.rand(G[i], n_arms, n_features)
    for j in range(M):
        A[j]= np.zeros((n_features, n_features))
        b[j] = np.zeros(n_features)

    for t in range(G[i]):
        j = choose_j(G[i],M,t)
        current_contexts = contexts[t]
        
        chosen_arm = split[i].select_arm(current_contexts)
        
        reward = current_contexts[chosen_arm].dot(true_theta) + np.random.normal(0, 0.1)

        max_r = max_reward(current_contexts, true_theta ,n_arms )

        current_reg = max_r - reward
        acc_reg = acc_reg + current_reg

        Regret_sp_list.append(acc_reg)

        arm_context = current_contexts[chosen_arm]

        A[j] = A[j] + np.outer(arm_context,arm_context)
        b[j] = b[j] + reward * arm_context

    for j in range(M):
        split[j].batch_update(A[j],b[j])


Regret = np.array(Regret_list,dtype = np.float64)

Regret_sp = np.array(Regret_sp_list, dtype = np.float64)



plt.plot( range(1,len(Regret)+1), Regret, label = "BatchLinUCB")

plt.plot( range(1,len(Regret_sp)+1), Regret_sp, label ="BatchPureExp")

plt.plot( range(1,len(Regret_list_g)+1), Regret_list_g, label ="This Work")

plt.ylabel("Regret")
plt.xlabel("Steps")

W = G

for i in range(M-1):
    W[i+1]= W[i]+G[i+1]

for i in range(M):
    plt.axvline(x=W[i], color='red', linestyle='--')


plt.legend()
plt.show()







