import numpy as np

from util import sigmoid, project_l2_norm


class OptimalPolicy:
    def __init__(self, env):
        self.env = env
        self.theta_list = env.theta_list
        self.K = env.K
    
    def run(self, T):
        env = self.env
        
        for t in range(T):
            is_new_job_arrival = (env.rng.uniform(0, 1) < env.lambda_)
            
            if is_new_job_arrival:
                x = env.generate_feature()
                env.queue.append(x)
            
            if len(env.queue) > 0:
                max_prob = -np.inf
                opt_job_idx = -1
                opt_server_idx = -1
                
                for job_idx, job in enumerate(env.queue):
                    for server_idx in range(self.K):
                        prob = sigmoid(np.dot(job, self.theta_list[server_idx]))
                        if prob > max_prob:
                            max_prob = prob
                            opt_job_idx, opt_server_idx = job_idx, server_idx
                            
                env.step(opt_job_idx, opt_server_idx)
            
            else:
                env.step(-1, -1)


class RandomPolicy:
    def __init__(self, env):
        self.env = env
        self.rng = env.rng
        self.K = env.K

    def run(self, T):
        env = self.env
        for t in range(T):
            is_new_job_arrival = (self.rng.uniform(0, 1) < env.lambda_)
            if is_new_job_arrival:
                x = env.generate_feature()
                env.queue.append(x)
                
            if len(env.queue) > 0:
                job_idx = self.rng.choice(len(env.queue))
                server_idx = self.rng.choice(self.K)
                env.step(job_idx, server_idx)
            else:
                env.step(-1, -1)


class CQBEps:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self._init_params(env, lambda_, eps, d, K, kappa, L, S, reg, seed)
    
    def _init_params(self, env, lambda_, eps, d, K, kappa, L, S, reg, seed):
        self.rng = np.random.default_rng(seed)
        self.env = env
        self.lambda_ = lambda_
        self.eps = eps
        self.d = d
        self.K = K
        self.kappa = kappa
        self.L = L
        self.S = S
        self.reg = reg
        self.theta_hats = [np.zeros(d) for k in range(K)]
        self.Vs = [np.eye(d) * reg for k in range(K)]
        self.Vs_inv = [np.linalg.inv(v) for v in self.Vs]
        self.counts = [0 for k in range(K)]
        self.betas = [0.0 for k in range(K)]
        self.delta = 0.1
        self.rewards = [[] for k in range(K)]
        self.Xs = [[] for _ in range(K)]
        self.window_size = 256
        self.batch_size = 32
        self.epochs = 3
        self.lr = 1e-3
        
    def UCB(self, x, k):
        theta_hat = self.theta_hats[k]
        V_inv = self.Vs_inv[k]
        beta = 0.1 * self.betas[k]
        
        ucb_value = sigmoid(np.dot(x, theta_hat)) + beta * np.sqrt(np.dot(np.dot(x, V_inv), x))
        return ucb_value

    def update_estimator(self, k):
        if not self.Xs[k]:
            return
        X = np.array(self.Xs[k])
        y = np.array(self.rewards[k])
        
        for _ in range(self.epochs):
            for i in range(0, len(X), self.batch_size):
                X_batch = X[i:i+self.batch_size]
                y_batch = y[i:i+self.batch_size]
                
                if len(X_batch) == 0:
                    continue
                
                y_pred = sigmoid(np.dot(X_batch, self.theta_hats[k]))
                error = y_pred - y_batch
                
                grad = np.dot(X_batch.T, error) / len(X_batch)
                self.theta_hats[k] -= self.lr * grad
    
    def update_beta(self, k):
        d = self.d
        kappa = self.kappa
        reg = self.reg
        delta = self.delta
        S = self.S
        count_k = self.counts[k]
        
        lambda_0 = reg
        term1 = (d * 2) * np.log(1 + count_k / (kappa * lambda_0 * d))
        term2 = np.log(1 / delta)
        term3 = S * np.sqrt(lambda_0)
        self.betas[k] = kappa / 2 * (np.sqrt(term1 + term2) + term3)
    
    def run(self, T):
        env = self.env
        d = self.d
        reg = self.reg
        
        eps_1 = self.eps
        eps_2 = T**(-0.5)
        effective_epsilon_squared = (eps_1 - 2 * eps_2)**2
        sigma_0_squared = (1.0 / d)**2 
        
        tau = int(0.0003 * ((d * np.log(T)) / (sigma_0_squared * effective_epsilon_squared)) * (self.K / self.lambda_))
        p = 0
        
        for t in range(T):
            is_new_job_arrival = (self.rng.uniform(0, 1) < self.lambda_)
            
            if is_new_job_arrival:
                x = env.generate_feature()
                if x is not None:
                    env.queue.append(x)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                
                if t < tau and is_new_job_arrival:
                    opt_job_idx = len(env.queue) - 1
                    opt_server_idx = p
                    p = (p + 1) % self.K
                
                elif t >= tau and is_new_job_arrival and self.rng.uniform(0, 1) < eps_2:
                    opt_job_idx = len(env.queue) - 1
                    opt_server_idx = self.rng.choice(self.K)
                
                else:
                    max_ucb = -np.inf
                    for job_idx, job in enumerate(env.queue):
                        for server_idx in range(self.K):
                            ucb = self.UCB(job, server_idx)
                            if ucb > max_ucb:
                                max_ucb = ucb
                                opt_job_idx, opt_server_idx = job_idx, server_idx
                                
                
                if opt_job_idx != -1:
                    job_to_process = env.queue[opt_job_idx]
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process)
                        self.rewards[opt_server_idx].append(reward)

                        self.Vs[opt_server_idx] += np.outer(job_to_process, job_to_process)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx)
            else:
                env.step(-1, -1)


class CQBOpt:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self._init_params(env, lambda_, eps, d, K, kappa, L, S, reg, seed)

    def _init_params(self, env, lambda_, eps, d, K, kappa, L, S, reg, seed):
        self.rng = np.random.default_rng(seed)
        self.env = env
        self.lambda_ = lambda_
        self.eps = eps
        self.d = d
        self.K = K
        self.kappa = kappa
        self.L = L
        self.S = S
        self.reg = reg
        self.theta_hats = [np.zeros(d) for k in range(K)]
        self.Vs = [np.eye(d) * reg for k in range(K)]
        self.Vs_inv = [np.linalg.inv(v) for v in self.Vs]
        self.counts = [0 for k in range(K)]
        self.betas = [0.0 for k in range(K)]
        self.delta = 0.1
        self.rewards = [[] for k in range(K)]
        self.Xs = [[] for _ in range(K)]
        self.window_size = 256
        self.batch_size = 32
        self.epochs = 3
        self.lr = 1e-3
    
    def UCB(self, x, k):
        theta_hat = self.theta_hats[k]
        V_inv = self.Vs_inv[k]
        beta =  0.1 * self.betas[k]
        ucb_value = sigmoid (np.dot(x, theta_hat)) + beta*np.sqrt( np.dot(np.dot(x, V_inv), x))
        return ucb_value

    def update_estimator(self, k):
        if not self.Xs[k]:
            return
        X = np.array(self.Xs[k])
        y = np.array(self.rewards[k])
        
        for _ in range(self.epochs):
            for i in range(0, len(X), self.batch_size):
                X_batch = X[i:i+self.batch_size]
                y_batch = y[i:i+self.batch_size]
                
                if len(X_batch) == 0:
                    continue
                
                y_pred = sigmoid(np.dot(X_batch, self.theta_hats[k]))
                error = y_pred - y_batch
                
                grad = np.dot(X_batch.T, error) / len(X_batch)
                self.theta_hats[k] -= self.lr * grad

    def update_beta(self, k):
        d = self.d
        kappa = self.kappa
        reg = self.reg
        delta = self.delta
        S = self.S
        count_k = self.counts[k]
        
        lambda_0 = reg
        term1 = (d * 2) * np.log(1 + count_k / (kappa * lambda_0 * d))
        term2 = np.log(1 / delta)
        term3 = S * np.sqrt(lambda_0)
        self.betas[k] = kappa / 2 * (np.sqrt(term1 + term2) + term3)
    
    def run(self, T):
        env = self.env
        d = self.d
        reg = self.reg
        delta = self.delta
        S = self.S
        
        for t in range(T):
            is_new_job_arrival = (self.rng.uniform(0, 1) < self.lambda_)
            
            if is_new_job_arrival:
                x = env.generate_feature()
                if x is not None:
                    env.queue.append(x)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                max_ucb = -np.inf
                for job_idx, job in enumerate(env.queue):
                    for server_idx in range(self.K):
                        ucb = self.UCB(job, server_idx)
                        if ucb > max_ucb:
                            max_ucb = ucb
                            opt_job_idx, opt_server_idx = job_idx, server_idx
                            
                
                if opt_job_idx != -1:
                    job_to_process = env.queue[opt_job_idx]
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process)
                        self.rewards[opt_server_idx].append(reward)
                        
                        self.Vs[opt_server_idx] += np.outer(job_to_process, job_to_process)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx)
            
            else:
                env.step(-1, -1)



class CQBEpsopt(CQBEps):
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self._init_params(env, lambda_, eps, d, K, kappa, L, S, reg, seed)

    def run(self, T):
        env = self.env
        eps_2 = T**(-0.5) if T > 0 else 1.0 
        
        for t in range(T):
            is_new_job_arrival = (self.rng.uniform(0, 1) < self.lambda_)
            
            if is_new_job_arrival:
                job = env.generate_feature()
                if job is not None:
                    env.queue.append(job)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                
                if self.rng.uniform(0, 1) < eps_2 and is_new_job_arrival:
                    opt_job_idx = len(env.queue) - 1
                    opt_server_idx = self.rng.choice(self.K)
                
                else:
                    max_ucb = -np.inf
                    for job_idx, job in enumerate(env.queue):
                        for server_idx in range(self.K):
                            ucb = self.UCB(job, server_idx)
                            if ucb > max_ucb:
                                max_ucb = ucb
                                opt_job_idx, opt_server_idx = job_idx, server_idx
                                
                
                if opt_job_idx != -1:
                    job_to_process = env.queue[opt_job_idx]
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process)
                        self.rewards[opt_server_idx].append(reward)

                        self.Vs[opt_server_idx] += np.outer(job_to_process, job_to_process)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx)
            else:
                env.step(-1, -1)



class CQBts(CQBOpt):
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self._init_params(env, lambda_, eps, d, K, kappa, L, S, reg, seed)
        self.beta_base = 1.0

    def get_TS_score(self, x, k):
        V_inv = self.Vs_inv[k]
        beta_k = self.betas[k] * self.beta_base 
        
        mean = np.dot(x, self.theta_hats[k])
        
        uncertainty_term = np.dot(np.dot(x, V_inv), x)
        variance = 16 * (beta_k**2) * uncertainty_term
        
        if variance < 1e-8:
            variance = 1e-8
            
        std_dev = np.sqrt(variance)
        
        ts_score = self.rng.normal(mean, std_dev)
            
        return ts_score 

    def run(self, T):
        env = self.env
        
        for t in range(T):
            is_new_job_arrival = (self.rng.uniform(0, 1) < self.lambda_)
            
            if is_new_job_arrival:
                x = env.generate_feature()
                if x is not None:
                    env.queue.append(x)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                max_ts_score = -np.inf
                for job_idx, job in enumerate(env.queue):
                    for server_idx in range(self.K):
                        ts_score = self.get_TS_score(job, server_idx)
                        if ts_score > max_ts_score:
                            max_ts_score = ts_score
                            opt_job_idx, opt_server_idx = job_idx, server_idx
                            
                
                if opt_job_idx != -1:
                    job_to_process = env.queue[opt_job_idx]
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1:
                        self.Xs[opt_server_idx].append(job_to_process)
                        self.rewards[opt_server_idx].append(reward)
                        
                        self.Vs[opt_server_idx] += np.outer(job_to_process, job_to_process)
                        self.Vs_inv[opt_server_idx] = np.linalg.inv(self.Vs[opt_server_idx])
                        self.counts[opt_server_idx] += 1
                        
                        self.update_beta(opt_server_idx)
                        self.update_estimator(opt_server_idx)
            
            else:
                env.step(-1, -1)



class UCB1:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self.rng = env.rng
        self.env = env
        self.K = env.K
        self.lambda_ = lambda_
        
        self.counts = [0 for k in range(self.K)]
        self.rewards_sum = [0.0 for k in range(self.K)]

    def get_ucb_value(self, t):
        ucb_values = []
        
        if t <= 0:
            return -1, -1
        
        t_safe = t + 1 
        
        for k in range(self.K):
            T_k = self.counts[k]
            
            if T_k == 0:
                ucb_values.append(np.inf)
                continue
            
            hat_mu_k = self.rewards_sum[k] / T_k
            
            ucb_bonus = np.sqrt((np.log(t_safe)**2) / (2 * T_k))
            
            ucb_value = hat_mu_k + ucb_bonus
            ucb_values.append(ucb_value)
        
        opt_server_idx = np.argmax(ucb_values)
        opt_job_idx = 0
        
        return opt_job_idx, opt_server_idx
    
    def get_bernoulli_explore_sample(self, t):
        if t <= 1:
            return 1
            
        t_safe = t + 1
        log_t_sq = np.log(t_safe)**2
        p_explore = min(1.0, (3 * self.K * log_t_sq) / t_safe)
        
        return 1 if self.rng.random() < p_explore else 0

    def run(self, T):
        env = self.env
        
        for t in range(1, T + 1): 
            is_new_job_arrival = (env.rng.uniform(0, 1) < env.lambda_)
            
            if is_new_job_arrival:
                job = env.generate_feature()
                if job is not None:
                    env.queue.append(job)
            
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                E_t = self.get_bernoulli_explore_sample(t)
                
                if E_t == 1:
                    opt_job_idx = 0
                    opt_server_idx = self.rng.choice(self.K)
                else:
                    opt_job_idx, opt_server_idx = self.get_ucb_value(t)
                
                if opt_job_idx != -1: 
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1: 
                        self.counts[opt_server_idx] += 1       
                        self.rewards_sum[opt_server_idx] += reward
            
            else:
                env.step(-1, -1)



class TS1:
    def __init__(self, env, lambda_=0.2, eps=0.1, d=5, K=3, kappa=10, L=1, S=1, reg=1, seed=0):
        self.rng = env.rng
        self.env = env
        self.K = env.K
        self.lambda_ = lambda_
        
        self.counts = [0 for k in range(self.K)]
        self.rewards_sum = [0.0 for k in range(self.K)]

    def get_thompson_sample_server(self, t):
        max_sample = -np.inf
        opt_server_idx = -1
        
        for k in range(self.K):
            T_k = self.counts[k]
            S_k = self.rewards_sum[k] 
            
            if T_k == 0:
                return 0, k
            
            F_k = T_k - S_k
            alpha = S_k + 1
            beta = F_k + 1
            
            sample = self.rng.beta(alpha, beta)
            
            if sample > max_sample:
                max_sample = sample
                opt_server_idx = k
        
        opt_job_idx = 0
        return opt_job_idx, opt_server_idx

    def get_bernoulli_explore_sample(self, t):
        if t <= 1:
            return 1
        t_safe = t + 1
        log_t_sq = np.log(t_safe)**2
        p_explore = min(1.0, (3 * self.K * log_t_sq) / t_safe)
        return 1 if self.rng.random() < p_explore else 0

    def run(self, T):
        env = self.env
        
        for t in range(1, T + 1):
            is_new_job_arrival = (env.rng.uniform(0, 1) < env.lambda_)
            
            if is_new_job_arrival:
                job = env.generate_feature()
                if job is not None:
                    env.queue.append(job)
                    
            opt_job_idx = -1
            opt_server_idx = -1
            
            if len(env.queue) > 0:
                E_t = self.get_bernoulli_explore_sample(t)
                
                if E_t == 1:
                    opt_job_idx = 0
                    opt_server_idx = self.rng.choice(self.K)
                else:
                    opt_job_idx, opt_server_idx = self.get_thompson_sample_server(t)
                
                if opt_job_idx != -1 and opt_server_idx != -1:
                    reward = env.step(opt_job_idx, opt_server_idx)
                    
                    if reward != -1: 
                        self.counts[opt_server_idx] += 1
                        self.rewards_sum[opt_server_idx] += reward
            else:
                env.step(-1, -1)