import networkx as nx
import util.misc as misc
import torch
import random
import math
import copy
import os
import matplotlib.pyplot as plt
import pickle


def init_network_graph(workerIDlist, connectivity):
    # init the network graph according to workers
    num_nodes = len(workerIDlist)
    G = nx.erdos_renyi_graph(num_nodes, connectivity)
    while len(list(nx.connected_components(G))[0]) < num_nodes:
        G = nx.erdos_renyi_graph(num_nodes, connectivity)
    return G

def save_network_graph(graph, path):
    fh = open(path, 'wb')
    nx.write_adjlist(graph, fh)

def load_network_graph(path):
    fh = open(path, 'rb')
    network_G = nx.read_adjlist(fh)
    print("Succesefully loads network graph from %s" % path)
    return network_G

def save_starting_workers(sampled_workers, save_path):
    with open(save_path, 'wb') as fp:
        pickle.dump(sampled_workers, fp)
        print('Selected workers list saved successfully to %s' % save_path)

def load_starting_workers(workerIDlist, load_path):

    with open(load_path, 'rb') as fp:
        sampled_workers = pickle.load(fp)
    
    for wid in sampled_workers:
        assert wid in workerIDlist, "The loaded workers should be in the range of the specified worker list"

    return sampled_workers

def init_workers_cp(workerIDlist, highest_cp, cp_ratio):
    # init the computing power of each worker
    num_of_workers = len(workerIDlist)
    assert highest_cp == len(cp_ratio)
    total_ratio = 0
    for cpr in cp_ratio:
        total_ratio += cpr
    id_ctr = 0
    cp_ctr = 1
    cp_dict = {}
    for cpr in cp_ratio:
        num_cps = int(num_of_workers * cpr / total_ratio)
        if cp_ctr == highest_cp:
            for wid in range(id_ctr, len(workerIDlist)):
                cp_dict[workerIDlist[wid]] = cp_ctr
        else:
            for wid in range(id_ctr, id_ctr + num_cps):
                cp_dict[workerIDlist[wid]] = cp_ctr
        cp_ctr += 1
        id_ctr += num_cps
    return cp_dict

def save_workers_cp(cp_dict, save_path):
    with open(save_path, 'wb') as fp:
        pickle.dump(cp_dict, fp)
        print('Computing power of workers saved successfully to %s' % save_path)

def load_workers_cp(workerIDlist, highest_cp, load_path):

    with open(load_path, 'rb') as fp:
        cp_dict = pickle.load(fp)
    
    assert len(cp_dict.keys()) == len(workerIDlist), "The length of loaded computing power dict should match with the length of workers"

    cp_ctr = 0
    for _, cp in cp_dict.items():
        if cp > cp_ctr:
            cp_ctr = cp
    assert cp_ctr == highest_cp, "The highest computing power in the loaded dict should match with the specificed value"

    print("Succesefully loads workers' computing power dict from %s" % load_path)

    return cp_dict

# initialize the state log for each model being pre-trained
def init_model_info(model_infos, model_ID, sampled_workers, workerIDlist):
    model_infos[model_ID] = {}
    model_infos[model_ID]['this_model_ID'] = model_ID
    model_infos[model_ID]['explored'] = []
    model_infos[model_ID]['training_worker'] = sampled_workers[model_ID]
    model_infos[model_ID]['trained_data_amount'] = 0
    model_infos[model_ID]['prev_loss'] = -1.0
    model_infos[model_ID]['num_passed'] = 0
    model_infos[model_ID]['explored_dict'] = {wid:0 for wid in workerIDlist}
    model_infos[model_ID]['explored_loss'] = {wid:float('inf') for wid in workerIDlist}
    model_infos[model_ID]['explored_train_time'] = {wid:float('inf') for wid in workerIDlist}
    model_infos[model_ID]['last_visit_round'] = {wid:0 for wid in workerIDlist}


# initialize the ratios for model aggregation based on the number of data that the model has been trained with
def get_data_len_ratios(model_infos, model_IDs):
    data_len_ratios = {}
    data_len_total = 0
    for model_ID in model_IDs:
        data_len = model_infos[model_ID]['prev_data_amount']
        data_len_ratios[model_ID] = data_len
        data_len_total += data_len
    for model_ID in data_len_ratios:
        data_len_ratios[model_ID] = data_len_ratios[model_ID] / data_len_total
    return data_len_ratios

