import pdb
import sys
import time
import random
import threading
import tensorflow as tf

from misc.utils import *
from .client import Client
from modules.federated import ServerModule

class Server(ServerModule):
    """ FedWeIT Server
    Performing fedweit server algorithms 
    Created by:
        Wonyong Jeong (wyjeong@kaist.ac.kr)
    """
    def __init__(self, args):
        super(Server, self).__init__(args, Client)
        self.client_adapts = []

    def train_clients(self):
        cids = np.arange(self.args.num_clients).tolist()
        num_selection = int(round(self.args.num_clients*self.args.frac_clients))
        #for curr_round in range(self.args.num_rounds*self.args.num_tasks):
        for curr_round in range(self.args.total_rounds):
            self.updates = []
            self.curr_round = curr_round+1
            self.is_last_round = self.curr_round%self.args.num_rounds==0
            if self.is_last_round:
                self.client_adapts = []
            selected_ids = random.sample(cids, num_selection) # pick clients
            self.logger.print('server', 'round:{} train clients (selected_ids: {})'.format(curr_round, selected_ids))
            # train selected clients in parallel
            for clients in self.parallel_clients:
                self.threads = []
                for gid, cid in enumerate(clients):
                    client = self.clients[gid]
                    selected = True if cid in selected_ids else False
                    with tf.device('/device:GPU:{}'.format(gid)):
                        thrd = threading.Thread(target=self.invoke_client, args=(client, cid, curr_round, selected, self.get_weights(), self.get_adapts()))
                        self.threads.append(thrd)
                        thrd.start()
                # wait all threads each round
                for thrd in self.threads:
                    thrd.join()
            # update
            aggr = self.train.aggregate(self.updates)
            self.set_weights(aggr)
        self.logger.print('server', 'done. ({}s)'.format(time.time()-self.start_time))
        sys.exit()


    # def evaluate_global(self):
        tf.keras.backend.set_learning_phase(0)
        #for tid in range(self.state['curr_task']+1):
        for tid in range(5): # PRI: evaluating all: TODO: get number of tasks
            if self.args.model == 'stl': # PRI: 'stl' is not used anywhere else
                if not tid == self.state['curr_task']:
                    continue
            if tid not in self.task['task_names']: #PRI
                continue
            x_test = self.task['x_test_list'][tid]
            y_test = self.task['y_test_list'][tid]
            model = self.nets.get_model_by_tid(tid)
            for i in range(0, len(x_test), self.args.batch_size):
                x_batch = x_test[i:i+self.args.batch_size]
                y_batch = y_test[i:i+self.args.batch_size]
                y_pred = model(x_batch)
                loss = tf.keras.losses.categorical_crossentropy(y_batch, y_pred)
                self.add_performance('test_lss', 'test_acc', loss, y_batch, y_pred)
            lss, acc = self.measure_performance('test_lss', 'test_acc')
            if not tid in self.state['scores']['test_loss']:
                self.state['scores']['test_loss'][tid] = []
                self.state['scores']['test_acc'][tid] = []
            self.state['scores']['test_loss'][tid].append(lss)
            self.state['scores']['test_acc'][tid].append(acc)
            print(self.task['task_names'][tid])
            self.logger.print(self.state['client_id'], 'round:{}(cnt:{}),epoch:{},task:{},lss:{},acc:{} ({},#_train:{},#_valid:{},#_test:{})'
                .format(self.state['curr_round'], self.state['round_cnt'], self.state['curr_epoch'], tid, round(lss, 3), \
                    round(acc, 3), self.task['task_names'][tid], len(self.task['x_train']), len(self.task['x_valid']), len(x_test)))

    def invoke_client(self, client, cid, curr_round, selected, weights, adapts):
        update = client.train_one_round(cid, curr_round, selected, weights, adapts)
        if not update == None:
            self.updates.append(update)
            if self.is_last_round:
                self.client_adapts.append(client.get_adaptives())

    def get_adapts(self):
        if self.curr_round%self.args.num_rounds==1 and not self.curr_round==1:
            from_kb = []
            for lid, shape in enumerate(self.nets.shapes):
                shape = np.concatenate([self.nets.shapes[lid],[int(round(self.args.num_clients*self.args.frac_clients))]], axis=0)
                from_kb_l = np.zeros(shape)
                for cid, ca in enumerate(self.client_adapts):
                    try:
                        if len(shape)==5:
                            from_kb_l[:,:,:,:,cid] = ca[lid]
                        else:
                            from_kb_l[:,:,cid] = ca[lid]
                    except:
                        pdb.set_trace()           
                from_kb.append(from_kb_l)
            return from_kb
        else:
            return None
        