from clients.get_client import get_client
from server.toolkit import *
from server.aggregation import *
from models.get_model import get_model


class server_pFedDIL():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None,
                 theta_reg=None, **kwargs):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        self.model = get_model(self.args, max_class)
        self.clients = [get_client(get_model(self.args, max_class), pair, self.method_name, self.args) for pair in self.train_data]
        self.clients_selected = []
        self.train_data = train_data
        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class
        self.previous_model = kwargs.get('previous_model', None)
        self.future_data = kwargs.get('future_data', None)
        self.auxiliary_classifier = kwargs.get('auxiliary_classifier', None)

        # record information
        self.avg_fgt = []
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []

    def server_train(self):
        self.initialize_weights()
        self.model = self.model.to(self.device)
        self.update_aux_classifier()
        for t in range(self.args.rounds):
            self.client_update()
            self.aggregate_model()
            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            # print the model performance:
            # (1) the acc on the train dataset of the current task
            # (2) the acc on the test dataset of the current task
            # (3) the acc on the combination of the learned tasks
            # (4) the average acc over each learned tasks
            # (5) the list of the test acc on each task
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)

            # save the model performance
            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        self.save_model_classifier()  # at the end of each task, the model of each client should be saved to be selected in the further task
        print('task finishes')

    def initialize_weights(self):
        if not self.first_round:
            self.model.load_state_dict(self.theta_reg)  # set the model of the last task as the initial point

    def update_aux_classifier(self):
        if self.task_id >= 1:
            client_id = 0
            for client in self.clients:
                client.calculate_aux_score(self.auxiliary_classifier[client_id])  # calculate the classifier scores
                client_id += 1

    def client_update(self):
        idx = 0
        clients_index = list(
            np.random.choice(list(range(self.args.num_clients)), int(self.args.num_clients / 2), replace=False))
        self.clients_selected = [self.clients[i] for i in clients_index]
        for client in self.clients:
            client.update_model(self.model)
        for client in self.clients_selected:
            if self.task_id == 0:
                client.update_model(self.model)
                client.train_first_task(self.future_data)
            elif self.task_id == self.args.num_tasks - 1:
                best_model_value = max(client.aux_classifier_score)
                best_idx = client.aux_classifier_score.index(best_model_value)
                lambda_ = 0.5 if self.args.dataset_list == 'DIGIT10' else 0.8
                if best_model_value >= lambda_:
                    client.update_model(self.previous_model[idx][best_idx])
                else:
                    client.update_model(self.model)
                client.train_without_aux(self.previous_model[idx])
                idx += 1
            else:
                best_model_value = max(client.aux_classifier_score)
                best_idx = client.aux_classifier_score.index(best_model_value)
                lambda_ = 0.5 if self.args.dataset_list == 'DIGIT10' else 0.8
                if best_model_value >= lambda_:
                    client.update_model(self.previous_model[idx][best_idx])
                else:
                    client.update_model(self.model)
                client.train_with_aux(self.future_data, self.previous_model[idx])
                idx += 1

    def save_model_classifier(self):
        idx = 0
        for client in self.clients:
            self.previous_model[idx].append(client.model)
            self.auxiliary_classifier[idx].append(client.auxiliary_classifier_current_task)
            idx += 1


    def aggregate_model(self):
        # aggregate the model after collecting all the updates
        client_fraction = [1 / len(self.clients_selected)] * len(self.clients_selected)
        params = aggregate_basic(self.clients_selected, client_fraction, self.model)

        self.model.load_state_dict(params)

        self.model.load_state_dict(params)

    def evaluate_global_model(self):
        # calculate the performance
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data,
                                                                                                   self.test_data,
                                                                                                   self.model,
                                                                                                   self.previous_test,
                                                                                                   self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc

