import numpy as np
import pandas as pd
from util import sigmoid, inv_dot_sigmoid 

class ContextualQueue:
    def __init__(self, data_file='', 
                 lambda_=0.7, T=5000, seed=0, **kwargs):
        
        self.rng = np.random.default_rng(seed)
        self.lambda_ = lambda_
        self.T = T  
        
        self.X_data, self.y_data = self._load_data(data_file)
        self.N = len(self.X_data) 
        self.d = self.X_data.shape[1] 
        self.K = 5 
        
        self.eps = kwargs.get('eps', 0.1)
        self.kappa = kwargs.get('kappa', 10)
        self.L = kwargs.get('L', 3)
        self.S = kwargs.get('S', 1)
        
        self.queue, self.theta_list, self.queue_length_history = self.reset()
        self.count = -1

    def _load_data(self, data_file):
        try:
            df = pd.read_csv(data_file)
            
            target_col = df.columns[-1]
            X = df.drop(columns=[target_col]).values.astype(float)
            y = df[target_col].values.astype(int)

            y = np.clip(y, 0, 4) 
            
            print(f"feature size: {X.shape}, target size: {y.shape}")
            return X, y
        except FileNotFoundError:
            print(f"Error: '{data_file}")
            raise
        except Exception as e:
            print(f"Error: {e}")
            raise

    def reset(self):
        queue = []
        theta_list = [np.zeros(self.d) for k in range(self.K)] 
        queue_length_history = []
        return queue, theta_list, queue_length_history
    
    def step(self, job_idx, server_idx):

        reward = -1 
        job_feature = None 
        y_true = -1        
        
        if job_idx != -1 and len(self.queue) > 0:
            data_idx = self.queue.pop(job_idx) 
            
            job_feature = self.X_data[data_idx]
            y_true = self.y_data[data_idx]
            

            is_correct = (y_true == server_idx)

            p = 1.0 if is_correct else 0.0

            reward = self.rng.binomial(1, p)
            
            if not reward:
                self.queue.append(data_idx) 
                
        self.queue_length_history.append(len(self.queue))

        return reward, job_feature, y_true

    def generate_feature(self):

        self.count += 1
        
        if self.count < self.N:

            return self.count 
        else:

            return self.rng.choice(self.N)

    def generate_arrival(self, arrival_rate=None):

        lambda_ = arrival_rate if arrival_rate is not None else self.lambda_
        is_new_job_arrival = (self.rng.uniform(0, 1) < lambda_)
        return is_new_job_arrival