import logging
import uuid
import numpy as np

from envs.constants import *
from envs import applications
from envs.task import Task
from envs.task_queue import TaskQueue

logger = logging.getLogger(__name__)


class ServerNode:
    def __init__(self, computational_capability, is_random_task_generating=False):
        super().__init__()
        self.uuid = uuid.uuid4()
        self.links_to_lower = {} 
        self.links_to_higher = {} 
        self.mobility = False
        self.computational_capability = computational_capability  
        self.number_of_applications = 0
        self.queue_list = {} 
        self.is_random_task_generating = is_random_task_generating
        self.cpu_used = {}

    def __del__(self):
        iter = list(self.queue_list.keys())
        for app_type in iter:
            del self.queue_list[app_type]
        del self

    def make_application_queues(self, *application_types):
        for application_type in application_types:
            self.queue_list[application_type] = TaskQueue(application_type)
            self.number_of_applications += 1
            self.cpu_used[application_type] = 0
        return

    def do_tasks(self, alpha):
        app_type_list = list(self.queue_list.keys())
        cpu_allocs = dict(zip(app_type_list, alpha))

        for app_type in app_type_list:
            if cpu_allocs[app_type] == 0 or (app_type not in self.queue_list.keys()):
                pass
            else:
                my_task_queue = self.queue_list[app_type]
                if my_task_queue.get_length():
                    cpu_allocs[app_type], _ = my_task_queue.served(cpu_allocs[app_type]*self.computational_capability, type=1)

                else:
                    cpu_allocs[app_type]=0

        self.cpu_used = cpu_allocs
        
        return sum(cpu_allocs.values())

    def _probe(self, bits_to_be_arrived, id_to_offload):
        node_to_offload = self.links_to_higher[id_to_offload]['node']
        failed = {}
        for app_type, bits in bits_to_be_arrived.items():
            if (app_type in self.queue_list.keys()):
                bits_to_be_arrived[app_type], failed[app_type] = node_to_offload.probed(app_type, bits)
        return bits_to_be_arrived, failed

    def probed(self, app_type, bits_to_be_arrived):
        if self.queue_list[app_type]:
            if self.queue_list[app_type].get_max() < self.queue_list[app_type].get_length()+bits_to_be_arrived:
                return (0, True)
            else:
                return (bits_to_be_arrived, False)
        return (0, True)

    def offload_tasks(self, beta, id_to_offload):
        channel_rate = self.sample_channel_rate(id_to_offload)
        app_type_list = list(self.queue_list.keys())
        lengths = self.get_queue_lengths()
        tx_allocs = dict(zip(app_type_list, np.minimum(lengths,np.array(beta)*channel_rate).astype(int)))
        tx_allocs, failed = self._probe(tx_allocs, id_to_offload)
        task_to_be_offloaded = {}
        for app_type in app_type_list:
            if tx_allocs[app_type] ==0 or (app_type not in self.queue_list.keys()):
                pass
            else:
                my_task_queue = self.queue_list[app_type]
                if my_task_queue.get_length():
                    tx_allocs[app_type], new_to_be_offloaded = my_task_queue.served(tx_allocs[app_type], type=0)
                    task_to_be_offloaded.update(new_to_be_offloaded)
                else:
                    tx_allocs[app_type]=0
        return sum(tx_allocs.values()), task_to_be_offloaded, failed

    def offloaded_tasks(self, tasks, arrival_timestamp):
        failed_to_offload = 0
        for task_id, task_ob in tasks.items():
            task_ob.client_index = task_ob.server_index
            task_ob.server_index = self.get_uuid()
            task_ob.set_arrival_time(arrival_timestamp)
            failed_to_offload += (not self.queue_list[task_ob.application_type].arrived(task_ob, arrival_timestamp))
        return failed_to_offload

    def random_task_generation(self, task_rate, arrival_timestamp, *app_types):
        app_type_list = applications.app_type_list()
        app_type_pop = applications.app_type_pop()
        this_app_type_list = list(self.queue_list.keys())
        random_id = uuid.uuid4()
        arrival_size = np.zeros(len(app_types))
        failed_to_generate = 0
        for app_type, population in app_type_pop:
            if app_type in this_app_type_list:
                data_size = np.random.poisson(task_rate*population)*applications.arrival_bits(app_type)
                if data_size >0:
                    task = Task(app_type, data_size, client_index = random_id.hex, server_index = self.get_uuid(), arrival_timestamp=arrival_timestamp)
                    failed_to_generate += (not self.queue_list[app_type].arrived(task, arrival_timestamp))
                    arrival_size[app_type-1]= data_size
                else:
                    pass

            else:
                pass
        return arrival_size, failed_to_generate

    def get_higher_node_ids(self):
        return list(self.links_to_higher.keys())

    def get_queue_list(self):
        return self.queue_list.items()

    def get_queue_lengths(self, scale = 1):
        lengths = np.zeros(len(self.queue_list))
        for app_type, queue in self.queue_list.items():
            lengths[app_type-1]=queue.get_length(scale)
        return lengths

    def sample_channel_rate(self, linked_id):
        if linked_id in self.links_to_higher.keys():
            return self.links_to_higher[linked_id]['channel'].get_rate()
        elif linked_id in self.links_to_lower.keys():
            return self.links_to_lower[linked_id]['channel'].get_rate(False)

    def get_applications(self):
        return list(self.queue_list.keys())

    def get_uuid(self):
        return self.uuid.hex

    def _get_obs(self, time, estimate_interval=100, involve_capability=False, scale=1):
        queue_estimated_arrivals = np.zeros(3)
        queue_arrivals = np.zeros(3)
        queue_lengths = np.zeros(3)
        app_info = np.zeros(3)
        cpu_used = np.zeros(3)        
        for app_type, queue in self.queue_list.items():
            queue_estimated_arrivals[app_type-1] = queue.mean_arrival(time, estimate_interval, scale=scale)
            queue_arrivals[app_type-1] = queue.last_arrival(time, scale=scale)
            queue_lengths[app_type-1] = queue.get_length(scale=scale)
            app_info[app_type-1] = applications.get_info(app_type, "workload")/KB
            cpu_used[app_type-1] = self.cpu_used[app_type]/self.computational_capability
        if involve_capability:
            return list(queue_lengths) + [self.computational_capability/GHZ]
        return list(queue_estimated_arrivals)+list(queue_arrivals)+list(queue_lengths)+list(cpu_used)+list(app_info)
