from clients.get_client import get_client
from server.toolkit import *
from models.generator.wgan import WGAN
from server.aggregation import *
from models.get_model import get_model
import torch
import copy


class server_FedCIL():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None, theta_reg=None):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        self.clients = None
        self.clients_selected = []
        self.train_data = train_data
        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class

        # set different model structure based on the dataset
        if args.dataset_list == 'DIGIT10':
            self.model = WGAN(
                z_size=100,
                image_size=32,
                image_channel_size=3,
                c_channel_size=64,
                g_channel_size=64,
                num_classes=self.max_class
            )
            self.model = self.model.to(device=self.device)
            self.size = 32

        else:
            self.model = WGAN(
                z_size=100,
                image_size=224,
                image_channel_size=3,
                c_channel_size=64,
                g_channel_size=32,
                num_classes=self.max_class
            )
            self.model = self.model.to(device=self.device)
            self.size = 256
        self.initialize_AC_GAN()

        # record information
        self.avg_fgt = []
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []




    def server_train(self):
        self.initialize_weights()
        self.set_clients()
        for u in self.clients:
            u.available_labels = list(range(self.max_class))
            u.available_labels_current = list(range(self.max_class))
        for t in range(self.args.rounds):
            self.client_update(local_round=t)
            self.aggregate_model()

            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            # print the model performance:
            # (1) the acc on the train dataset of the current task
            # (2) the acc on the test dataset of the current task
            # (3) the acc on the combination of the learned tasks
            # (4) the average acc over each learned tasks
            # (5) the list of the test acc on each task
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)

            # save the model performance
            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        print('task finishes')

    def initialize_weights(self):
        if self.first_round == False:
            self.model.load_state_dict(self.theta_reg)  # set the model of the last task as the initial point


    def set_clients(self):
        g = copy.deepcopy(self.model)
        self.clients = [get_client(get_model(self.args, self.max_class), pair, self.method_name, self.args, g=g if self.task_id != 0 else None, generator=self.model) for pair in self.train_data]

    def client_update(self, local_round):
        clients_index = list(
            np.random.choice(list(range(self.args.num_clients)), int(self.args.num_clients / 2), replace=False))
        self.clients_selected = [self.clients[i] for i in clients_index]
        for client in self.clients:
            client.update_model(self.model)
        for client in self.clients_selected:
            client.update_model(self.model)
            client.train(local_round, self.model, list(range(self.max_class)))


    def aggregate_model(self):
        # aggregate the model after collecting all the updates
        client_fraction = [1 / len(self.clients_selected)] * len(self.clients_selected)
        generator = aggregate_FedCIL(self.clients_selected, client_fraction, self.model)
        self.model.load_state_dict(generator.state_dict())

    def evaluate_global_model(self):
        # calculate the performance
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data, self.test_data, self.model.critic,
                                                                         self.previous_test, self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc

    ####################################

    def initialize_AC_GAN(self):
        beta1 = 0.5
        beta2 = 0.999
        lr = 1e-4
        weight_decay = 1e-5

        generator_g_optimizer = torch.optim.Adam(self.model.generator.parameters(),
                                                 lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))
        generator_c_optimizer = torch.optim.Adam(self.model.critic.parameters(),
                                                 lr=lr, weight_decay=weight_decay, betas=(beta1, beta2))

        self.model.set_lambda(10.)
        self.model.set_generator_optimizer(generator_g_optimizer)
        self.model.set_critic_optimizer(generator_c_optimizer)

        model = self.model
        std = 0.02
        modules = [m for n, m in model.named_modules() if 'conv' in n or 'fc' in n]
        parameters = [p for m in modules for p in m.parameters()]

        for p in parameters:
            if p.dim() >= 2:
                torch.nn.init.normal_(p, mean=0, std=std)
            else:
                torch.nn.init.constant_(p, 0)

        modules = [m for n, m in model.named_modules() if 'bn' in n]

        for m in modules:
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
