import numpy as np

from util import *

class ContextualQueue:
    def __init__(self, lambda_=0.2, eps=0.1, d=5, K=3, N=5, T=1000, kappa=10, L=1, S=1, seed=0):
        self.rng = np.random.default_rng(seed)
        self.lambda_ = lambda_
        self.eps = eps
        self.d = d
        self.K = K
        self.N = N
        self.T = T
        self.kappa = kappa
        self.L = L
        self.S = S
        
        self.queue, self.queue_optimal, self.theta_list, self.queue_length_history = self.reset()
        self.finite_job_set = self.generate_finite_set()
        self.count = -1

    def reset(self):
        queue = []
        queue_optimal = []
        theta_list = [self.rng.uniform(-self.S, self.S, self.d) for k in range(self.K)]
        queue_length_history = []
        return queue, queue_optimal, theta_list, queue_length_history
    
    def step(self, job_idx, server_idx):
        if len(self.queue) > 0:
            job = self.queue.pop(job_idx)
            theta = self.theta_list[server_idx]
            reward = self.rng.binomial(1, sigmoid(np.dot(job, theta)))
            if not reward:
                self.queue.append(job)
        else:
            reward = -1

        self.queue_length_history.append(len(self.queue))

        return reward
    
    def generate_finite_set(self):
        job_set = []
        for _ in range(self.N):
            while True:
                x = self.rng.uniform(-self.L, self.L, self.d)
                sigmoid_list = [sigmoid(np.dot(x, theta)) for theta in self.theta_list]
                inv_dot_sigmoid_list = [inv_dot_sigmoid(np.dot(x, theta)) for theta in self.theta_list]
                
                if all(s < self.lambda_ + self.eps for s in sigmoid_list):
                    continue
                if any(inv_dot_sigmoid_list) > self.kappa:
                    continue
                
                job_set.append(x)
                break
        return job_set

    def generate_feature(self):
        if len(self.finite_job_set) == self.T:
            self.count += 1
            return self.finite_job_set[self.count]
        elif len(self.finite_job_set) > 0:
            return self.finite_job_set[self.rng.choice(len(self.finite_job_set))]
        else:
            return self.rng.uniform(-self.L, self.L, self.d)
