from tqdm import tqdm
import pickle
from utils.datasavor import Savor
from server.get_server import get_server
from args import args_parser
from utils.data_manager import DataManager
from utils.toolkit import set_seed
import os
import shutil
from torch.utils.data import ConcatDataset



def main():

    args = args_parser()
    set_seed(args.seed)



    # initialization
    previous_test = []  # save the dataset of previous tasks
    data_savor = Savor()  # model performance recorder

    if args.method == 'pFedDIL':
        previous_model = [[] for _ in range(args.num_clients)]
        auxiliary_classifier = [[] for _ in range(args.num_clients)]



    for task_id in tqdm(range(args.num_tasks), desc=f'UPDATE TASK'):
        # download the split train dataset, test dataset, and the number of class of the current task
        data_manager = DataManager(args, task_id)
        train_data, test_data, max_class = data_manager.train_units, data_manager.data_test, data_manager.num_classes

        if args.method == 'pFedDIL' and task_id != args.num_tasks - 1:
            future_manager = DataManager(args, task_id + 1)
            future_units = future_manager.train_units
            future_data = ConcatDataset(future_units)
        
        if task_id == 0:
            prev_server_model = None

        previous_test.append(test_data)
        args.lr_global = 1 / (task_id + 1) if task_id <= 1 else 1  # update the global learning rate
        server_generator = get_server(args, train_data,
                                      test_data, max_class, method_name=args.method, task_id=task_id,
                                      previous_test=previous_test,
                                      theta_reg=prev_server_model,
                                      previous_model=previous_model if args.method == 'pFedDIL' else None,
                                      future_data=future_data if task_id != args.num_tasks - 1 and args.method == 'pFedDIL' else None,
                                      auxiliary_classifier=auxiliary_classifier if args.method == 'pFedDIL' else None,
                                      previous_discriminator=previous_discriminator if args.method == 'SRFDIL' and task_id != 0 else None,
                                      previous_data=previous_data if args.method == 'SRFDIL' and task_id != 0 else None,
                                      teacher=teacher if args.method == 'MFCL' and task_id != 0 else None)

        print(f'server is training on task {task_id} ')
        server_generator.server_train()
        prev_server_model = {k: v.clone().detach() for k, v in server_generator.model.state_dict().items()}  # save the model parameter to be the initial point of the next task


        # necessary parameters needed to be pass on to the next task for different methods
        if args.method == 'pFedDIL':
            previous_model.append(prev_server_model)
            auxiliary_classifier = server_generator.auxiliary_classifier
        if args.method == 'SRFDIL':
            previous_discriminator = server_generator.discriminator
            previous_data = server_generator.train_data
        if args.method == 'MFCL':
            teacher = server_generator.teacher
        
        # update the performance recorder
        data_savor.data_save(server_generator)
    #  save the recorder
    with open(args.savor_name, 'wb') as f:
        pickle.dump(data_savor, f)

if __name__ == '__main__':
    main()
