import torch
from models.get_model import get_model
from clients.get_client import get_client
from server.aggregation import *
import copy
import numpy as np
from server.toolkit import *

class server_FedProx():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None, theta_reg=None):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        self.model = get_model(self.args, max_class)
        self.clients = [get_client(get_model(self.args, max_class), pair, self.method_name, self.args) for pair in self.train_data]
        # self.num_clients
        self.train_data = train_data

        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class

        self.avg_fgt = []  # record the averaging forgetting of previous val data
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []




    def server_train(self):
        self.initialize_weights()
        self.model = self.model.to(self.device)
        for t in range(self.args.rounds):
            self.client_update()
            self.aggregate_model()
            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)
            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        print('task finishes')


    def initialize_weights(self):
        if self.first_round == False:
            self.model.load_state_dict(self.theta_reg)

    def client_update(self):
        global_model = [p.clone().detach() for p in self.model.parameters()]
        for client in self.clients:
            client.update_model(self.model)
            client.train(self.theta_reg, self.first_round, global_model)

    def aggregate_model(self):

        client_fraction = [1/self.args.num_clients] * len(self.clients)
        params = aggregate_basic(self.clients, client_fraction, self.model)

        self.model.load_state_dict(params)


    def evaluate_global_model(self):
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data, self.test_data, self.model,
                                                                         self.previous_test, self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc