import numpy as np
import pandas as pd

from util import sigmoid, inv_dot_sigmoid 
from tensorflow.keras.datasets import mnist

class ContextualQueue:
    def __init__(self, data_file=None, 
                 lambda_=0.7, T=60000, seed=0, data_load_option='keras', **kwargs):
        
        self.rng = np.random.default_rng(seed)
        self.lambda_ = lambda_
        self.T = T 
        

        self.X_data, self.y_data = self._load_data(data_file, data_load_option)
        self.N = len(self.X_data) 
        self.K = 10 
        self.d = self.X_data.shape[1] 
        
        self.eps = kwargs.get('eps', 0.1)
        self.kappa = kwargs.get('kappa', 10)
        self.L = kwargs.get('L', 3)
        self.S = kwargs.get('S', 1)
        
        self.queue, self.theta_list, self.queue_length_history = self.reset()
        self.count = -1

    def _downsample_28x28_to_7x7(self, images):

        N, H, W = images.shape
        pool_size = 4

        images_reshaped = images.reshape(N, H // pool_size, pool_size, 
                                        W // pool_size, pool_size)
        

        images_pooled = images_reshaped.mean(axis=(2, 4))
        
        X_downsampled = images_pooled.reshape(N, -1)
        
        return X_downsampled


    def _load_data(self, data_file, option):

        if option == 'keras':
            try:

                (X_train, y_train), (_, _) = mnist.load_data()
                
                X_train = X_train.astype('float32') / 255.0

                X = self._downsample_28x28_to_7x7(X_train)
                
                y = y_train.astype(int)
                
                return X, y
            except Exception as e:
                raise
        
        elif option == 'file':

            try:
                df = pd.read_csv(data_file)
                target_col = df.columns[-1]
                X = df.drop(columns=[target_col]).values.astype(float)

                y = df[target_col].values.astype(int)
                y = np.clip(y, 0, self.K - 1)
                print(f"feature size {X.shape}, target size: {y.shape}")
                return X, y
            except FileNotFoundError:
                print(f"Error '{data_file}'")
                raise
            except Exception as e:
                print(f"Error: {e}")
                raise
        else:
            raise ValueError("Error")


    def reset(self):
        queue = []
        theta_list = [np.zeros(self.d) for k in range(self.K)] 
        queue_length_history = []
        return queue, theta_list, queue_length_history
    
    
    def step(self, job_idx, server_idx):

        reward = -1 
        job_feature = None 
        y_true = -1        
        
        if job_idx != -1 and len(self.queue) > 0:
            data_idx = self.queue.pop(job_idx) 
            job_feature = self.X_data[data_idx]
            y_true = self.y_data[data_idx]
            
            is_correct = (y_true == server_idx)
            p = 1.0 if is_correct else 0.0
            reward = self.rng.binomial(1, p)
            
            if not reward:
                self.queue.append(data_idx) 
                
        self.queue_length_history.append(len(self.queue))

        return reward, job_feature, y_true

    def generate_feature(self):

        self.count += 1
        
        if self.count < self.N:
            return self.count 
        else:
            return self.rng.choice(self.N)

    def generate_arrival(self, arrival_rate=None):

        lambda_ = arrival_rate if arrival_rate is not None else self.lambda_
        is_new_job_arrival = (self.rng.uniform(0, 1) < lambda_)
        return is_new_job_arrival