# check if aggreagtion is valid
def check_agg_valid(model_infos, min_visit_times):
    for mid in model_infos.keys():
        m_da = model_infos[mid]['trained_data_amount']
        if m_da == 0:
            return False
        m_ep = model_infos[mid]['explored_dict']
        for wid, times in m_ep.items():
            if times <= min_visit_times:
                return False
    return True

# aggreagte all local models on a client into one model
def local_models_agg(args, worker_ID, model_IDs, agg_ratios, save_path=None):
    print("Start model aggregation on worker %s" % worker_ID)
    local_ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
    ckpts = {}
    for model_ID in model_IDs:
        if model_ID == -1:
            ckpt = torch.load(local_ckpt_path, map_location='cpu')
        else:
            ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, model_ID)
            ckpt = torch.load(ckpt_path, map_location='cpu')
        ckpts[model_ID] = ckpt
    for mid in ckpts:
        res_ckpt = copy.deepcopy(ckpts[mid])
        break
    output_model = res_ckpt['model']
    for param_name in output_model:
        params = []
        for model_ID in model_IDs:
            ckpt = ckpts[model_ID]
            param = ckpt['model'][param_name] * agg_ratios[model_ID]
            params.append(param)
        output_model[param_name] = torch.sum(torch.stack(params), dim=0)
    res_ckpt['model'] = output_model
    if save_path:
        misc.save_on_master(res_ckpt, save_path)
    else:
        misc.save_on_master(res_ckpt, local_ckpt_path)
        for model_ID in model_IDs:
            if model_ID != -1:
                ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, model_ID)
                misc.save_on_master(res_ckpt, ckpt_path)
            else:
                misc.save_on_master(res_ckpt, local_ckpt_path)
    print("Model aggregation completes")

######### Staleness-aware model aggreagtion ######### 
def get_agg_ratios(args, current_round, model_infos, local_model_infos, worker_ID, model_ID, agg_model_IDS):
    vistor_model_round = current_round - 1
    agg_ratios = {} 
    agg_epoch_ratios = {}
    agg_data_ratios = {}
    agg_epoch_ratios[model_ID] = vistor_model_round * args.num_of_local_epochs
    agg_data_ratios[model_ID] = model_infos[model_ID]['trained_data_amount']
    
    for local_mid, local_minfo in local_model_infos[worker_ID].items():
        local_model_round = local_minfo['last_visit_round'][worker_ID]
        if model_ID != local_mid and vistor_model_round - local_model_round <= args.stale_bound:
            agg_model_IDS.append(local_mid)
            agg_epoch_ratios[local_mid] = local_model_round * args.num_of_local_epochs
            agg_data_ratios[local_mid] = local_minfo['trained_data_amount']
    total_epoch = 0
    total_data = 0
    for mid in agg_model_IDS:
        total_epoch += agg_epoch_ratios[mid]
        total_data += agg_data_ratios[mid]
    for mid in agg_model_IDS:
        agg_epoch_ratios[mid] = agg_epoch_ratios[mid] / total_epoch
        agg_data_ratios[mid] = agg_data_ratios[mid] / total_data
    if args.agg == 0:
        for mid in agg_model_IDS:
            agg_ratios[mid] = 1 / len(agg_model_IDS)
    elif args.agg == 1:
        for mid in agg_model_IDS:
            agg_ratios[mid] = agg_data_ratios[mid]
    elif args.agg == 2:
        for mid in agg_model_IDS:
            agg_ratios[mid] = agg_epoch_ratios[mid]
    else:
        for mid in agg_model_IDS:
            agg_ratios[mid] = agg_epoch_ratios[mid] * agg_data_ratios[mid] 
        total_value = 0
        for mid in agg_model_IDS:
            total_value += agg_ratios[mid]
        for mid in agg_model_IDS:
            agg_ratios[mid] = agg_ratios[mid] / total_value
    return agg_model_IDS, agg_ratios

