import torch
from models.get_model import get_model
from clients.get_client import get_client
from server.aggregation import *
import copy
import numpy as np
from server.toolkit import *

class server_FLwF():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None, theta_reg=None):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        self.model = get_model(self.args, max_class)
        self.clients = [get_client(get_model(self.args, max_class), pair, self.method_name, self.args) for pair in self.train_data]
        self.clients_selected = []
        self.train_data = train_data
        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class

        # record information
        self.avg_fgt = []
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []



    def server_train(self):
        self.initialize_weights() # step1: use the params of the last task to be the initial point
        self.model = self.model.to(self.device)
        for t in range(self.args.rounds):
            self.client_update()
            self.aggregate_model()
            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            # print the model performance:
            # (1) the acc on the train dataset of the current task
            # (2) the acc on the test dataset of the current task
            # (3) the acc on the combination of the learned tasks
            # (4) the average acc over each learned tasks
            # (5) the list of the test acc on each task
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)
            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        print('task finishes')


    def initialize_weights(self):
        if self.first_round == False:
            self.model.load_state_dict(self.theta_reg)  # set the model of the last task as the initial point

    def client_update(self):
        global_model = [p.clone().detach() for p in self.model.parameters()]
        clients_index = list(
            np.random.choice(list(range(self.args.num_clients)), self.args.participated_clients, replace=False))
        self.clients_selected = [self.clients[i] for i in clients_index]
        for client in self.clients:
            client.update_model(self.model)
        for client in self.clients_selected:
            if self.first_round:
                client.train(self.model)
            else:
                previous_model = get_model(self.args, self.max_class)
                previous_model.load_state_dict(self.theta_reg)
                previous_model.to(self.device)
                client.train(self.model, previous_model)

    def aggregate_model(self):
        # calculate the size portion of each client
        # train_len = sum([len(i) for i in self.train_data])
        client_fraction = [1 / len(self.clients_selected)] * len(self.clients_selected)
        params = aggregate_FDCL2(self.clients_selected, client_fraction, self.model, self.args.lr_global)

        self.model.load_state_dict(params)

    def evaluate_global_model(self):
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data, self.test_data, self.model,
                                                                         self.previous_test, self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc