import pdb
import os
import sys
import random
import logging
import math
import time
import atexit
import threading
import numpy as np

from flask_socketio import *
from flask import Flask, request
from threading import Lock

from common.utils import *
# from models.cnn.cnn_global import GlobalCNN
# from models.ewc.ewc_global import GlobalEWC
# from models.apd.apd_global import GlobalAPD
from models.fedapc.fedapc_global import GlobalFedAPC

class Server:

        def __init__(self, opt):

                self.opt= opt
                self.sid2cid = {}
                self.responses = []
                self.ready_clients = set()
                self.trained_clients = []

                self.server_id = -2
                self.client_id = -1
                self.current_round = -1
                self.has_begun = False
                self.num_waitings = round(opt.num_clients*opt.frac_clients)

                self.lock = threading.Lock()
                self.lock_train = threading.Lock()
                self.server = Flask(__name__)
                self.socketio = SocketIO(self.server)
                self.server.logger.setLevel(logging.ERROR)
                logging.getLogger('werkzeug').setLevel(logging.ERROR)

                if not os.path.isdir(self.opt.log_dir):
                        os.makedirs(self.opt.log_dir)

                self.global_model = self.get_global_model()
                self.listener = threading.Thread(target=self.listen)
                self.listener.daemon = True
                self.listener.start()

                self.add_handlers()

        def run(self):
                self.socketio.run(self.server, host=self.opt.host_ip, port=self.opt.host_port)

        def get_global_model(self):
                # if get_model(self.opt.model, model='cnn'):
                #         return GlobalCNN(self.opt)
                # elif get_model(self.opt.model, model='apd'):
                #         return GlobalAPD(self.opt)
                if self.opt.model == 3:
                        return GlobalFedAPC(self.opt)
                # elif get_model(self.opt.model, model='ewc'):
                #         return GlobalEWC(self.opt)

        def listen(self):
                syslog(self.server_id, 'started to listen')
                is_wait=True
                is_train=False
                if not self.opt.federated:
                        while len(self.ready_clients) <= self.opt.num_clients:
                                if len(self.ready_clients) == self.opt.num_clients:
                                        for sid in self.ready_clients:
                                                self.socketio.emit('train_all', {}, room=sid)
                                                self.socketio.sleep(1)  
                                        is_train=True
                                if len(self.ready_clients) <=0 and is_train:
                                        self.stop()
                                time.sleep(3)
                while True:
                        if self.has_begun:
                                """
                                if len(self.trained_clients)>=self.opt.num_clients and is_wait:
                                        with self.lock_train:
                                                self.selected_clients = random.sample(list(self.trained_clients), self.num_waitings)
                                                for sid in self.selected_clients:
                                                        self.socketio.emit('request-update', {}, room=sid)
                                                        self.socketio.sleep(1)
                                        is_wait=False
                                        self.trained_clients=[]
                                """
                                if len(self.responses)>=self.num_waitings:
                                        if self.opt.federated and not self.opt.socket_test:
                                                syslog(self.server_id, 'update global weights')
                                                with self.lock:
                                                        self.global_model.update_weights(self.responses)
                                                if self.opt.sparse_comm:
                                                        self.global_model.write_current_status()
                                        # NOTE: make sure num_waiting is minimum
                                        if len(self.responses)+len(self.trained_clients)>=self.opt.num_clients:
                                                with self.lock:
                                                        self.responses = [] # reset
                                                with self.lock_train:
                                                        self.trained_clients = [] #reset
                                                self.train_next_round()
                                                #self.send_aws_train_next_round()
                                        if len(self.ready_clients) <=0:
                                                self.stop()
                        else:
                                if len(self.ready_clients)==self.opt.num_clients: # >=self.num_waitings:
                                        self.has_begun = True
                                        self.train_next_round()
                        time.sleep(1)

        def train_next_round(self):
                self.current_round += 1
                self.selected_clients = random.sample(list(self.ready_clients), self.num_waitings)
                syslog(self.server_id, "round:%d, request train one round (%s)" %(self.current_round, self.get_clients(self.selected_clients)))
                #syslog(self.server_id, "round:%d, request train one round to all clients" %(self.current_round))
                req = {'server_round': self.current_round}
                if self.opt.federated and not self.opt.socket_test:
                        if self.current_round > 0:
                                get_w = self.global_model.get_weights()
                                get_a = self.global_model.get_adapts()
                                req['server_weights'] = obj_to_pickle_string(np.array([get_w, get_a]))
                        else:
                                req['server_weights'] = obj_to_pickle_string(self.global_model.get_weights())
                elif not self.opt.federated:
                        req['server_weights']=None

                req['comm'] = True
                i=0
                while len(self.responses) < self.num_waitings:
                        if len(self.responses)>=i:
                                sids = list(self.selected_clients)[i:i+5]
                                for sid in sids:
                                        self.socketio.emit('request-train', req, room=sid)
                                        self.socketio.sleep(0.5)
                                i+=5
                        time.sleep(0.5)

                req['comm'] = False
                if self.current_round == 0 or self.current_round%self.opt.num_rounds!=0:
                        req['server_weights'] = None
                
                syslog(self.server_id, "Transmit training request to not comm. clients")
                i=0
                only_train_clients = self.ready_clients - set(self.selected_clients)
                while len(self.trained_clients) < self.opt.num_clients-self.num_waitings:
                        if len(self.trained_clients)>=i:
                                sids = list(only_train_clients)[i:i+5]
                                for sid in sids:
                                        self.socketio.emit('request-train', req, room=sid) # to the specific client
                                        self.socketio.sleep(0.5)
                                i+=5
                        time.sleep(0.5)
                    

                #self.socketio.emit('request-train', req) # broadcast to all
        """
        def send_aws_train_next_round(self):
                self.current_round += 1
                #if self.current_round % self.opt.num_rounds == 0:
                syslog(self.server_id, "round:%d, request train one round" %(self.current_round))
                req = {'server_round': self.current_round}
                if self.opt.federated and not self.opt.socket_test:
                        #pdb.set_trace()
                        req['server_weights'] = obj_to_pickle_string(self.global_model.get_adapts())
                for sid in self.ready_clients:
                        self.socketio.emit('request-train', req, room=sid) # to the specific client
                # self.socketio.emit('request-train', req) # broadcast to all
        """
        def stop(self):
                syslog(self.server_id, 'all clients are done.')
                if self.opt.save_weights:
                        save_weights(self.opt.weights_dir, 'round-{}-aggr-weights.npy'.\
                                                        format(self.current_round), self.global_model.get_weights())
                        syslog(self.server_id, 'final aggregated weights has been saved.')
                syslog(self.server_id, 'done.')
                os._exit(0) # thread

        def get_clients(self, sids):
                c = [self.sid2cid[sid] for sid in sids]
                return ','.join(map(str, c))

        def add_handlers(self):
                @self.socketio.on('client-start')
                def handle_start(resp):
                        with self.lock:
                                self.client_id += 1
                                cid = self.client_id
                                self.sid2cid[request.sid] = cid
                                self.socketio.emit('request-init', {
                                        'client_id': cid,
                                        'log_dir': self.opt.log_dir,
                                        'model_info': self.global_model.get_info(),
                                }, room=request.sid) # to the specific client
                                syslog(self.server_id, 'request initialize client:%d' %(cid))

                @self.socketio.on('client-ready')
                def handle_ready(resp):
                        syslog(self.server_id, 'client:%d is ready' %(resp['client_id']))
                        with self.lock:
                                self.ready_clients.add(request.sid)

                @self.socketio.on('client-update')
                def handle_update(resp):
                        if resp['client_round'] < self.current_round:
                                syslog(self.server_id, 'round:%d, receive outdated updates from client:%d (client-round:%d). Ignored.'
                                                        %(self.current_round, resp['client_id'], resp['client_round']) )
                        else:
                                syslog(self.server_id, 'round:%d, receive trained weights from client:%d.'
                                                        %(self.current_round, resp['client_id']) ) # , sys.getsizeof(resp)
                                with self.lock:
                                        self.responses.append(resp)
                
                @self.socketio.on('client-train-done')
                def handle_train_done(resp):
                        syslog(self.server_id, 'client:%d training done!' %(resp['client_id']))
                        with self.lock_train:
                                self.trained_clients.append(request.sid)

                @self.socketio.on('client-stop')
                def handle_stop(resp):
                        syslog(self.server_id, 'round:%d, client:%d has been stopped ' %(self.current_round, resp['client_id']))
                        with self.lock:
                                self.ready_clients.remove(request.sid)