# new codes for local model aggreagting
def new_local_models_agg(args, model_ID, worker_ID, agg_IDs, agg_ratios, save_path=None):
    print("Start model aggregation on worker %s" % worker_ID)
    local_ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
    ckpts = {}
    for mid in agg_IDs:
        if mid == model_ID:
            ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, mid)
        else:
            ckpt_path = '%sworker/%s/local_%s.pth' % (args.save_path, worker_ID, mid)
        ckpt = torch.load(ckpt_path, map_location='cpu')
        ckpts[mid] = ckpt
    res_ckpt = torch.load(local_ckpt_path, map_location='cpu')
    output_model = res_ckpt['model']
    for param_name in output_model:
        params = []
        for mid in agg_IDs:
            ckpt = ckpts[mid]
            param = ckpt['model'][param_name] * agg_ratios[mid]
            params.append(param)
        output_model[param_name] = torch.sum(torch.stack(params), dim=0)
    res_ckpt['model'] = output_model
    if save_path:
        misc.save_on_master(res_ckpt, save_path)
    else:
        misc.save_on_master(res_ckpt, local_ckpt_path)
        for mid in agg_IDs:
            if mid == model_ID:
                ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, mid)
            else:
                ckpt_path = '%sworker/%s/local_%s.pth' % (args.save_path, worker_ID, mid)
            misc.save_on_master(res_ckpt, ckpt_path)
    print("Model aggregation completes")
    
# aggregate the local model on a client with local models of all neighbors into a model
def neighbours_avg_agg(args, worker_IDs, data_len_ratios, save_path=None):
    print("Start model aggregation for models on workers %s" % (str(worker_IDs)))
    ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_IDs[0])
    res_ckpt = torch.load(ckpt_path, map_location='cpu')
    output_model = res_ckpt['model']
    for param_name in output_model:
        params = []
        for worker_ID in worker_IDs:
            ckpt_path = '%sworker/%s/local.pth' % (args.save_path, worker_ID)
            ckpt = torch.load(ckpt_path, map_location='cpu')
            param = ckpt['model'][param_name] * data_len_ratios[worker_ID]
            params.append(param)
        output_model[param_name] = torch.sum(torch.stack(params), dim=0)
    res_ckpt['model'] = output_model
    if save_path:
        misc.save_on_master(res_ckpt, save_path)
    print("Model aggregation completes")
    return res_ckpt

# aggreagte the multiple latest pre-training models into one model
def models_avg_agg(args, model_IDs, data_len_ratios, save_path=None):
    print("Start model aggregation for models %s" % (str(model_IDs)))
    ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, model_IDs[0])
    res_ckpt = torch.load(ckpt_path, map_location='cpu')
    output_model = res_ckpt['model']
    for param_name in output_model:
        params = []
        for mid in model_IDs:
            ckpt_path = '%sglobal/model_%s.pth' % (args.save_path, mid)
            ckpt = torch.load(ckpt_path, map_location='cpu')
            param = ckpt['model'][param_name] * data_len_ratios[mid]
            params.append(param)
        output_model[param_name] = torch.sum(torch.stack(params), dim=0)
    res_ckpt['model'] = output_model
    if save_path:
        misc.save_on_master(res_ckpt, save_path)
    print("Model aggregation completes")
    return res_ckpt

# model aggregation for the given list of checkpoints
def ckpts_avg_agg(ckpts, data_len_ratios, load_path, save_path=None):
    print("Start model aggregation for ckpts")
    res_ckpt = torch.load(load_path, map_location='cpu')
    output_model = res_ckpt['model']
    for param_name in output_model:
        params = []
        for clientID, ckpt in ckpts.items():
            param = ckpt['model'][param_name] * data_len_ratios[clientID]
            params.append(param)
        output_model[param_name] = torch.sum(torch.stack(params), dim=0)
    res_ckpt['model'] = output_model
    if save_path:
        misc.save_on_master(res_ckpt, save_path)
    print("Model aggregation completes")
    return res_ckpt

# Update the state-log after local training
def update_info(model_infos, model_ID, worker_ID, data_amount, loss, train_time, current_round):
    print("in_update_info")
    model_infos[model_ID]['trained_data_amount'] += data_amount
    model_infos[model_ID]['prev_loss'] = loss
    model_infos[model_ID]['num_passed'] += 1
    if not worker_ID in model_infos[model_ID]['explored']:
        model_infos[model_ID]['explored'].append(worker_ID)
    model_infos[model_ID]['explored_dict'][worker_ID] += 1
    model_infos[model_ID]['explored_loss'][worker_ID] = loss
    model_infos[model_ID]['explored_train_time'][worker_ID] = train_time
    model_infos[model_ID]['last_visit_round'][worker_ID] = current_round

def get_model_scores(model_infos):
    # score function:
    # s_i = exp((t_i / T) * (p_i / P) / sum((t_i / T) * (p_i / P)))
    total_m_da = 0
    total_m_ep = 0 
    for mid in model_infos.keys():
        m_da = model_infos[mid]['trained_data_amount']
        m_ed = model_infos[mid]['explored_dict']
        total_m_da += m_da
        total_m_ep += get_explore_points(m_ed)
    model_scores = {}
    total_model_score = 0
    for mid in model_infos.keys():
        m_da = model_infos[mid]['trained_data_amount']
        m_ed = model_infos[mid]['explored_dict']
        model_score = math.exp((m_da / total_m_da) * (get_explore_points(m_ed) / total_m_ep))
        model_scores[mid] = model_score
        total_model_score += model_score
    for mid in model_scores:
        model_scores[mid] /= total_model_score
    return model_scores


def get_explore_points(explored_dict):
    # a function to calculate the explore points according to the recorded explored dict of a model
    c_dict = copy.deepcopy(explored_dict)
    explored_nodes = 0
    for _, visit_times in c_dict.items():
        if visit_times > 0:
            explored_nodes += 1
    output = float(explored_nodes / len(c_dict.keys()))
    while explored_nodes == len(c_dict.keys()):
        explored_nodes = 0
        for wid, visit_times in c_dict.items():
            visit_times -= 1
            if visit_times > 0:
                explored_nodes += 1
            c_dict[wid] = visit_times
        output += float(explored_nodes / len(c_dict.keys()))
    return output

def get_node_value(explored_list, cur_node_idx, graph, datasets):
    # determine the value for exploring each node on graph
    node_values = []
    spl = dict(nx.all_pairs_shortest_path_length(graph))

    for node_idx in list(graph.nodes()):
        if isinstance(node_idx, str):
            cur_node_idx = str(cur_node_idx)
            worker_ID = f'workerID_{int(node_idx)+1}'
        else:
            worker_ID = f'workerID_{node_idx+1}'
        if node_idx != cur_node_idx:
            d = spl[cur_node_idx][node_idx]
            t = len(datasets[worker_ID])
            i = 0 if (worker_ID in explored_list) else 1
            node_values.append((node_idx, (t ** i) ** (1/d)))

    node_values.sort(key=lambda nv: nv[1], reverse=True)
    return node_values

######### Client selection score formulation ######### 
def get_new_node_value(model_info, cur_node_idx, graph, datasets, args, current_round):
    node_values = []
    neighbours = [idx for idx in list(graph.neighbors(cur_node_idx))]
    neighbours.append(cur_node_idx)
    neighbour_IDs = []
    for node_idx in neighbours:
        if isinstance(node_idx, str):
            node_id = f'workerID_{int(node_idx)+1}'
        else:
            node_id = f'workerID_{node_idx+1}'
        neighbour_IDs.append(node_id)

    last_visit_round = model_info['last_visit_round']
    explored_dict = model_info['explored_dict']
    explored_train_time = model_info['explored_train_time']
    explored_loss = model_info['explored_loss']
    max_data_amount = 0
    min_training_time = float('inf')
    max_training_loss = 0
    for node_id in neighbour_IDs:
        if len(datasets[node_id]) > max_data_amount:
            max_data_amount = len(datasets[node_id])
        if explored_train_time[node_id] < min_training_time:
            min_training_time = explored_train_time[node_id]
        if explored_loss[node_id] > max_training_loss:
            max_training_loss = explored_loss[node_id]
            

    for node_id in neighbour_IDs:
        
        if explored_dict[node_id] < args.training_times:
            assert last_visit_round[node_id] <= current_round, "error finding L=%s > Rc=%s" % (last_visit_round[node_id], current_round)
            # Overlook Control Factor
            if last_visit_round[node_id] > 0:
                A = (current_round - last_visit_round[node_id]) / args.rounds
            else:
                A = 1
            
            # Data Efficiency
            B = len(datasets[node_id]) / (max_data_amount)

            # Device Efficiency
            if explored_train_time[node_id] != float('inf'):
                # the worker has been visited before
                C = min_training_time / explored_train_time[node_id]
                # D = explored_loss[node_id] / max_training_loss 
            else:
                C = 1
                # D = 1

            node_value = B * (A + C) + 1
        else:
            node_value = 1

        node_values.append((node_id, node_value))
    node_values.sort(key=lambda nv: nv[1], reverse=True)
    return node_values

