import pdb
import sys
import time
import flask
import random
import datetime
import tensorflow as tf
from socketIO_client import SocketIO, LoggingNamespace
import os

from common.utils import *
from data.generator import DataGenerator
# from models.cnn.cnn_local import LocalCNN
# from models.ewc.ewc_local import LocalEWC
# from models.apd.apd_local import LocalAPD
from models.fedapc.fedapc_local import LocalFedAPC

class Client:

        def __init__(self, opt):
                self.opt = opt
                self.current_task = -1
                self.count_rounds = 0
                self.recent_s_rounds = -1
                self.socketio_client = SocketIO(opt.host_ip, opt.host_port, LoggingNamespace)
                self.add_handlers()

        def get_local_model(self):
                # if get_model(self.opt.model, model='cnn'):
                #         return LocalCNN(self.client_id, self.data_info, self.opt)
                # elif get_model(self.opt.model, model='apd'):
                #         return LocalAPD(self.client_id, self.data_info, self.opt)
                if self.opt.model == 3:
                        return LocalFedAPC(self.client_id, self.data_info, self.opt)
                # elif get_model(self.opt.model, model='ewc'):
                #         return LocalEWC(self.client_id, self.data_info, self.opt)

        def get_next_task(self):
                self.current_task += 1
                self.count_rounds = 0
                data = self.data_generator.get_task(self.current_task)
                self.current_task_name = data['name']
                self.local_model.set_task(self.current_task, data)
                syslog(self.client_id, 'the next task (%s) has been loaded' %(self.current_task_name))

        def stop(self):
                syslog(self.client_id, 'learning all tasks has been done.')
                self.socketio_client.emit('client-stop', {'client_id': self.client_id})
                syslog(self.client_id, 'done.')
                os._exit(0)

        def add_handlers(self):

                def on_request_init(*args):
                        req = args[0]
                        self.client_id       = req['client_id']
                        self.data_generator  = DataGenerator(self.client_id, self.opt)
                        self.data_info       = self.data_generator.get_info()
                        self.local_model     = self.get_local_model()

                        self.get_next_task()
                        self.local_model.log_dir = req['log_dir']
                        self.local_model.initialize(req['model_info'])
                        self.socketio_client.emit('client-ready', {'client_id': self.client_id})

                def on_request_train(*args):
                        req = args[0]
                        """ 
                        if self.recent_s_rounds != req['server_round']:
                                self.recent_s_rounds = req['server_round']
                                self.socketio_client.emit('client-train-ready', {'client_id': self.client_id})
                        else:
                                syslog(self.client_id, 'over request-> server round:%d, server round client received: %d' %(req['server_round'], self.recent_s_rounds))
                                return # preventing additional training
                        """
                        self.count_rounds += 1
                        syslog(self.client_id, 'task:%d, round:%d (cnt:%d), train one round'
                                                        %(self.current_task, req['server_round'], self.count_rounds))

                        is_last_task = (self.current_task==self.data_info['num_tasks']-1)
                        is_last_round = (self.count_rounds%self.opt.num_rounds==0 and self.count_rounds!=0)
                        is_last = is_last_task and is_last_round
                        if self.opt.socket_test:
                                resp = {'client_id': self.client_id, 'client_round': req['server_round']}
                        elif not req['comm']:
                                if self.count_rounds == 1 and req['server_weights'] is not None:
                                        weights_both = pickle_string_to_obj(req['server_weights'])
                                        self.local_model.set_adapts(weights_both[1])
                                self.local_model.train_one_round(req['server_round'], self.count_rounds, is_last)
                                syslog(self.client_id, 'training only')
                        else:
                                if self.opt.federated:
                                        #"""
                                        if req['server_round'] == 0:
                                                client_w = pickle_string_to_obj(req['server_weights'])
                                                self.local_model.set_weights(client_w)
                                        else:
                                                weights_both = pickle_string_to_obj(req['server_weights'])
                                                #client_w = [cb[0] for cb in weights_both]
                                                #client_a = [cb[1] for cb in weights_both]
                                                #self.local_model.set_weights(weights_both[0])
                                                #self.local_model.set_adapts(weights_both[1])
                                                up_adpts = False
                                                if self.count_rounds == 1:
                                                        up_adpts = True
                                                self.local_model.set_both(weights_both, update_ta=up_adpts)
                                        #"""
                                        #self.local_model.set_weights(weights)
                                ########################################################################
                                self.local_model.train_one_round(req['server_round'], self.count_rounds, is_last)
                                ########################################################################
                                resp = {'client_id': self.client_id, 'client_round': req['server_round']}
                                if self.opt.federated:

                                        if self.opt.sparse_comm:
                                                weights, masks = self.local_model.get_weights()
                                                #weights = self.local_model.get_weights()
                                                adps = self.local_model.get_adapts()
                                                resp['client_masks'] = masks
                                        else:
                                                weights = self.local_model.get_weights()
                                                adps = self.local_model.get_adapts()
                                                #both = self.local_model.get_both()
                                        #weights = obj_to_pickle_string(weights)
                                        #adps = obj_to_pickle_string(adps)
                                        resp['client_both'] = obj_to_pickle_string([weights, adps])
                                        resp['train_size_per_class'] = self.local_model.train_size_per_class
                                        resp['train_size'] = len(self.local_model.x_train)
                                        self.resp = resp
                                        '''
                                        for eid in range(4):
                                                atten = self.local_model.get_variable('atten', eid, self.current_task)
                                                print('layer :%d, %s'%(eid, atten.numpy()))
                                        '''
                        '''
                        if req['comm']:
                                self.socketio_client.emit('client-update', resp) # to the server only
                        else:
                                self.socketio_client.emit('client-train-done',{'client_id':self.client_id})
                        '''
                        if not self.opt.model == 0:
                                self.local_model.calculate_capacity()
                        self.local_model.write_current_performances()
                        if req['comm']:
                                self.socketio_client.emit('client-update', resp) # to the server only
                        else:
                                self.socketio_client.emit('client-train-done',{'client_id':self.client_id})
                        if is_last_round or self.local_model.early_stop:
                                if is_last_task:
                                        if self.local_model.early_stop:
                                                self.local_model.evaluation(is_last=True)
                                        self.stop()
                                else:
                                        if self.opt.save_weights:
                                                save_weights(self.opt.weights_dir, 'client%d-task%d-weights.npy'
                                                        %(self.client_id, self.current_task), self.local_model.variables)
                                                syslog(self.client_id, '%dth weights has been saved.'%self.current_task)
                                        self.get_next_task()


                def on_request_train_all(*args):
                        for i in range(self.data_info['num_tasks']):
                                for j in range(self.opt.num_rounds):
                                        self.count_rounds += 1
                                        syslog(self.client_id, 'task:%d, round:%d, train one round'
                                                        %(self.current_task, self.count_rounds))

                                        is_last_task = (self.current_task==self.data_info['num_tasks']-1)
                                        is_last_round = (self.count_rounds%self.opt.num_rounds==0 and self.count_rounds!=0)
                                        is_last = is_last_task and is_last_round
                                        r = (self.count_rounds-1)%self.opt.num_rounds
                                        self.local_model.train_one_round(r, self.count_rounds, is_last)
                                        
                                        if not self.opt.model == 0:
                                                self.local_model.calculate_capacity()
                                        self.local_model.write_current_performances()
                                        if is_last_round or self.local_model.early_stop:
                                                if is_last_task:
                                                        if self.local_model.early_stop:
                                                                self.local_model.evaluation(is_last=True)
                                                        self.stop()
                                                else:
                                                        if self.opt.save_weights:
                                                                save_weights(self.opt.weights_dir, 'client%d-task%d-weights.npy'
                                                                        %(self.client_id, self.current_task), self.local_model.variables)
                                                                syslog(self.client_id, '%dth weights has been saved.'%self.current_task)
                                                        self.get_next_task()

 

                '''
                def on_request_send(*args):
                        self.socketio_client.emit('client-update', self.resp)
                        syslog(self.client_id, 'weights are sent to server')
                '''
                # register handlers
                self.socketio_client.on('request-init', on_request_init)
                self.socketio_client.on('request-train', on_request_train)
                self.socketio_client.on('train_all', on_request_train_all)
                #self.socketio_client.on('request-update', on_request_send)
                self.socketio_client.emit('client-start', {})
                self.socketio_client.wait()
