import numpy as np
import pandas as pd
from util import sigmoid, inv_dot_sigmoid 

class ContextualQueue:
    
    def __init__(self, data_file='coupon.csv', 
                 lambda_=0.4, T=4000, seed=0, **kwargs):
        
        self.rng = np.random.default_rng(seed)
        self.lambda_ = lambda_
        self.T = T  
        

        self.X_data, self.y_data = self._load_data(data_file)
        self.N = len(self.X_data) 
        self.d = self.X_data.shape[1] 
        self.K = 2 

        self.eps = kwargs.get('eps', 0.1)
        self.kappa = kwargs.get('kappa', 10)
        self.L = kwargs.get('L', 3)
        self.S = kwargs.get('S', 1)
        
        self.queue, self.theta_list, self.queue_length_history = self.reset()
        self.count = -1

    def _load_data(self, data_file):

        try:
            df = pd.read_csv(data_file, low_memory=False) 
            
            target_col = df.columns[-1]
            
            y = df[target_col]
            y = y.astype(str).str.upper().map(
                lambda x: 1 if x in ['1', '1.0', 'Y', 'YES'] else 0
            )
            y = y.values.astype(int)
            
            X = df.drop(columns=[target_col])
            
            X = pd.get_dummies(X, drop_first=True)
            
            X = X.values.astype(float)
            
            y = np.clip(y, 0, 1) 
            
            print(f"feature size: {X.shape}, Target size: {y.shape}")
            return X, y
        except FileNotFoundError:
            print(f"Error: '{data_file}'")
            raise
        except Exception as e:
            print(f"Error: {e}")

            raise

    def reset(self):
        queue = []
        theta_list = [np.zeros(self.d) for k in range(self.K)] 
        queue_length_history = []
        return queue, theta_list, queue_length_history
    
    def step(self, job_idx, server_idx):

        reward = -1 
        job_feature = None 
        y_true = -1        
        
        if job_idx != -1 and len(self.queue) > 0:
            data_idx = self.queue.pop(job_idx) 
            
            job_feature = self.X_data[data_idx]
            y_true = self.y_data[data_idx]
            
            is_correct = (y_true == server_idx)
            
            p = 1.0 if is_correct else 0.0
            
            reward = self.rng.binomial(1, p)
            
            if not reward:
                self.queue.append(data_idx) 
                
        self.queue_length_history.append(len(self.queue))

        return reward, job_feature, y_true

    def generate_feature(self):

        self.count += 1
        
        if self.count < self.N:
            return self.count 
        else:

            return self.rng.choice(self.N)

    def generate_arrival(self, arrival_rate=None):

        lambda_ = arrival_rate if arrival_rate is not None else self.lambda_
        is_new_job_arrival = (self.rng.uniform(0, 1) < lambda_)
        return is_new_job_arrival
        