######### Next client selection ######### 
def next_worker(model_infos, model_ID, worker_ID, graph, datasets, args, current_round, mode=0, wait=False):
    # return the next worker that the selected model will pass to
    cur_node_idx = int(worker_ID.split('_')[-1])-1
    if isinstance(list(graph.nodes())[0], str):
        cur_node_idx = str(cur_node_idx)
    if mode == 0:
        # the beginning next worker finding algorithm
        explored_list = model_infos[model_ID]['explored']
        if len(explored_list) == len(list(graph.nodes())):
            # Case when all workers have been explored
            explored_dict = model_infos[model_ID]['explored_dict']
            ed = []
            for wid, times in explored_dict.items():
                ed.append((wid, times))
            ed.sort(key=lambda x: x[1])
            smallest_times = ed[0][1]
            for wid, times in explored_dict.items():
                if times == smallest_times:
                    explored_list.remove(wid)
            model_infos[model_ID]['explored'] = explored_list
        node_values = get_node_value(explored_list, cur_node_idx, graph, datasets)
        target_node_idx = node_values[0][0]
        sp = dict(nx.all_pairs_shortest_path(graph))
        i = 1
        if isinstance(target_node_idx, str):
            cur_node_idx = str(cur_node_idx)
        next_worker_idx = sp[cur_node_idx][target_node_idx][i]
        while model_infos[model_ID]['explored_dict'][next_worker] >= args.training_times and i + 1 < len(sp[cur_node_idx][target_node_idx]):
            i += 1
            next_worker_idx = sp[cur_node_idx][target_node_idx][i]
        if isinstance(next_worker_idx, str):
            next_worker = f'workerID_{int(next_worker_idx)+1}'
        else:
            next_worker = f'workerID_{next_worker_idx+1}'
    elif mode == 1:
        # randomly select a neighbour as the next worker
        neighbours = [idx for idx in list(graph.neighbors(cur_node_idx))]
        neighbours.append(cur_node_idx)
        next_worker_idx = random.sample(neighbours, 1)[0]
        if isinstance(next_worker_idx, str):
            next_worker = f'workerID_{int(next_worker_idx)+1}'
        else:
            next_worker = f'workerID_{next_worker_idx+1}'
    else:
        # applys self-designed next worker selection
        node_values = get_new_node_value(model_infos[model_ID], cur_node_idx, graph, datasets, args, current_round)
        max_node_value = node_values[0][1]
        node_candidates = []
        for wid, value in node_values:
            if value == max_node_value:
                node_candidates.append(wid)
        if len(node_candidates) > 1 and max_node_value <= 1:
            spl = dict(nx.all_pairs_shortest_path_length(graph))
            d = 1
            neigh_nvs_dicts = {}
            new_max_node_value = copy.deepcopy(max_node_value)
            search_fail = False
            while new_max_node_value <= max_node_value:
                # loop when unable to find new max node value that is higher than the stuck max node value
                neighbours = []
                for node_idx in list(graph.nodes()):
                    if spl[cur_node_idx][node_idx] == d and not node_idx in neighbours:
                        neighbours.append(node_idx)
                if len(neighbours) == 0:
                    next_worker = random.sample(node_candidates, 1)[0]
                    search_fail = True
                    print("Unable to find neighbours at distance %s so that stop searching and randomly return a good neighbour" % d)
                    break
                neigh_nvs_dict = {}
                for neigh_idx in neighbours:
                    neigh_node_values = get_new_node_value(model_infos[model_ID], neigh_idx, graph, datasets, args, current_round)
                    neigh_nvs_dict[neigh_idx] = neigh_node_values
                    if neigh_node_values[0][0] != worker_ID and neigh_node_values[0][1] > new_max_node_value:
                        new_max_node_value = neigh_node_values[0][1]
                neigh_nvs_dicts[d] = neigh_nvs_dict
                print("new_max_node_value = %s; max_node_value = %s" % (new_max_node_value, max_node_value))
                d += 1
            if not search_fail:
                neigh_nvs_dict = neigh_nvs_dicts[d-1]
                node_candidates = []
                for _, nvs in neigh_nvs_dict.items():
                    for wid, value in nvs:
                        if value == new_max_node_value and wid != worker_ID:
                            node_candidates.append(wid)
                            break
                target_node_idx = int(node_candidates[0].split('_')[-1])-1
                if isinstance(list(graph.nodes())[0], str):
                    target_node_idx = str(target_node_idx)
                sp = dict(nx.all_pairs_shortest_path(graph))
                next_worker_idx = sp[cur_node_idx][target_node_idx][1]
                if isinstance(next_worker_idx, str):
                    next_worker = f'workerID_{int(next_worker_idx)+1}'
                else:
                    next_worker = f'workerID_{next_worker_idx+1}'
        else:
            next_worker = node_values[0][0]
        # check if next worker is occupied by other models
        # when the switch is on, stay on the current worker to do another training
        if wait == True:
            for mid in model_infos.keys():
                if mid != model_ID:
                    tw = model_infos[mid]['training_worker']
                    if next_worker == tw:
                        next_worker = worker_ID
                        break
    model_infos[model_ID]['training_worker'] = next_worker

def max_agg_model_infos(received_model_infos):
    # aggregate the explore dicts and the trained data amount of the model informations by maximum value
    eps = []
    das = []
    for mid in received_model_infos:
        ep = received_model_infos[mid]['explored_dict']
        da = received_model_infos[mid]['trained_data_amount']
        eps.append(ep)
        das.append(da)
    res = {}
    res['explored_dict'] = copy.deepcopy(eps[0])
    das.sort(reverse=True)
    res['trained_data_amount'] = das[0]
    for wid in eps[0].keys():
        values = []
        for ep in eps:
            values.append(ep[wid])
        values.sort(reverse=True)
        res['explored_dict'][wid] = values[0]
    return res

def sum_agg_model_infos(received_model_infos):
    # aggregate the explore dicts and the trained data amount of the model informations by summing up value
    eps = []
    das = []
    for mid in received_model_infos:
        ep = received_model_infos[mid]['explored_dict']
        da = received_model_infos[mid]['trained_data_amount']
        eps.append(ep)
        das.append(da)
    res = {}
    res['explored_dict'] = copy.deepcopy(eps[0])
    res['trained_data_amount'] = sum(das)
    for wid in eps[0].keys():
        values = []
        for ep in eps:
            values.append(ep[wid])
        res['explored_dict'][wid] = sum(values)
    return res
    

def show_cur_status(model_infos, graph):
    # show the current status, including:
    # 1. how many workers that each model has not explored
    #print_remained_workers(model_infos, graph)
    # 2. where each model is 
    show_model_pos(model_infos, graph)

def print_remained_workers(model_infos, graph):
    # how many workers that each model has not explored
    num_total_nodes = len(list(graph.nodes()))
    nodes = list(graph.nodes())
    print("Below are the number of workers that each model has not explored:")
    for model_ID in model_infos.keys():
        num_explored_workers = len(model_infos[model_ID]['explored'])
        explored_list = model_infos[model_ID]['explored']
        remained_workers = []
        for node_idx in nodes:
            worker_ID = f'workerID_{int(node_idx)+1}'
            if not worker_ID in explored_list:
                remained_workers.append(worker_ID)
        print("Model %s: %s workers remain, which are %s" % (model_ID, num_total_nodes - num_explored_workers, remained_workers))


def show_model_pos(model_infos, graph):
    # where each model is
    tns = []
    for model_ID in model_infos.keys():
        cur_worker_ID = model_infos[model_ID]['training_worker']
        if 'wait' in cur_worker_ID:
            cur_node_idx = int(cur_worker_ID.split('_')[-2])-1
        else:
            cur_node_idx = int(cur_worker_ID.split('_')[-1])-1
        tns.append(cur_node_idx)
    color_map = []
    for node_idx in list(graph.nodes()):
        node_idx = int(node_idx)
        if node_idx in tns:
            color_map.append('#58D68D')
        else:
            color_map.append('#5DADE2')
    nx.draw(graph, node_color=color_map, with_labels=True)
    plt.show()
    
def is_training_completed(model_infos, training_times):
    for model_ID in model_infos:
        explored_dict = model_infos[model_ID]['explored_dict']
        for wid, times in explored_dict.items():
            if times < training_times:
                return False
    return True

def find_uncompleted_models(model_infos, training_times):
    res = []
    for model_ID in model_infos:
        explored_dict = model_infos[model_ID]['explored_dict']
        for wid, times in explored_dict.items():
            if times < training_times:
                res.append(model_ID)
                break
    return res

    

def cascade_models(ckpts, cascade_indexs, load_path):

    """
    new way of aggregating the learned model gradients from workers
    Use case: 
        input: [id1, id2, id3]
        Structure of new global model:
            [prev base blocks] x n
            [id1_blocks_0] x n
            [id2_blocks_0] x n
            [id3_blocks_0] x n
            [after base blocks] x n
        return new global model
    """
    print("Start aggregating Vit blocks from workers")
    block_ckpt_mapping = {}
    block_index_mapping = {}

    for i in range(len(ckpts)):
        ckpt = ckpts[i]
        cascade_index = cascade_indexs[i]
        for idx in cascade_index:
            if idx in block_ckpt_mapping:
                block_ckpt_mapping[idx].append(ckpt)
            else:
                block_ckpt_mapping[idx] = [ckpt]
            block_index_mapping[idx] = len(cascade_index)

    for k, v in block_ckpt_mapping.items():
        if len(v) > 1:
            new_cp = copy.deepcopy(v[0])
            for kk in new_cp['model']:
                param_list = []
                for cp in v:
                    param_list.append(cp['model'][kk])
                mean_param = torch.mean(torch.stack(param_list), dim=0)
                new_cp['model'][kk] = mean_param
            block_ckpt_mapping[k] = new_cp
        else:
            block_ckpt_mapping[k] = v[0]

    assert os.path.exists(load_path), "fed_checkpoint.pth must exist for model aggregation"
    new_checkpoint = torch.load(load_path, map_location='cpu')

    contained_block_ids = set()
    for k in new_checkpoint['model']:
        if not "decoder" in k and "blocks" in k:
            name_split = k.split(".")
            block_id = name_split[1]
            contained_block_ids.add(block_id)
    assert len(contained_block_ids) == len(block_ckpt_mapping.keys()), "the depth in the cascaded model %s does not match with the sum of depth of the input models %s" % (len(contained_block_ids), len(block_ckpt_mapping.keys()))

    for k in new_checkpoint['model']:
        if not "decoder" in k and "blocks" in k:
            name_split = k.split(".")
            block_id = int(name_split[1])
            client_checkpoint = block_ckpt_mapping[block_id]
            respond_block_id = block_id % block_index_mapping[block_id]
            name_split[1] = str(respond_block_id)
            new_name = ".".join(name_split)
            new_checkpoint['model'][k] = client_checkpoint['model'][new_name]
        else:
            params = []
            for ckpt in ckpts:
                params.append(ckpt['model'][k])
            mean_param = torch.mean(torch.stack(params), dim=0)
            new_checkpoint['model'][k] = mean_param

    misc.save_on_master(new_checkpoint, load_path)

    print('Block aggregation is finished!')

    return new_checkpoint
        