import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import time
import torch
import argparse
import random
import copy
import numpy as np
import prepare_dataset as pdata
import distributed as dist
from train_worker import pretrain_MAE, models_KD_agg, pretrain_CNN 
from evaluation import finetune_VIT, finetune_CNN
import util.misc as misc
import baseline_models as bm
import util.debug as debug


def get_argparser():
    parser = argparse.ArgumentParser(description="Arguements of distributed masked learning")
    parser.add_argument('-rd', '--rounds', default=200, type=int,
                        help='The number of rounds for federated learning')
    parser.add_argument('-tt', '--training_times', default=3, type=int,
                        help='The bounded number of times of training on each worker')
    parser.add_argument('-sb', '--stale_bound', default=5, type=int,
                        help='The bounded number of train round gap between worker stored local model and visitor model')
    parser.add_argument('-nw', '--num_total_workers', default=100, type=int,
                        help='The total number of workers in the network')
    parser.add_argument('-nm', '--num_models', default=1, type=int,
                        help='The number of autoencoders being trained') 
    parser.add_argument('-le', '--num_of_local_epochs', default=5, type=int,
                        help='The number of epochs in local training')
    parser.add_argument('-c', '--connectivity', default=0.15, type=float,
                        help='The connectivity of network')  
    parser.add_argument('-hcp', '--highest_cp', default=5, type=int,
                        help='The highest computing power for workers') 
    parser.add_argument('-bd', '--base_depth', default=1, type=int,  
                        help='The depth of base transformer')
    
    parser.add_argument('-sp', '--save_path', default='./checkpoint/', type=str, help='checkpoint save path')
    parser.add_argument('-gp', '--graph_path', default='./graph/network_G.adjlist', type=str, help='network graph path')
    parser.add_argument('-dsp', '--data_split_path', default='./graph/data_split.pkl', type=str, help='network data split path') 
    parser.add_argument('-cpp', '--computing_power_path', default='./graph/computing_power.pkl', type=str, help='network computing powers of workers path') 
    parser.add_argument('-swp', '--starting_workers_path', default='./graph/starting_workers.pkl', type=str, help='network starting workers path') 

    parser.add_argument('-d', '--dataset', default=4, type=int, 
                        help='The id of dataset to use: 0 - CIFAR10; 1 - CIFAR100; 2 - Food101; 3 - ImageNet; 4 - Mini-ImageNet; 5 - Road-Sign, 6 - Mini-INAT;')
    parser.add_argument('-ra', '--ratio', default=1, type=float,
                        help='The ratio of labelled images')
    parser.add_argument('-samp', '--sampling', default="iid", 
                        help='The way of samping, iid or dir')
    parser.add_argument('--alpha', default=1e-1, type=float, 
                        help='The required parameter for dir sampling, which decides the statistical heterogenity')
    parser.add_argument('-ri', '--record_interval', default=100, type=int, 
                        help='The interval of saving checkpoint')
    parser.add_argument('-m', '--mode', default=2, type=int, 
                        help='The mode of next worker finding algorithm: 0 - the beginning algorithm; 1 - random; 2 - new one')
    parser.add_argument('-agg', '--agg', default=3, type=int, 
                        help='The mode of aggregation: 0 - average weights; 1 - data volume weights; 2 - round weights; 3 - our weights')
    parser.add_argument('-bl', '--baseline', default=0, type=int, 
                        help='The baselines: 0 - our algorithm; 1 - FedMAE, 2 - GossipMAE, 3 - DecenCNN')

    parser.add_argument('-p', '--phase', default="pretrain",  
                        help='specify the codes to: pretrain, finetune')
    parser.add_argument('-ftd', '--ft_depth', default=5, type=int,  
                        help='The depth of model in finetuning')
    parser.add_argument('--seed', default=0, type=int,  
                        help='The seed used for reproducibility')
    return parser
    
def main(args):
    print ("Run distributed masked autoencoder learning")

    since = time.time()
    since1 = time.time()
    model_time = {}

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    if os.path.exists(args.save_path):
        os.makedirs('%sworker/' % (args.save_path), exist_ok=True)
        os.makedirs('%sglobal/' % (args.save_path), exist_ok=True)
        os.makedirs('%sfinetune/' % (args.save_path), exist_ok=True)

    dataset_names = ["cifar10", "cifar100", "food101", "imagenet", "mini", "road_sign", "inat"]
    dataset = dataset_names[args.dataset]

    mean, std = pdata.find_dataset_mean_std(dataset)
    print("Chosen dataset is %s" % dataset)
    
    if args.baseline != 3:
        train_transform, val_transform = pdata.init_pretrain_transform(mean, std)
    else:
        train_transform, val_transform = bm.create_transform(224, mean, std)

    # initialize dataset
    train_dataset, _, num_classes = pdata.init_dataset(train_transform, val_transform, dataset=dataset)

    # initialize IDs of workers/clients
    workerIDlist = [f'workerID_{i+1}' for i in range(args.num_total_workers)]
    for wid in workerIDlist:
        os.makedirs('%sworker/%s/' % (args.save_path, wid), exist_ok=True)
    

    # divide the dataset 
    if args.phase == "pretrain":
        #unsupervised_datasets = pdata.get_unsuper_datasets(train_dataset, workerIDlist, sampling=args.sampling, alpha=args.alpha, dataset=dataset, num_classes=num_classes)
        #print("Choose %s Sampling!" % args.sampling)
        if os.path.exists(args.data_split_path):
            unsupervised_datasets = pdata.load_data_split(train_dataset, workerIDlist, dataset, args.data_split_path)
        else:
            if args.sampling == 'iid':
                args.alpha = 1e5
            unsupervised_datasets, client_data_ids = pdata.get_varied_unsuper_datasets(train_dataset, workerIDlist, dataset=dataset, num_classes=num_classes, alpha=args.alpha)
            pdata.save_data_split(client_data_ids, args.data_split_path)
        
    else:
        assert args.ratio > 0. and args.ratio <= 1.0, "divide ratio should be between 0 and 1"
        if args.ratio < 1.0:
            super_train_idxs, _ = pdata.divide_dataset(train_dataset, args.ratio, dataset)
        else:
            super_train_idxs = np.arange(len(train_dataset)) 
        supervised_dataset = pdata.get_super_dataset(train_dataset, super_train_idxs)
        unsupervised_datasets = None

    client_depths = [args.base_depth, args.base_depth]
    r_eps = args.num_of_local_epochs

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    torch.cuda.empty_cache()

    # model_mae_depth = args.ft_depth


 
    if args.phase == "pretrain":
        print("*"*20)
        
        # load the computing power distribution and the network
        if os.path.exists(args.computing_power_path):
            cp_dict = dist.load_workers_cp(workerIDlist, args.highest_cp, args.computing_power_path)
        else:
            cp_ratio = [(args.highest_cp - i) for i in range(args.highest_cp)]
            cp_dict = dist.init_workers_cp(workerIDlist, args.highest_cp, cp_ratio)
            dist.save_workers_cp(cp_dict, args.computing_power_path)
        if os.path.exists(args.graph_path):
            network_G = dist.load_network_graph(args.graph_path)
        else:
            network_G = dist.init_network_graph(workerIDlist, args.connectivity)
            dist.save_network_graph(network_G, args.graph_path)
        print("The number of edges in the network is: %s" % (network_G.number_of_edges()))

        last_rd_worker_IDs = {}

        if args.mode == 0:
            print("Use the beginning strategy for next worker decision!")
        elif args.mode == 1:
            print("Use the random strategy for next worker decision!")
        else:
            print("Use our strategy for next worker decision!")

        if args.agg == 0:
            print("Use the average weights for model aggregation!")
        elif args.agg == 1:
            print("Use the data volume weights for model aggregation!")
        elif args.agg == 2:
            print("Use the round weights for model aggregation!")
        else:
            print("Use staleness-noticed weights for model aggregation!")
        print("*"*20)
        
        if args.baseline == 0:
 
            # DeNAV Pre-training
        
            model_infos = {} # data stucture to store information about each model
            local_model_infos = {} # data stucture to store the local model information on each worker

            # training preparation, randomly select workers to be the beginning worker for each model's training
            if os.path.exists(args.starting_workers_path):
                # load start position if specified
                sampled_workers = dist.load_starting_workers(workerIDlist, args.starting_workers_path)
            else:
                sampled_workers = random.sample(workerIDlist, args.num_models)
                dist.save_starting_workers(sampled_workers, args.starting_workers_path)
            model_IDs = [i for i in range(args.num_models)] 
            
            # initialize the model and state log on each selected client, making sure that the model parameters are same
            ckpt = pretrain_MAE(args, None, None, client_depths, r_eps=r_eps, load_path="", save_path="")
            for i in range(args.num_models):
                ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, i)
                misc.save_on_master(ckpt, ckpt_path)
                dist.init_model_info(model_infos, i, sampled_workers, workerIDlist)
                model_time[i] = 0
                

            #dist.show_model_pos(model_infos, network_G) # show beginning status in the network training
                  
            print("Distributed Learning will train %s models for %s rounds" % (args.num_models, args.rounds))
            for i in range(args.rounds):
                current_round = i + 1
                print("Distributed Learning Round %s" % current_round)

                # Identify the training clients at current round
                training_workers = {}
                for model_ID in model_IDs:
                    tw_ID = model_infos[model_ID]['training_worker']
                    if not tw_ID in training_workers.keys():
                        training_workers[tw_ID] = [model_ID]
                    else:
                        training_workers[tw_ID].append(model_ID)

                time_elapsed = time.time() - since1
                for model_ID in model_IDs:
                    model_time[model_ID] += time_elapsed
                since1 = time.time()

                # Conduct local training on these clients
                for worker_ID in training_workers.keys():
                    received_model_Ids = training_workers[worker_ID]
                    received_model_infos = {}
                    
                    # For each model being trained
                    for model_ID in received_model_Ids:
                        received_model_infos[model_ID] = copy.deepcopy(model_infos[model_ID])
                        last_rd_worker_IDs[model_ID] = worker_ID

                        #time_elapsed = time.time() - since1
                        #model_time[model_ID] += time_elapsed
                        since1 = time.time()
                        
                        local_ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
                        ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, model_ID)
                        
                        # staleness-aware model aggreagtion
                        agg_model_IDS = [model_ID]
                        if worker_ID in local_model_infos:  
                            agg_model_IDS, agg_ratios = dist.get_agg_ratios(args, current_round, model_infos, local_model_infos, worker_ID, model_ID, agg_model_IDS) 
                            dist.new_local_models_agg(args, model_ID, worker_ID, agg_model_IDS, agg_ratios)
                            # dist.local_models_agg(args, worker_ID, agg_model_IDS, agg_ratios)

                        # model training
                        ckpt = pretrain_MAE(args, worker_ID, unsupervised_datasets[worker_ID], client_depths, r_eps=r_eps, load_path=ckpt_path, save_path=ckpt_path)
                        misc.save_on_master(ckpt, local_ckpt_path)
                        # misc.save_on_master(ckpt, '%sworker/%s/local_%s.pth' % (args.save_path, worker_ID, model_ID))
                        
                        # local model updates
                        for mid in agg_model_IDS:
                            ckpt_path = '%sworker/%s/local_%s.pth' % (args.save_path, worker_ID, mid)
                            misc.save_on_master(ckpt, ckpt_path)
                        
                        # assuming the training time of a client is affected by its computing resources
                        training_time_elapsed = time.time() - since1
                        training_time_elapsed = training_time_elapsed / (1 + (cp_dict[worker_ID] - 1) / (args.highest_cp - 1))
                        model_time[model_ID] += training_time_elapsed
                        since1 = time.time()

                        # Update the state log
                        print("Update Model %s information" % model_ID)
                        dist.update_info(model_infos, model_ID, worker_ID, len(unsupervised_datasets[worker_ID]), ckpt['loss'], training_time_elapsed, current_round)

                        # Next client selection
                        print("Find next worker for Model %s" % model_ID)
                        dist.next_worker(model_infos, model_ID, worker_ID, network_G, unsupervised_datasets, args, current_round, mode=args.mode)

                        local_model_info = copy.deepcopy(model_infos[model_ID])    
                    
                        if not worker_ID in local_model_infos.keys():
                            local_model_infos[worker_ID] = {}
                        local_model_infos[worker_ID][model_ID] = local_model_info
                        
                        # local_model_infos[worker_ID] = local_model_info
                        #time_elapsed = time.time() - since1
                        #model_time[model_ID] += time_elapsed
                        since1 = time.time()

                # Backup the training in certain interval
                if (current_round % args.record_interval == 0) or (current_round == args.rounds):
                    global_backup_path = '%sglobal/neighbour_%srd.pth' % (args.save_path, current_round)
                    if not os.path.exists(global_backup_path):
                        trained_workers = []
                        for wid in workerIDlist:
                            ckpt_path = '%sworker/%s/local.pth' % (args.save_path, wid)
                            if os.path.exists(ckpt_path):
                                trained_workers.append(wid)
                        data_len_ratios = {}
                        for wid in trained_workers:
                            data_len_ratios[wid] = 1.0 / len(trained_workers)
                        _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, global_backup_path)
                
                            
                dist.show_cur_status(model_infos, network_G) # show current status in the network training

            
            # aggregate the multiple latest models into one model
            data_len_ratios = {}
            for mid in model_IDs:
                data_len_ratios[mid] = 1.0 / len(model_IDs)
            avg_ckpt_path = '%sglobal/single.pth' % args.save_path
            _ = dist.models_avg_agg(args, model_IDs, data_len_ratios, avg_ckpt_path)

            # avg aggregate the local models across all clients
            trained_workers = []
            for wid in workerIDlist:
                ckpt_path = '%sworker/%s/local.pth' % (args.save_path, wid)
                if os.path.exists(ckpt_path):
                    trained_workers.append(wid)
            data_len_ratios = {}
            for wid in trained_workers:
                data_len_ratios[wid] = 1.0 / len(trained_workers)
            _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, '%sglobal/neighbour.pth' % args.save_path)

            # print out the training time of each model and the total training time
            total_used_time = 0
            for mid, used_time in model_time.items():
                print('Learning of Model %s completes in %.0fm %.0fs' % (mid, used_time // 60, used_time % 60))
                total_used_time += used_time
            print('Learning of %s Models totally consumes %.0fm %.0fs' % (len(model_IDs), total_used_time // 60, total_used_time % 60))
        
        elif args.baseline == 1:

            # DeNAV in federated system

            print("The chosen baseline is Federated DeNAV")

            r_eps = args.num_of_local_epochs

            fed_ckpt_path = '%sglobal/fed_checkpoint.pth' % args.save_path

            ckpt = pretrain_MAE(args, None, None, client_depths, r_eps=r_eps, load_path="", save_path=fed_ckpt_path)

            print("Federated Learning will run for %s rounds" % args.rounds)
            for i in range(args.rounds):
                print("#"*30)
                print("Federated Learning Round %s/%s" % (i+1, args.rounds))
                print("#"*30)
                randomClientIDs = random.sample(workerIDlist, args.num_models)
                
                # server can know the ratio of data for each selected client
                data_len_ratios = {}
                data_len_total = 0
                for clientID in randomClientIDs:
                    data_len_ratios[clientID] = len(unsupervised_datasets[clientID])
                    data_len_total += len(unsupervised_datasets[clientID])
                for clientID in data_len_ratios:
                    data_len_ratios[clientID] = data_len_ratios[clientID] / data_len_total

                cur_rd = i + 1

                # local training on clients
                ckpts = {}
                for j in range(len(randomClientIDs)):
                    clientID = randomClientIDs[j]
                    ckpt = pretrain_MAE(args, clientID, unsupervised_datasets[clientID], client_depths, r_eps=r_eps, load_path=fed_ckpt_path)
                    ckpts[clientID] = ckpt

                new_ckpt = dist.ckpts_avg_agg(ckpts, data_len_ratios, load_path=fed_ckpt_path, save_path=fed_ckpt_path)

                if cur_rd % 100 == 0:
                    ckpt_path = '%sglobal/fed_checkpoint_%srd.pth' % (args.save_path, cur_rd)
                    misc.save_on_master(new_ckpt, ckpt_path)

            misc.save_on_master(new_ckpt, '%sglobal/neighbour.pth' % args.save_path)

        elif args.baseline == 2:
            
            # Gossip learning baseline
            print("The chosen baseline is Gossip Learning")
            
            model_infos = {} # data stucture to store information about each model

            ckpt = pretrain_MAE(args, None, None, client_depths, r_eps=r_eps, load_path="", save_path="")
            for worker_ID in workerIDlist:
                ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
                misc.save_on_master(ckpt, ckpt_path)
                model_infos[worker_ID] = {}
                model_infos[worker_ID]['trained_data_amount'] = 0
                model_infos[worker_ID]['training_worker'] = [worker_ID]
            
            print("Gossip Learning will train %s models for %s rounds" % (args.num_total_workers, args.rounds))
            for i in range(args.rounds):
                current_round = i + 1
                print("Gossip Learning Round %s" % current_round)
                training_workers = {}
                for worker_ID in workerIDlist:
                    tw_IDs = model_infos[worker_ID]['training_worker']
                    for tw_ID in tw_IDs:
                        if not tw_ID in training_workers.keys():
                            training_workers[tw_ID] = [worker_ID]
                        else:
                            if not worker_ID in training_workers[tw_ID]:
                                training_workers[tw_ID].append(worker_ID)

                for worker_ID in training_workers.keys():
                    received_model_Ids = training_workers[worker_ID]
                        
                    local_ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)

                    # Randomly selects a neighbour to aggregate
                    next_worker = random.sample(received_model_Ids, 1)[0]
                    if next_worker != worker_ID:
                        agg_model_IDS = [worker_ID, next_worker] 
                    else:
                        agg_model_IDS = [worker_ID]

                    if current_round > 1:
                        agg_ratios = {} 
                        total_value = 0
                        for mid in agg_model_IDS:
                            agg_ratios[mid] = model_infos[mid]['trained_data_amount']
                            total_value += model_infos[mid]['trained_data_amount']
                        for mid in agg_model_IDS:
                            agg_ratios[mid] = agg_ratios[mid] / total_value
                        dist.neighbours_avg_agg(args, agg_model_IDS, agg_ratios, save_path=local_ckpt_path)
                
                    ckpt = pretrain_MAE(args, worker_ID, unsupervised_datasets[worker_ID], client_depths, r_eps=r_eps, load_path=local_ckpt_path, save_path=local_ckpt_path)
                        
                    print("Update Model %s information" % worker_ID)
                    model_infos[worker_ID]['trained_data_amount'] += len(unsupervised_datasets[worker_ID])

                    print("Find next worker for Model %s" % worker_ID)
                    cur_node_idx = int(worker_ID.split('_')[-1])-1
                    if isinstance(list(network_G.nodes())[0], str):
                        cur_node_idx = str(cur_node_idx)
                    neighbours = [idx for idx in list(network_G.neighbors(cur_node_idx))]
                    neighbours.append(cur_node_idx)
                    neighbour_IDs = []
                    for node_idx in neighbours:
                        if isinstance(node_idx, str):
                            node_id = f'workerID_{int(node_idx)+1}'
                        else:
                            node_id = f'workerID_{node_idx+1}'
                        neighbour_IDs.append(node_id)
                    model_infos[worker_ID]['training_worker'] = neighbour_IDs
                    
                    if current_round % args.record_interval == 0:
                        # additionally backup local models by interval
                        backup_path = '%sworker/%s/local_%srd.pth' % (args.save_path, worker_ID, current_round)
                        misc.save_on_master(ckpt, backup_path)
                        print("Backup %s's checkpoint at %s rounds" % (worker_ID, current_round))
                    

                    if current_round > 1 and current_round % args.record_interval == 0:
                        global_backup_path = '%sglobal/neighbour_%srd.pth' % (args.save_path, current_round)
                        if not os.path.exists(global_backup_path):
                            trained_workers = []
                            for wid in workerIDlist:
                                ckpt_path = '%sworker/%s/local_%srd.pth' % (args.save_path, wid, current_round)
                                if os.path.exists(ckpt_path):
                                    trained_workers.append(wid)
                            data_len_ratios = {}
                            for wid in trained_workers:
                                data_len_ratios[wid] = 1.0 / len(trained_workers)
                            _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, global_backup_path)
                    
            trained_workers = []
            for wid in workerIDlist:
                ckpt_path = '%sworker/%s/local.pth' % (args.save_path, wid)
                if os.path.exists(ckpt_path):
                    trained_workers.append(wid)
            data_len_ratios = {}
            for wid in trained_workers:
                data_len_ratios[wid] = 1.0 / len(trained_workers)
            _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, '%sglobal/neighbour.pth' % args.save_path)
        
        else:
            # Dencentralized CNN
            print("The chosen baseline is Dencentralized CNN")
            model_infos = {} # data stucture to store information about each model
            local_model_infos = {} # data stucture to store the local model information on each worker

            # training preparation, randomly select workers to be the beginning worker for each model's training
            if os.path.exists(args.starting_workers_path):
                sampled_workers = dist.load_starting_workers(workerIDlist, args.starting_workers_path)
            else:
                sampled_workers = random.sample(workerIDlist, args.num_models)
                dist.save_starting_workers(sampled_workers, args.starting_workers_path)
            model_IDs = [i for i in range(args.num_models)] 
            
            ckpt = pretrain_CNN(args, None, None, r_eps=r_eps, load_path="", save_path="")

            for i in range(args.num_models):
                ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, i)
                misc.save_on_master(ckpt, ckpt_path)
                dist.init_model_info(model_infos, i, sampled_workers, workerIDlist)
                model_time[i] = 0

            print("Distributed Learning will train %s models for %s rounds" % (args.num_models, args.rounds))
            for i in range(args.rounds):
                current_round = i + 1
                print("Distributed Learning Round %s" % current_round)
                training_workers = {}
                for model_ID in model_IDs:
                    tw_ID = model_infos[model_ID]['training_worker']
                    if not tw_ID in training_workers.keys():
                        training_workers[tw_ID] = [model_ID]
                    else:
                        training_workers[tw_ID].append(model_ID)

                time_elapsed = time.time() - since1
                for model_ID in model_IDs:
                    model_time[model_ID] += time_elapsed
                since1 = time.time()

                for worker_ID in training_workers.keys():
                    received_model_Ids = training_workers[worker_ID]
                    received_model_infos = {}

                    for model_ID in received_model_Ids:
                        received_model_infos[model_ID] = copy.deepcopy(model_infos[model_ID])
                        last_rd_worker_IDs[model_ID] = worker_ID

                        since1 = time.time()
                        
                        local_ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
                        ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, model_ID)
                        
                        agg_model_IDS = [model_ID]
                        if worker_ID in local_model_infos:  
                            agg_model_IDS, agg_ratios = dist.get_agg_ratios(args, current_round, model_infos, local_model_infos, worker_ID, model_ID, agg_model_IDS) 
                            dist.new_local_models_agg(args, model_ID, worker_ID, agg_model_IDS, agg_ratios)

                        ckpt = pretrain_CNN(args, worker_ID, unsupervised_datasets[worker_ID], r_eps=r_eps, load_path=ckpt_path, save_path=ckpt_path)
                        misc.save_on_master(ckpt, local_ckpt_path)
                        
                        for mid in agg_model_IDS:
                            ckpt_path = '%sworker/%s/local_%s.pth' % (args.save_path, worker_ID, mid)
                            misc.save_on_master(ckpt, ckpt_path)
                        
                        training_time_elapsed = time.time() - since1
                        training_time_elapsed = training_time_elapsed / (1 + (cp_dict[worker_ID] - 1) / (args.highest_cp - 1))
                        model_time[model_ID] += training_time_elapsed
                        since1 = time.time()

                        print("Update Model %s information" % model_ID)
                        dist.update_info(model_infos, model_ID, worker_ID, len(unsupervised_datasets[worker_ID]), ckpt['loss'], training_time_elapsed, current_round)

                        print("Find next worker for Model %s" % model_ID)
                        dist.next_worker(model_infos, model_ID, worker_ID, network_G, unsupervised_datasets, args, current_round, mode=args.mode)

                        local_model_info = copy.deepcopy(model_infos[model_ID])    
                    
                        if not worker_ID in local_model_infos.keys():
                            local_model_infos[worker_ID] = {}
                        local_model_infos[worker_ID][model_ID] = local_model_info

                        since1 = time.time()

                if (current_round % args.record_interval == 0) or (current_round == args.rounds):
                    global_backup_path = '%sglobal/neighbour_%srd.pth' % (args.save_path, current_round)
                    if not os.path.exists(global_backup_path):
                        trained_workers = []
                        for wid in workerIDlist:
                            ckpt_path = '%sworker/%s/local.pth' % (args.save_path, wid)
                            if os.path.exists(ckpt_path):
                                trained_workers.append(wid)
                        data_len_ratios = {}
                        for wid in trained_workers:
                            data_len_ratios[wid] = 1.0 / len(trained_workers)
                        _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, global_backup_path)

            trained_workers = []
            for wid in workerIDlist:
                ckpt_path = '%sworker/%s/local.pth' % (args.save_path, wid)
                if os.path.exists(ckpt_path):
                    trained_workers.append(wid)
            data_len_ratios = {}
            for wid in trained_workers:
                data_len_ratios[wid] = 1.0 / len(trained_workers)
            _ = dist.neighbours_avg_agg(args, trained_workers, data_len_ratios, '%sglobal/neighbour.pth' % args.save_path)

            
            total_used_time = 0
            for mid, used_time in model_time.items():
                print('Learning of Model %s completes in %.0fm %.0fs' % (mid, used_time // 60, used_time % 60))
                total_used_time += used_time
            print('Learning of %s Models totally consumes %.0fm %.0fs' % (len(model_IDs), total_used_time // 60, total_used_time % 60))
                        
        
    else:

        # load the pre-trained model
        ckpt_path = '%sglobal/neighbour.pth' % args.save_path
        ckpt_avg = torch.load(ckpt_path, map_location='cpu')
        if args.baseline != 3:
            # if the model to be finetuned is transformer
            ckpts = []
            for i in range(args.ft_depth):
                ckpts.append(copy.deepcopy(ckpt_avg))
            print("Load pre-trained weight from: %s" % ckpt_path)
            
            # replicate the pre-trained transformer block to form a large backbone
            cascade_depths = [args.base_depth for i in range(args.ft_depth)]
            cascade_indexs = {}
            start = 0
            for i in range(len(cascade_depths)):
                cascade_indexs[i] = [j for j in range(start, start + cascade_depths[i])]
                start += cascade_depths[i]

        
            ckpt_path_cascade = '%sglobal/model_cascade.pth' % args.save_path
            ckpt_path_final = '%sfinetune/model_final.pth' % args.save_path
            
            
            # finetune the transformer backbone
            _ = pretrain_MAE(args, None, None, [args.ft_depth, 1], r_eps=r_eps, load_path="", save_path=ckpt_path_cascade)
            dist.cascade_models(ckpts, cascade_indexs, load_path=ckpt_path_cascade)
        

            #ckpt_path_final = "../federated_MAE/backup/mini_iid_200rd_10ep/new_fed_checkpoint_cat_10000.pth"
            
            
            mean, std = pdata.find_dataset_mean_std(dataset)
            finetune_VIT(args, super_train_idxs, mean, std, dataset, args.ratio, args.ft_depth, load_path=ckpt_path_cascade, save_path=ckpt_path_final)
        else:
            # Transfer Learning 
            mean, std = pdata.find_dataset_mean_std(dataset)
            finetune_CNN(args, super_train_idxs, mean, std, dataset, args.ratio, load_path=ckpt_path)


        

    del device

    time_elapsed1 = time.time() - since
    print('Distributed Masked Learning completes in {:.0f}m {:.0f}s'.format(time_elapsed1 // 60, time_elapsed1 % 60))
    



if __name__ == '__main__':
    args = get_argparser().parse_args()
    main(args)