import numpy as np
import time

from agent import GossipAgent, AgentLian, AgentAPS, PushSumAgent_random
from event import Event, TimeLine, DelayGenerator
from my_utils import Logit

class SimulatedSystem(object):
    '''
    Our proposed AD-OGP which realizes instantaneous model averaging and 
    asymmetric communication.
    '''
    
    def __init__(self, network, **kwargs):
        self.network = network
        self.M = network.M
        self.ref_j = 1

    def _set_processing_params(self, **kwargs):
        self.loss = kwargs.pop('loss', Logit) # loss type
        self.classes = kwargs.pop('classes', 2) # class number in classification
        self.lr = kwargs.pop('lr') # learning rate
        self.K = kwargs.pop('K') # diameter of L2-norm ball as the decision set
        self.record_round = kwargs.pop('record_round', 10000)
        self.report_round = kwargs.pop('report_round', 10000)

    def process(self, stream, **kwargs):
        self._set_processing_params(**kwargs)
        
        stream.reset()
        N, M, D = stream.N, self.M, stream.D
        self.stream = stream
        self.D = D
        
        self.x = np.zeros([M, self.classes, D]) # push-sum parameters
        self.y = np.ones(M)
        
        self.x_buffer = np.zeros([M, self.classes, D]) # each learner maintains a buffer to collect received messages
        self.y_buffer = np.zeros(M)
        
        self.is_occupied = [False for i in range(M)] # a flag denoting whether each learner is occupied
        
        self.delay_generator = DelayGenerator(self.network, **kwargs)
        self.timeline = TimeLine()
        
        self.t = .0
        self.instance_count = 0
        instance_max = min(N, kwargs.pop('N', 1e20))
        
        print('Number of instances %i'%(instance_max))
        
        record_round, report_round = self.record_round, self.report_round
        if report_round < record_round:
            report_round = record_round
        
        self.actual_loss, self.ref_loss = .0, .0
        self.actual_loss_record, self.ref_loss_record, self.realtime_record, self.instance_record = [], [], [], []
        for agent in range(M):
            self.timeline.add_event(Event(self.t, agent, 'predict'))
        
        cur_record = 0
        
        while self.instance_count < instance_max:
            self._execute(self.timeline.next_event())
            
            if cur_record < self.instance_count:
                cur_record = self.instance_count
            
                if not (self.instance_count) % record_round:
                    self.actual_loss_record.append(self.actual_loss / self.instance_count)
                    self.ref_loss_record.append(self.ref_loss / self.instance_count)
                    self.realtime_record.append(self.t)
                    self.instance_record.append(self.instance_count)
                if not (self.instance_count) % report_round:
                    print('%i instances processed out of %i'%(self.instance_count, instance_max))
                    print('Actual Loss %f, Ref Loss %f, Running Time %f'%(self.actual_loss_record[-1], self.ref_loss_record[-1], self.realtime_record[-1]))
        
        self.RECORD = {'actual':self.actual_loss_record, 'ref':self.ref_loss_record, 
                       't':self.realtime_record, 'n':self.instance_record}    
        return self.ref_loss / self.instance_count
    
    def _execute(self, event):
        if event.tag == 'predict':
            self._execute_prediction(event)
        
        elif event.tag == 'update':
            self._execute_local_update(event)
        
        elif event.tag == 'message':
            self._execute_message_reception(event)
    
    def _execute_prediction(self, event): # make prediction
        self.t = event.t
        index = event.index
        
        feature, label = self.stream.next_instance()
        self.ref_loss += self.loss.loss(self.x[self.ref_j]/self.y[self.ref_j], feature, label)
        self.actual_loss += self.loss.loss(self.x[index]/self.y[index], feature, label)
        gradient = self.loss.gradient(self.x[index]/self.y[index], feature, label)
        
        lag = self.delay_generator.generate_computation_delay(index) # equals to ZERO by default. The positive lag is ONLY used in the evaluation of instantaneous model averaing.
        self.timeline.add_event(Event(self.t+lag, index, 'update', gradient))
        self.is_occupied[index] = True # until the next local update, this learner is occupied
        
        self.instance_count += 1
    
    def _execute_local_update(self, event): # perform local update
        self.t = event.t
        index = event.index
        
        self.x[index] -= self.y[index] * self.lr * event.message
        norm = np.linalg.norm(self.x[index]/self.y[index]) # weighted projection
        if norm > self.K:
            self.x[index] *= self.y[index]*self.K/norm

        self.is_occupied[index] = False # after processing, the learner becomes unoccupied

        if self.y[index] > 1e-4: # The threshold is set for numerable stability. This is commonly used in other push-sum implementations [Assran 2020; Assran 2020].

            N_t = len(self.network.A[index])+1
            self.x[index] /= N_t # send out model copy
            self.y[index] /= N_t
            
            for target in self.network.A[index]: # send out model copy to each of its neighbors
                message_delay = self.delay_generator.generate_message_delay((index, target))
                self.timeline.add_event(Event(self.t+message_delay, target, 'message', (np.copy(self.x[index]), np.copy(self.y[index]))))
        
        if self.y_buffer[index] > 0: # if there are unprocessed messsages, use these messages for model averaging
            self.x[index] += self.x_buffer[index]
            self.y[index] += self.y_buffer[index]
            self.x_buffer[index] = np.zeros([self.classes, self.D]) # clear the buffer
            self.y_buffer[index] = 0
        
        pred_delay = self.delay_generator.generate_prediction_interval(index)
        self.timeline.add_event(Event(self.t+pred_delay, index, 'predict'))
    
    def _execute_message_reception(self, event): # receive message
        self.t = event.t
        index = event.index
        
        x, y = event.message
        
        if self.is_occupied[index]:
            self.x_buffer[index] += x
            self.y_buffer[index] += y
        else: # if this learner is unoccupied, then it performs model averaging INSTANTANEOUSLY
            self.x[index] += x
            self.y[index] += y


class SimulatedSystemWithoutInstantMA(SimulatedSystem):
    '''
    We remove the instantaneous model averaging mechanism from AD-OGP.
    The reduced algorithm is used in the verification of the effectiveness of instantaneous model averaging.
    '''
    
    def _execute_message_reception(self, event): # receive message
        self.t = event.t
        index = event.index
        
        x, y = event.message
        
        # all messages are stored in receiving buffer UNTIL the learner's next local update
        self.x_buffer[index] += x
        self.y_buffer[index] += y


class SimulatedSystemWithSymmetricGossiping(SimulatedSystem):
    '''
    We replace push-sum with symmetric gossiping (using A-PSGD [Lian et al, 2018]'s implementation).
    The reduced algorithm is used in the verification of the effectiveness of asymmetric communication.
    '''

    def process(self, stream, **kwargs):
        
        print('ACTIVE', self.network.active[self.ref_j])
        
        self._set_processing_params(**kwargs)
        
        stream.reset()
        N, M, D = stream.N, self.M, stream.D
        self.stream = stream
        self.D = D
        
        self.x = np.zeros([M, self.classes, D]) # unlike push-sum, symmetric gossiping does not need to introduce a scalar weight y
        
        self.x_buffer = np.zeros([M, self.classes, D]) # each learner has a receiving buffer to collect received messages
        self.g_buffer = np.zeros([M, self.classes, D]) # each learner has a gradient buffer to collect the gradients that have not yet been used for local update

        self.is_occupied = [False for i in range(M)] # a flag denoting whether each learner is occupied
        self.waiting = [False for i in range(M)] # a flag denoting whether an ACTIVE leaner is waiting for the response message from the PASSIVE learner
        
        self.delay_generator = DelayGenerator(self.network, **kwargs)
        self.timeline = TimeLine()
        
        self.t = .0
        self.instance_count = 0
        instance_max = min(N, kwargs.pop('N', 1e20))
        
        print('Number of instances %i'%(instance_max))
        
        record_round, report_round = self.record_round, self.report_round
        if report_round < record_round:
            report_round = record_round
        
        self.actual_loss, self.ref_loss = .0, .0
        self.actual_loss_record, self.ref_loss_record, self.realtime_record, self.instance_record = [], [], [], []
        for agent in range(M):
            self.timeline.add_event(Event(self.t, agent, 'predict'))

        for agent in range(M):
            if self.network.active[agent]:
                self._send_gossip(agent) # active learners trigger gossiping
        
        cur_record = 0
        
        while self.instance_count < instance_max:
            self._execute(self.timeline.next_event())
            
            if cur_record < self.instance_count:
                cur_record = self.instance_count
            
                if not (self.instance_count) % record_round:
                    self.actual_loss_record.append(self.actual_loss / self.instance_count)
                    self.ref_loss_record.append(self.ref_loss / self.instance_count)
                    self.realtime_record.append(self.t)
                    self.instance_record.append(self.instance_count)
                if not (self.instance_count) % report_round:
                    print('%i instances processed out of %i'%(self.instance_count, instance_max))
                    print('Actual Loss %f, Ref Loss %f, Running Time %f'%(self.actual_loss_record[-1], self.ref_loss_record[-1], self.realtime_record[-1]))
        
        self.RECORD = {'actual':self.actual_loss_record, 'ref':self.ref_loss_record, 
                       't':self.realtime_record, 'n':self.instance_record}    
        return self.ref_loss / self.instance_count
    
    def _execute(self, event):
        if event.tag == 'predict':
            self._execute_prediction(event)
        
        elif event.tag == 'update':
            self._execute_local_update(event)
        
        elif event.tag == 'message':
            self._execute_message_reception(event)

    def _execute_prediction(self, event):
        '''
        In Lian 2018's implementation of symmetric gossiping, when any active learner makes a prediction, 
        the gradients stored in the gradient buffer is TEMPORARILY added to the local model to make a more accurate prediction.
        '''
    
        self.t = event.t
        index = event.index
        
        feature, label = self.stream.next_instance()
        self.ref_loss += self.loss.loss((self.x[self.ref_j]-self.lr*self.g_buffer[self.ref_j]), feature, label)
        if self.network.active[index]:
            x_pred = self.x[index] - self.lr * self.g_buffer[index]
        else:
            x_pred = self.x[index]
        self.actual_loss += self.loss.loss(x_pred, feature, label)
        gradient = self.loss.gradient(x_pred, feature, label)
        
        self.timeline.add_event(Event(self.t, index, 'update', gradient))
        self.is_occupied[index] = True # until the next local update, this learner is occupied
        
        self.instance_count += 1
    
    def _execute_local_update(self, event): # perform local update
        self.t = event.t
        index = event.index

        if self.network.active[index] and self.waiting[index]:  # if the learner is an active learner, and it is waiting for the response message
            self.g_buffer[index] += event.message # the newly computed gradient is stored in the gradient buffer 

        else:   # if the learner is a passive learner, or it is an active learner but is now available for update/triggering gossiping
            self.x[index] -= self.lr * event.message # directly perform local update with that gradient
            if self.network.active[index]:
                self.x[index] -= self.lr * self.g_buffer[index]
                self.g_buffer[index] = np.zeros([self.classes, self.D])
            norm = np.linalg.norm(self.x[index]) # projection
            if norm > self.K:
                self.x[index] *= self.K/norm
        
        self.is_occupied[index] = False # after processing, the learner becomes unoccupied
        
        self.timeline.add_event(Event(self.t, index, 'predict'))
    
    def _execute_message_reception(self, event): 
        '''
        In Lian 2018's implementation of symmetric gossiping, the gossip sender must be an active learner. 
        As soon as any active learner receives the response message from one of its neighbor (which is a passive learner), 
        it uses the response message for model averaging, and then uses the gradients stored in the gradient buffer for local update.
        '''
        self.t = event.t
        index = event.index
        
        x, sender = event.message

        if self.network.active[index]: # if it is an active learner
            self.waiting[index] = False
            self.x[index] = (self.x[index]+x)/2 # perform model averaging
            self.x[index] -= self.lr * self.g_buffer[index] # use the gradients in the gradient buffer for local update
            self.g_buffer[index] = np.zeros([self.classes, self.D])
            norm = np.linalg.norm(self.x[index]) # projection
            if norm > self.K:
                self.x[index] *= self.K/norm
            self._send_gossip(index) # the learner trigger another gossip operation
        else:   # if it is a passive learner
            self._response_gossip(index, sender) # the learner sends back to the active learner
            self.x[index] = (self.x[index]+x)/2 # perform model averaging
            
    def _send_gossip(self, index, sender=None): # an active learner sent out a gossip message
        self.waiting[index] = True
        target = np.random.choice(self.network.A_bi[index]) # randomly select a neighbor (based on the BIPARTITE subgraph of G)
        message_delay = self.delay_generator.generate_message_delay((index, target))
        self.timeline.add_event(Event(self.t+message_delay, target, 'message', (np.copy(self.x[index]), index))) # send out gossip message
 
    def _response_gossip(self, index, sender):
        message_delay = self.delay_generator.generate_message_delay((index, sender))
        self.timeline.add_event(Event(self.t+message_delay, sender, 'message', (np.copy(self.x[index]), index))) # send out gossip message
            