import numpy as np
from tqdm import tqdm
import datetime
import time
import os
from collections import defaultdict
from util import * 

class OptimalPolicy:
    def __init__(self, env, *args, **kwargs):
        self.env = env
        self.theta_list = env.theta_list 
        self.K = env.K
        self.rng = env.rng
    
    def run(self, T):
        env = self.env
        
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                env.queue.append(data_idx)
            
            opt_job_idx = -1
            opt_server_idx = -1

            if len(env.queue) > 0:

                job_idx_to_check = 0 
                data_idx = env.queue[job_idx_to_check]
                y_true_oracle = env.y_data[data_idx] 
                
                opt_job_idx = job_idx_to_check
                opt_server_idx = y_true_oracle 
                            
                env.step(opt_job_idx, opt_server_idx)
            
            else:
                env.step(-1, -1)


class RandomPolicy:
    def __init__(self, env, *args, **kwargs):
        self.env = env
        self.rng = env.rng
        self.K = env.K

    def run(self, T):
        env = self.env
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                env.queue.append(data_idx)
                
            if len(env.queue) > 0:
                job_idx = self.rng.choice(len(env.queue))
                server_idx = self.rng.choice(self.K)
                env.step(job_idx, server_idx)
            else:
                env.step(-1, -1)



class CQBEps:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):

        d = env.d
        K = env.K
        
        self.rng = np.random.default_rng(seed)
        self.env = env
        self.lambda_ = lambda_
        self.eps = eps
        self.d = d
        self.K = K
        self.kappa = kappa
        self.L = L
        self.S = S 
        self.reg = reg
        self.theta_hats = [np.zeros(d) for k in range(K)]
        self.Vs = [np.eye(d) * reg for k in range(K)]
        self.Vs_inv = [np.linalg.inv(v) for v in self.Vs] 
        self.counts = [0 for k in range(K)]
        self.betas = [0.0 for k in range(K)]
        self.delta = 0.1
        self.rewards = [[] for k in range(K)]
        self.Xs = [[] for _ in range(K)]
        self.window_size = 256 
        self.batch_size = 32
        self.epochs = 3
        self.lr = 1e-3
        
    def UCB(self, x, k):
        theta_hat = self.theta_hats[k]
        V_inv = self.Vs_inv[k]
        beta = 0.1 * self.betas[k]
        
        ucb_value = sigmoid(np.dot(x, theta_hat)) + beta * np.sqrt(np.dot(np.dot(x, V_inv), x))
        return ucb_value

    def update_estimator(self, k):
        if not self.Xs[k]:
            return
        X = np.array(self.Xs[k])
        y = np.array(self.rewards[k])
        
        for _ in range(self.epochs):
            for i in range(0, len(X), self.batch_size):
                X_batch = X[i:i+self.batch_size]
                y_batch = y[i:i+self.batch_size]
                
                if len(X_batch) == 0:
                    continue
                
                y_pred = sigmoid(np.dot(X_batch, self.theta_hats[k]))
                error = y_pred - y_batch
                
                grad = np.dot(X_batch.T, error) / len(X_batch)
                self.theta_hats[k] -= self.lr * grad

        self.theta_hats[k] = project_l2_norm(self.theta_hats[k], threshold=self.S)

    
    def update_beta(self, k):
        d = self.d
        kappa = self.kappa
        reg = self.reg
        delta = self.delta
        S = self.S
        count_k = self.counts[k]
        
        lambda_0 = reg
        term1 = (d * 2) * np.log(1 + count_k / (kappa * lambda_0 * d))
        term2 = np.log(1 / delta)
        term3 = S * np.sqrt(lambda_0)
        self.betas[k] = kappa / 2 * (np.sqrt(term1 + term2) + term3)
    
    def run(self, T):
        env = self.env
        d = self.d
        reg = self.reg
        
        eps_1 = self.eps
        eps_2 = T**(-0.5)
        effective_epsilon_squared = (eps_1 - 2 * eps_2)**2
        sigma_0_squared = (1.0 / d)**2 
        
        #tau = int(0.0003 * ((d * np.log(T)) / (sigma_0_squared * effective_epsilon_squared)) * (self.K / self.lambda_))
        tau = int(0.1 * T)
        p = 0 
        
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                if data_idx is not None:
                    env.queue.append(data_idx)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                
                if t < tau and is_new_job_arrival:
                    opt_job_idx = len(env.queue) - 1 
                    opt_server_idx = p
                    p = (p + 1) % self.K
                
                elif t >= tau and is_new_job_arrival and self.rng.uniform(0, 1) < eps_2:
                    opt_job_idx = len(env.queue) - 1 
                    opt_server_idx = self.rng.choice(self.K)
                
                else:
                    max_ucb = -np.inf
                    for job_idx, data_idx in enumerate(env.queue):
                        x_contextual = env.X_data[data_idx]
                        
                        for server_idx in range(self.K):
                            ucb = self.UCB(x_contextual, server_idx)
                            if ucb > max_ucb:
                                max_ucb = ucb
                                opt_job_idx, opt_server_idx = job_idx, server_idx
                                
                
                if opt_job_idx != -1:
                    data_idx_to_process = env.queue[opt_job_idx]
                    job_to_process_feature = env.X_data[data_idx_to_process]
                    
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1: 
                        self.Xs[opt_server_idx].append(job_to_process_feature)
                        self.rewards[opt_server_idx].append(reward)

                        self.Vs[opt_server_idx] += np.outer(job_to_process_feature, job_to_process_feature)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx) 
            else:
                env.step(-1, -1)


class CQBOpt:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):
        d = env.d
        K = env.K
        
        self.rng = env.rng
        self.env = env
        self.lambda_ = lambda_
        self.eps = eps 
        self.d = d
        self.K = K
        self.kappa = kappa
        self.L = L 
        self.S = S 
        self.reg = reg
        self.theta_hats = [np.zeros(d) for k in range(K)]
        self.Vs = [np.eye(d) * reg for k in range(K)]
        self.Vs_inv = [np.linalg.inv(v) for v in self.Vs]
        self.counts = [0 for k in range(K)]
        self.betas = [0.0 for k in range(K)]
        self.delta = 0.1
        self.rewards = [[] for k in range(K)]
        self.Xs = [[] for _ in range(K)]
        self.window_size = 256 
        self.batch_size = 32
        self.epochs = 3
        self.lr = 1e-3
    
    def UCB(self, x, k):
        theta_hat = self.theta_hats[k]
        V_inv = self.Vs_inv[k]
        beta = 0.1 * self.betas[k] 
        ucb_value = sigmoid (np.dot(x, theta_hat)) + beta*np.sqrt( np.dot(np.dot(x, V_inv), x))
        return ucb_value
    


    def update_estimator(self, k):
        if not self.Xs[k]:
            return
        X = np.array(self.Xs[k])
        y = np.array(self.rewards[k])
        
        for _ in range(self.epochs):
            for i in range(0, len(X), self.batch_size): 
                X_batch = X[i:i+self.batch_size]
                y_batch = y[i:i+self.batch_size]
                
                if len(X_batch) == 0:
                    continue
                
                y_pred = sigmoid(np.dot(X_batch, self.theta_hats[k]))
                error = y_pred - y_batch
                
                grad = np.dot(X_batch.T, error) / len(X_batch)
                self.theta_hats[k] -= self.lr * grad
        
        self.theta_hats[k] = project_l2_norm(self.theta_hats[k], threshold=self.S)
    
    def run(self, T):
        env = self.env
        d = self.d
        kappa = self.kappa
        reg = self.reg
        delta = self.delta
        S = self.S
        
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                if data_idx is not None:
                    env.queue.append(data_idx)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                max_ucb = -np.inf
                for job_idx, data_idx in enumerate(env.queue):
                    x_contextual = env.X_data[data_idx]
                    
                    for server_idx in range(self.K):
                        ucb = self.UCB(x_contextual, server_idx)
                        if ucb > max_ucb:
                            max_ucb = ucb
                            opt_job_idx, opt_server_idx = job_idx, server_idx
                            
                
                if opt_job_idx != -1:
                    data_idx_to_process = env.queue[opt_job_idx]
                    job_to_process_feature = env.X_data[data_idx_to_process]
                    
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process_feature)
                        self.rewards[opt_server_idx].append(reward)
                        
                        self.Vs[opt_server_idx] += np.outer(job_to_process_feature, job_to_process_feature)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        count_k = self.counts[opt_server_idx]
                        lambda_0 = reg
                        term1 = (d * 2) * np.log(1 + count_k / (kappa * lambda_0 * d))
                        term2 = np.log(1 / delta)
                        term3 = S * np.sqrt(lambda_0)
                        self.betas[opt_server_idx] = kappa / 2 * (np.sqrt(term1 + term2) + term3)
                        
                        self.update_estimator(opt_server_idx) 
            
            else:
                env.step(-1, -1)


class CQBEpsopt(CQBEps):
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):
        super().__init__(env, lambda_=lambda_, eps=eps, d=env.d, K=env.K, kappa=kappa, L=L, S=S, reg=reg, seed=seed) 

    def run(self, T):
        env = self.env
        eps_2 = T**(-0.5) if T > 0 else 1.0 
        
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                if data_idx is not None:
                    env.queue.append(data_idx)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                
                if self.rng.uniform(0, 1) < eps_2 and is_new_job_arrival:
                    opt_job_idx = len(env.queue) - 1
                    opt_server_idx = self.rng.choice(self.K)

                else:
                    max_ucb = -np.inf
                    for job_idx, data_idx in enumerate(env.queue):
                        x_contextual = env.X_data[data_idx]
                        
                        for server_idx in range(self.K):
                            ucb = self.UCB(x_contextual, server_idx)
                            if ucb > max_ucb:
                                max_ucb = ucb
                                opt_job_idx, opt_server_idx = job_idx, server_idx
                                
                
                if opt_job_idx != -1:
                    data_idx_to_process = env.queue[opt_job_idx]
                    job_to_process_feature = env.X_data[data_idx_to_process]
                    
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process_feature)
                        self.rewards[opt_server_idx].append(reward)

                        self.Vs[opt_server_idx] += np.outer(job_to_process_feature, job_to_process_feature)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx) 
            else:
                env.step(-1, -1)


class CQBts(CQBOpt):
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):
        super().__init__(env, lambda_=lambda_, eps=eps, d=env.d, K=env.K, kappa=kappa, L=L, S=S, reg=reg, seed=seed)
        self.beta_base = 1.0

    def get_TS_score(self, x, k):
        V_inv = self.Vs_inv[k]
        beta_k = self.betas[k] * self.beta_base 
        
        mean = np.dot(x, self.theta_hats[k])
        
        uncertainty_term = np.dot(np.dot(x, V_inv), x)
        variance = 16 * (beta_k**2) * uncertainty_term 
        
        if variance < 1e-8:
            variance = 1e-8
            
        std_dev = np.sqrt(variance)
        
        ts_score = self.rng.normal(mean, std_dev)
            
        return ts_score 

    def run(self, T):
        env = self.env
        d = self.d
        kappa = self.kappa
        reg = self.reg
        delta = self.delta
        S = self.S
        
        for t in range(T):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                data_idx = env.generate_feature()
                if data_idx is not None:
                    env.queue.append(data_idx)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                max_ts_score = -np.inf
                for job_idx, data_idx in enumerate(env.queue):
                    x_contextual = env.X_data[data_idx]
                    
                    for server_idx in range(self.K):
                        ts_score = self.get_TS_score(x_contextual, server_idx)
                        if ts_score > max_ts_score:
                            max_ts_score = ts_score
                            opt_job_idx, opt_server_idx = job_idx, server_idx
                            
                
                if opt_job_idx != -1:
                    data_idx_to_process = env.queue[opt_job_idx]
                    job_to_process_feature = env.X_data[data_idx_to_process]
                    
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process_feature)
                        self.rewards[opt_server_idx].append(reward)
                        
                        self.Vs[opt_server_idx] += np.outer(job_to_process_feature, job_to_process_feature)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        count_k = self.counts[opt_server_idx]
                        lambda_0 = reg
                        term1 = (d * 2) * np.log(1 + count_k / (kappa * lambda_0 * d))
                        term2 = np.log(1 / delta)
                        term3 = S * np.sqrt(lambda_0)
                        self.betas[opt_server_idx] = kappa / 2 * (np.sqrt(term1 + term2) + term3)
                        
                        self.update_estimator(opt_server_idx) 
            
            else:
                env.step(-1, -1)


