import logging
import uuid

from envs.constants import *
from envs.task_queue import TaskQueue

logger = logging.getLogger(__name__)


class ServerNode:
    def __init__(self, computational_capability, is_random_task_generating=False):
        super().__init__()
        self.uuid = uuid.uuid4()
        self.links_to_lower = {} 
        self.links_to_higher = {} 
        self.mobility = False
        self.computational_capability = computational_capability 
        self.task_queue = TaskQueue()
        self.cpu_used = 0

    def __del__(self):
        del self.task_queue
        del self

    def do_tasks(self):
        if self.task_queue.get_length():
            self.cpu_used, _ = self.task_queue.served(self.computational_capability, type=1)
        else:
            self.cpu_used=0
        return self.cpu_used

    def probed(self, app_type, bits_to_be_arrived):
        if self.task_queue.get_max() < self.task_queue.get_length()+bits_to_be_arrived:
            return (0, True)
        else:
            return (bits_to_be_arrived, False)

    def offloaded_tasks(self, tasks, arrival_timestamp):
        failed_to_offload = 0
        for task_id, task_ob in tasks.items():
            task_ob.client_index = task_ob.server_index
            task_ob.server_index = self.get_uuid()
            task_ob.set_arrival_time(arrival_timestamp)
            failed_to_offload += (not self.task_queue.arrived(task_ob, arrival_timestamp))
        return failed_to_offload

    def get_higher_node_ids(self):
        return list(self.links_to_higher.keys())

    def get_task_queue_length(self, scale=1):
        return self.task_queue.get_length(scale=scale)

    def sample_channel_rate(self, linked_id):
        if linked_id in self.links_to_higher.keys():
            return self.links_to_higher[linked_id]['channel'].get_rate()
        elif linked_id in self.links_to_lower.keys():
            return self.links_to_lower[linked_id]['channel'].get_rate(False)

    def get_uuid(self):
        return self.uuid.hex

    def _get_obs(self, time, estimate_interval=100, involve_capability=False, scale=1):
        return [self.task_queue.mean_arrival(time, estimate_interval, scale=scale),\
                self.task_queue.last_arrival(time, scale=scale), self.task_queue.get_cpu_needs(scale=scale),\
                self.cpu_used/self.computational_capability]
