from server.toolkit import *
from models.gen_dis import Discriminator
from models.get_model import get_model
from clients.get_client import get_client
from server.aggregation import *
import numpy as np


class server_SRFDIL():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None,
                 theta_reg=None, **kwargs):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        self.model = get_model(self.args, max_class)
        self.clients = [get_client(get_model(self.args, max_class), pair, self.method_name, self.args) for pair in self.train_data]
        self.clients_selected = []
        # self.num_clients
        self.train_data = train_data
        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class
        self.discriminator = Discriminator(img_channels=3, img_size=32 if self.args.dataset_list == 'DIGIT10' else 224)
        self.previous_discriminator = kwargs.get('previous_discriminator', None)
        self.previous_data = kwargs.get('previous_data', None)
        self.discriminator = self.discriminator.to(self.device)

        # record information
        self.avg_fgt = []
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []


    def server_train(self):
        self.initialize_weights()
        self.model = self.model.to(self.device)
        if self.task_id != 0:
            self.update_train_loader()
        for t in range(self.args.rounds):
            self.client_update()
            self.aggregate_model()
            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            # print the model performance:
            # (1) the acc on the train dataset of the current task
            # (2) the acc on the test dataset of the current task
            # (3) the acc on the combination of the learned tasks
            # (4) the average acc over each learned tasks
            # (5) the list of the test acc on each task
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)


            # save the model performance
            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        print('task finishes')

    def initialize_weights(self):
        if not self.first_round:
            self.model.load_state_dict(self.theta_reg)  # set the model of the last task as the initial point
            self.discriminator.load_state_dict(self.previous_discriminator.state_dict())  # set the discriminator of the last task as the initial point

    def update_train_loader(self):
        idx = 0
        for client in self.clients:
            client.update_train_date(self.model, self.previous_discriminator, self.previous_data[idx], self.max_class)
            idx += 1


    def client_update(self):
        clients_index = list(
            np.random.choice(list(range(self.args.num_clients)), self.args.participated_clients, replace=False))
        self.clients_selected = [self.clients[i] for i in clients_index]
        for client in self.clients:
            client.update_model(self.model)
        for client in self.clients_selected:
            if self.task_id == self.args.num_tasks - 1:
                client.trainI()
            else:
                client.update_dis(self.discriminator)
                client.trainII()


    def aggregate_model(self):
        # update model
        client_fraction = [1 / len(self.clients_selected)] * len(self.clients_selected)
        params = aggregate_basic(self.clients_selected, client_fraction, self.model)
        self.model.load_state_dict(params)

        # update discriminator
        params = aggregate_discriminator(self.clients_selected, client_fraction, self.discriminator)
        self.discriminator.load_state_dict(params)


    def evaluate_global_model(self):
        # calculate the performance
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data, self.test_data, self.model,
                                                                         self.previous_test, self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc
