import numpy as np


class passive_wireless_model:
    def __init__(self,comm_env,num_agent,channel_dist,fast,seed):
        # the communication rate follows the normal distribution
        self.num_agent = num_agent


        if comm_env.split('_')[0] == 'device': #in dynamic, the mu and sigma may also change
            self.client_number = [int(round(float(x)* self.num_agent)) for x in comm_env.split('_')[1::]]

            self.data_rate = [int(round(float(x)* fast)) for x in channel_dist.split('_')[1::]]
            self.slow_channel = min(self.data_rate)
            # check the number of clients
            if sum(self.client_number) != self.num_agent:
                raise ValueError('Please reset the number of clients using different devices.')
            if len(self.client_number)!=len(self.data_rate):
                raise ValueError('Please check the device and channel setting to be consistent.')

            self.full_channel = []
            for client_number,data_rate in zip(self.client_number,self.data_rate):
                for _ in range(client_number):
                    self.full_channel.append(data_rate)
            self.full_channel = np.array(self.full_channel)
        else:
            raise ValueError('Only support device mode!')

        
    def __call__(self):
        return self.full_channel,self.slow_channel


class passive_wireless_model_3sets:
    def __init__(self,comm_env,num_agent,channel_dist,fast,seed):
        # the communication rate follows the normal distribution
        self.num_agent = num_agent


        if comm_env.split('_')[0] == 'device': #in dynamic, the mu and sigma may also change
            self.client_number = [int(round(float(x)* self.num_agent)) for x in comm_env.split('_')[1::]]

            self.data_rate = [int(round(float(x)* fast)) for x in channel_dist.split('_')[1::]]
            self.slow_channel = min(self.data_rate)

            self.median_channel = int(np.median(self.data_rate))

            # check the number of clients
            if sum(self.client_number) != self.num_agent:
                raise ValueError('Please reset the number of clients using different devices.')
            if len(self.client_number)!=len(self.data_rate):
                raise ValueError('Please check the device and channel setting to be consistent.')

            self.full_channel = []
            for client_number,data_rate in zip(self.client_number,self.data_rate):
                for _ in range(client_number):
                    self.full_channel.append(data_rate)
            self.full_channel = np.array(self.full_channel)
        else:
            raise ValueError('Only support device mode!')

        
    def __call__(self):
        return self.full_channel,self.slow_channel,self.median_channel


class passive_wireless_model_multiset:
    def __init__(self,comm_env,num_agent,channel_dist,fast,seed):
        # the communication rate follows the normal distribution
        self.num_agent = num_agent

        if comm_env.split('_')[0] == 'device': #in dynamic, the mu and sigma may also change
            self.client_number = [int(round(float(x)* self.num_agent)) for x in comm_env.split('_')[1::]]

            self.data_rate = [int(round(float(x)* fast)) for x in channel_dist.split('_')[1::]]
            
            self.channels_except_full = sorted(self.data_rate)

            # self.median_channel = int(np.median(self.data_rate))

            # check the number of clients
            if sum(self.client_number) != self.num_agent:
                raise ValueError('Please reset the number of clients using different devices.')
            if len(self.client_number)!=len(self.data_rate):
                raise ValueError('Please check the device and channel setting to be consistent.')

            self.full_channel = []
            for client_number,data_rate in zip(self.client_number,self.data_rate):
                for _ in range(client_number):
                    self.full_channel.append(data_rate)
            self.full_channel = np.array(self.full_channel)
        else:
            raise ValueError('Only support device mode!')

        
    def __call__(self):
        return self.full_channel,self.channels_except_full[:-1]

