import torch
from clients.get_client import get_client
from server.aggregation import *
import copy
import numpy as np
from server.toolkit import *
from models.ResNet18_MFCL import ResNet18
from models.network_MFCL import network
from models.generator_MFCL import CIFAR_GEN, IMNET_GEN
from models.Teacher_MFCL import Teacher


class server_MFCL():
    def __init__(self, args, train_data, test_data, max_class, method_name, task_id, previous_test=None, theta_reg=None,
                 **kwargs):
        self.args = args
        self.device = self.args.device
        self.method_name = method_name
        self.train_data = train_data
        if self.args.dataset_list == 'DIGIT10':
            feature_extractor = ResNet18(max_class, cifar=True)
            self.generator = CIFAR_GEN()
        else:
            feature_extractor = ResNet18(max_class, cifar=False)
            self.generator = IMNET_GEN()
        self.model = network(max_class, feature_extractor)
        self.clients = [get_client(network(max_class, feature_extractor), pair, self.method_name, self.args) for pair in
                        self.train_data]
        self.clients_selected = []
        self.train_data = train_data
        self.test_data = test_data
        self.previous_test = previous_test
        self.task_id = task_id
        self.first_round = task_id == 0
        self.theta_reg = theta_reg
        self.max_class = max_class
        self.teacher = kwargs.get('teacher', None)

        # record information
        self.avg_fgt = []
        self.avg_train_acc = []
        self.avg_test_acc = []
        self.all_test_previous_acc = []
        self.all_test_acc = []

    def server_train(self):
        self.initialize_weights()  # step1: use the params of the last task to be the initial point
        self.model = self.model.to(self.device)
        if self.task_id != 0:
            for client in self.clients:
                client.update_model(self.model)
                client.last_valid_dim = self.max_class
                client.valid_dim = self.max_class
        for t in range(self.args.rounds):
            self.client_update()
            self.aggregate_model()
            if t == self.args.rounds - 1:
                self.update_teacher()
            self.model = self.model.to(self.device)
            avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = self.evaluate_global_model()
            if t % 5 == 0:
                print(f'round{t + 1}')
            # print the model performance:
            # (1) the acc on the train dataset of the current task
            # (2) the acc on the test dataset of the current task
            # (3) the acc on the combination of the learned tasks
            # (4) the average acc over each learned tasks
            # (5) the list of the test acc on each task
            print(
                f'{100 * avg_train_acc:.2f}%,  {100 * avg_test_acc:.2f}%, {100 * all_test_previous_acc:.2f}%,'
                f'{100 * all_test_acc:.2f}%',
                avg_fgt)

            self.avg_fgt.append(avg_fgt)
            self.avg_train_acc.append(avg_train_acc)
            self.avg_test_acc.append(avg_test_acc)
            self.all_test_previous_acc.append(all_test_previous_acc)
            self.all_test_acc.append(all_test_acc)
        print('task finishes')

    def initialize_weights(self):
        if self.first_round == False:
            self.model.load_state_dict(self.theta_reg)  # set the model of the last task as the initial point

    def client_update(self):
        clients_index = list(
            np.random.choice(list(range(self.args.num_clients)), self.args.participated_clients, replace=False))
        self.clients_selected = [self.clients[i] for i in clients_index]
        for client in self.clients:
            client.update_model(self.model)
        for client in self.clients_selected:
            if self.teacher is None:
                client.trainI()
            else:
                client.trainII(self.teacher)


    def update_teacher(self):
        if self.task_id != (self.args.num_tasks - 1):
            self.teacher = self.train_gen(copy.deepcopy(self.model), self.max_class, self.generator)
            self.model.Incremental_learning(self.max_class)
            for client in self.clients:
                client.valid_dim = self.max_class

    def aggregate_model(self):
        # aggregate the model after collecting the updates
        client_fraction = [1 / len(self.clients_selected)] * len(self.clients_selected)
        params = aggregate_FDCL2(self.clients_selected, client_fraction, self.model, self.args.lr_global)

        self.model.load_state_dict(params)

    def evaluate_global_model(self):
        avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc = evaluate_model(self.train_data,
                                                                                                   self.test_data,
                                                                                                   self.model,
                                                                                                   self.previous_test,
                                                                                                   self.args)
        return avg_train_acc, avg_test_acc, avg_fgt, all_test_previous_acc, all_test_acc

    def train_gen(self, model, valid_out_dim, generator):
        dataset_size = (
        -1, 3, 32 if self.args.dataset_list == 'DIGIT10' else 256, 32 if self.args.dataset_list == 'DIGIT10' else 256)
        model.to('cuda')
        generator_optimizer = torch.optim.Adam(params=generator.parameters(), lr=0.001)
        teacher = Teacher(solver=model, generator=generator, gen_opt=generator_optimizer,
                          img_shape=dataset_size, iters=50,
                          deep_inv_params=[1e-3, 5e1, 1e-3, 1e3, 1],
                          class_idx=np.arange(valid_out_dim), train=True)
        teacher.sample(128 if self.args.dataset_list == 'DIGIT10' else 16, return_scores=False)
        return teacher, copy.deepcopy(model.fc)