class UCB1:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):
        self.rng = env.rng
        self.env = env
        self.K = env.K
        self.lambda_ = lambda_
        
        self.counts = [0 for k in range(self.K)]
        self.rewards_sum = [0.0 for k in range(self.K)]

    def get_ucb_value(self, t):
        ucb_values = []
        
        if t <= 0:
            return -1, -1
        
        t_safe = t + 1 
        
        for k in range(self.K):
            T_k = self.counts[k]
            
            if T_k == 0:
                ucb_values.append(np.inf)
                continue
            
            hat_mu_k = self.rewards_sum[k] / T_k
            ucb_bonus = np.sqrt((np.log(t_safe)**2) / (2 * T_k)) 
            
            ucb_value = hat_mu_k + ucb_bonus
            ucb_values.append(ucb_value)
        
        opt_server_idx = np.argmax(ucb_values)
        opt_job_idx = 0 
        
        return opt_job_idx, opt_server_idx
    
    def get_bernoulli_explore_sample(self, t):
        if t <= 1:
            return 1
            
        t_safe = t + 1
        log_t_sq = np.log(t_safe)**2
        p_explore = min(1.0, (3 * self.K * log_t_sq) / t_safe) 
        
        return 1 if self.rng.random() < p_explore else 0

    def run(self, T):
        env = self.env
        
        for t in range(1, T + 1): 
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                job = env.generate_feature()
                if job is not None:
                    env.queue.append(job)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                E_t = self.get_bernoulli_explore_sample(t)
                
                if E_t == 1:
                    opt_job_idx = 0 
                    opt_server_idx = self.rng.choice(self.K)
                else:
                    opt_job_idx, opt_server_idx = self.get_ucb_value(t)
                
                if opt_job_idx != -1: 
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1: 
                        self.counts[opt_server_idx] += 1       
                        self.rewards_sum[opt_server_idx] += reward
            
            else:
                env.step(-1, -1)

class TS1:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=5, kappa=10, L=1, S=1, reg=1, seed=0):
        self.rng = env.rng
        self.env = env
        self.K = env.K
        self.lambda_ = lambda_
        
        self.counts = [0 for k in range(self.K)]
        self.rewards_sum = [0.0 for k in range(self.K)]

    def get_thompson_sample_server(self, t):
        max_sample = -np.inf
        opt_server_idx = -1
        
        for k in range(self.K):
            T_k = self.counts[k]
            S_k = self.rewards_sum[k] 
            
            if T_k == 0:
                return 0, k 
            
            F_k = T_k - S_k
            alpha = S_k + 1
            beta = F_k + 1
            
            sample = self.rng.beta(alpha, beta)
            
            if sample > max_sample:
                max_sample = sample
                opt_server_idx = k
        
        opt_job_idx = 0 
        return opt_job_idx, opt_server_idx

    def get_bernoulli_explore_sample(self, t):
        if t <= 1:
            return 1
        t_safe = t + 1
        log_t_sq = np.log(t_safe)**2
        p_explore = min(1.0, (3 * self.K * log_t_sq) / t_safe)
        return 1 if self.rng.random() < p_explore else 0

    def run(self, T):
        env = self.env
        
        for t in range(1, T + 1):
            is_new_job_arrival = env.generate_arrival()
            
            if is_new_job_arrival:
                job = env.generate_feature()
                if job is not None:
                    env.queue.append(job)
                    
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                E_t = self.get_bernoulli_explore_sample(t)
                
                if E_t == 1:
                    opt_job_idx = 0 
                    opt_server_idx = self.rng.choice(self.K)
                else:
                    opt_job_idx, opt_server_idx = self.get_thompson_sample_server(t)
                
                if opt_job_idx != -1 and opt_server_idx != -1:
                    reward, _, _ = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1: 
                        self.counts[opt_server_idx] += 1
                        self.rewards_sum[opt_server_idx] += reward
            else:
                env.step(-1, -1)