import numpy as np

class Event(object):
    '''
    Prediction, local update or message reception.
    '''
    def __init__(self, t, index, tag='predict', msg=None):
        self.t, self.index = t, index
        self.tag = tag
        self.message = msg
        if not tag in ['predict', 'update', 'message']:
            raise NameError('Invalid event tag')
        
    def info_string(self):
        return '%s on agent %i at time %.2f'%(self.tag, self.index, self.t)

class TimeLine(object):
    '''
    A simulated timeline for the asynchronous decentralized online learning process
    '''
    def __init__(self):
        self.timeline = []
        self.t = .0
    
    def add_event(self, event):
        t = event.t
        try:
            counter = 0
            while True:
                ev = self.timeline[counter]
                if ev.t > t:
                    self.timeline.insert(counter, event)
                    break
                counter += 1
        except IndexError:
            self.timeline.append(event)
    
    def next_event(self):
        event = self.timeline.pop(0)
        self.t = event.t
        return event
    
    def print_timeline(self):
        s = ' '
        for event in self.timeline:
            s = '%s %s-%i-%.2f,'%(s, event.tag_name, event.worker, event.t)
        return s

class DelayGenerator(object):
    '''
    Simulate different types of time delay in our experiments.
    '''
    def __init__(self, network, **kwargs):
        np.random.seed(123)
        M = network.M
        self.mu_proc = np.abs(np.random.randn(M)) + 2.0
        self.mu_msg = .6
        self.seed_count = 0

        self.slow_learner = 1
        self.slow_proc = kwargs.pop('proc_delay', 1.0) # represents D^p in our experiments
        
        self.slow_link = np.random.choice(network.A_bi[self.slow_learner])
        self.slow_msg = kwargs.pop('msg_delay', 1.0) # represents D^m in our experiments
        self.msg_type = kwargs.pop('msg_type', 'edge') 
        # 'edge': slow down the transmission speed on a single edge, which is used in our main experiments;
        # 'node': slow down the transmission speed on all edges that links to a certain node, which is only used in the evaluation of asymmetric communication
        
        self.pred_interval = kwargs.pop('pred_lag', None)
        print('PRED', self.pred_interval)
        # represent D^l in the evaluation of instantaneous model averaging
    
    def generate_computation_delay(self, worker):
        # simulate the processing time of each feedback
        self.seed_count += 1
        np.random.seed(self.seed_count)
        d = self.mu_proc[worker]/100
        if worker == self.slow_learner:
            return d*self.slow_proc
        return d
    
    def generate_message_delay(self, edge=None):
        # simulate the communication time of each message
        self.seed_count += 1
        np.random.seed(self.seed_count)
        d = np.clip(np.random.exponential(self.mu_msg), .1, 10)/100
        if edge is None:
            return d
        if (self.msg_type == 'edge') and (edge == (self.slow_learner, self.slow_link) or edge == (self.slow_link, self.slow_learner)):
            return d*self.slow_msg
        if (self.msg_type == 'node') and (self.slow_learner in edge):
            return d*self.slow_msg
        return d
    
    def generate_prediction_interval(self, worker):
        # simulate the time gap between the local update and the incoming prediction
        # it is set to zero by default
        # we only consider positive interval when examing the effectiveness of intantaneous model averaging
        self.seed_count += 1
        if worker == 1 and self.pred_interval is not None:
            return np.random.exponential(0.6)/100*self.pred_interval
        else:
            return 0
        