import numpy as np

class resource_allocation:
    def __init__(self,num_agent,max_channel,batch_size,base_size,orig_channel,indicator='useBL'):
        # constants for importance indicator
        self.base_size = base_size
        self.time_tick = {}
        self.orig_channel = orig_channel


        self.indicator = indicator
        # self.deadline = float(self.indicator.split('_')[1])
        self.num_agent = num_agent
        self.max_channel = max_channel
        self.batch_size = batch_size
        self.all_samples = 50000
 
  
    def dropdata_scheduling_policy(self):
        # reduce the batch size of slow clients.
        predefined_channel = int(self.orig_channel)
        # channel size of different clients
        self.channel_allocation = np.array([predefined_channel for _ in range(self.num_agent)])
        self.adjusted_batch_size = np.array([self.batch_size for _ in range(self.num_agent)])
        self.time_per_round = 1
        self.samplereduceratio = np.array([fading_channel/self.orig_channel for fading_channel in self.fading_channels])
        return self.adjusted_batch_size, self.channel_allocation, self.time_per_round, self.samplereduceratio
    
    def useBL_scheduling_policy(self):
        # reduce the batch size of slow clients.
        predefined_channel = int(np.max(self.fading_channels))
        if predefined_channel > self.max_channel:
            predefined_channel = self.max_channel
        # channel size of different clients
        self.channel_allocation = np.array(self.fading_channels)
        self.adjusted_batch_size = np.array([self.batch_size for _ in range(self.num_agent)])
        self.time_per_round = 1
        self.samplereduceratio = np.array([1 for _ in range(self.num_agent)])
        return self.adjusted_batch_size, self.channel_allocation, self.time_per_round, self.samplereduceratio

    def allocation(self,fading_channels,predefined=[]):
        # the fading channel is the communication rate Mbps
        self.fading_channels = fading_channels
        if self.indicator == 'dropdata':
            return self.dropdata_scheduling_policy()
        elif self.indicator == 'useBL':
            return self.useBL_scheduling_policy()
        else:
            raise ValueError("The indicator is not supported.")


