import os
import argparse
import json
import logging
from platform import node
import random
from collections import OrderedDict, defaultdict
from pathlib import Path
import copy
import traceback
from functools import partial

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from tqdm.auto import trange, tqdm
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, from_networkx, to_dense_adj
from torch_sparse import SparseTensor
import torchvision

from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

from models import CNNTarget, GAEHyper, GNNHyper, GRUTarget, MLPTarget, MLPHyper, MLPInnerProductDecoder
from node import BaseNodes
from utils import get_device, set_logger, set_seed, freeze, unfreeze, detach_clone, CustomCosineAnnealingLR, PiecewiseLinearLR, write_results, eval_model
from deform import DeformWrapper

class GraphConstructor(nn.Module):
    def __init__(self, nnodes, k, dim, device, alpha=3, static_feat=None):
        super(GraphConstructor, self).__init__()
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1]
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:
            self.emb1 = nn.Embedding(nnodes, dim)
            self.emb2 = nn.Embedding(nnodes, dim)
            self.lin1 = nn.Linear(dim, dim)
            self.lin2 = nn.Linear(dim, dim)

        self.device = device
        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx, :]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1, 0)) - torch.mm(nodevec2, nodevec1.transpose(1, 0))
        adj = F.relu(torch.tanh(self.alpha*a))

        return adj

    def eval(self, idx, full=False):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx, :]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1, 0))-torch.mm(nodevec2, nodevec1.transpose(1, 0))
        adj = F.relu(torch.tanh(self.alpha*a))

        if not full:
            mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
            mask.fill_(float('0'))
            s1, t1 = adj.topk(self.k, 1)
            mask.scatter_(1, t1, s1.fill_(1))
            adj = adj*mask

        return adj

def generate_adj(param_metrix, args, subgraph_size):
    device = param_metrix.device
    num_clients = len(param_metrix)
    dist_metrix = torch.zeros((num_clients, num_clients))
    for i in range(num_clients):
        for j in range(num_clients):
            dist_metrix[i][j] = torch.nn.functional.pairwise_distance(
                param_metrix[i].view(1, -1), param_metrix[j].view(1, -1), p=2).clone().detach()
    dist_metrix = torch.nn.functional.normalize(dist_metrix).to(device)
    gc = GraphConstructor(num_clients, subgraph_size, 40,
                          device, args.adjalpha).to(device)
    idx = torch.arange(num_clients).to(device)
    optimizer = torch.optim.SGD(gc.parameters(), lr=0.01, weight_decay=0.0001)

    for e in range(args.gc_epoch):
        optimizer.zero_grad()
        adj = gc(idx)
        adj = torch.nn.functional.normalize(adj)

        loss = torch.nn.functional.mse_loss(adj, dist_metrix)
        loss.backward()
        optimizer.step()

    adj = gc.eval(idx).to("cpu")

    return adj

def normalize_adj(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

@torch.no_grad()
def sd_matrixing(state_dic):
    """
    Turn state dic into a vector
    :param state_dic:
    :return:
    """
    keys = []
    param_vector = None
    for key, param in state_dic.items():
        # if 'bn' in key:
        #     continue
        keys.append(key)
        if param_vector is None:
            param_vector = param.flatten()
        else:
            if len(list(param.size())) == 0:
                param_vector = torch.cat((param_vector, param.view(1).type(torch.float32)), 0)
            else:
                param_vector = torch.cat((param_vector, param.flatten()), 0)
    return param_vector

def SFL_graph_aggr(nodes, pre_A, args, node_ids):
    keys = []
    key_shapes = []
    param_metrix = []
    device = get_device(gpus=args.gpu)

    for key, param in nodes.global_weights.items():
        # if 'bn' in key:
        #     continue
        keys.append(key)
        key_shapes.append(param.data.shape)

    if args.full_adj:
        A = torch.from_numpy(normalize_adj(pre_A)).to(device)
        global_weights = sd_matrixing(nodes.global_weights)
        param_metrix = global_weights.expand(nodes.n_nodes, -1)
        for node_id in nodes.weights.keys():
            if node_id in node_ids:
                param_metrix[node_id] = sd_matrixing(nodes.weights[node_id])
    else:
        A = torch.from_numpy(normalize_adj(pre_A[node_ids][:, node_ids])).to(device)
        for node_id in node_ids:
            param_metrix.append(sd_matrixing(nodes.weights[node_id]))
        param_metrix = torch.stack(param_metrix).to(device)
    if args.algorithm in ["SFL_v2", "SFL_v3"]:
        subgraph_size = min(args.subgraph_size, args.num_clients)
        new_A = generate_adj(param_metrix, args, subgraph_size).cpu().detach().numpy()
        new_A = torch.from_numpy(normalize_adj(new_A)).to(device)
        if args.algorithm == "SFL_v3":
            A = (1 - args.adjbeta) * A + args.adjbeta * new_A
        else:
            A = new_A
    assert len(A) == len(param_metrix)
    # Aggregating
    aggregated_param = torch.mm(A, param_metrix)
    for i in range(args.layers - 1):
        aggregated_param = torch.mm(A, aggregated_param)
    new_param_matrix = (args.alpha * aggregated_param) + ((1 - args.alpha) * param_metrix)

    # reconstract parameter
    for i, node_id in enumerate(node_ids):
        start = 0
        for key, shape in zip(keys, key_shapes):
            end = start + shape.numel()
            nodes.weights[node_id][key] = new_param_matrix[i][start:end].reshape(shape)
            start = end

    return nodes.weights

def evaluate(nodes: BaseNodes, net, criteria, device, eval_batch_count, finetune_epochs, finetune_lr, finetune_wd, node_list, split):
    net.eval()
    results = defaultdict(lambda: defaultdict(list))

    node_iter = tqdm(node_list, position=-1, leave=False)
    for node_id in node_iter:  # iterating over nodes
        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]
        for n,p in net.named_parameters():
            p.data = nodes.global_weights[n].detach().clone()
        for n,p in net.named_buffers():
            p.data = nodes.global_buffers[n].detach().clone()
        with torch.no_grad():
            for batch_count, batch in enumerate(curr_data):
                if batch_count >= eval_batch_count:
                    break
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                running_loss += criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    running_correct += pred.argmax(1).eq(label).sum()
                    running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f}"
            if isinstance(criteria, torch.nn.CrossEntropyLoss):
                results[node_id]['correct'] = running_correct
                results[node_id]['total'] = running_samples
                results[node_id]['acc'] = running_correct / running_samples
                tqdm_str += f", AVG Acc: {results[node_id]['acc']}"
        if finetune_epochs > 0:
            net.train()
            inner_optim = torch.optim.SGD(net.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=finetune_wd)
            for e in range(finetune_epochs):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            net.eval()
            running_loss, running_correct, running_samples = 0., 0., 0.
            with torch.no_grad():
                for batch_count, batch in enumerate(curr_data):
                    if batch_count >= eval_batch_count:
                        break
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    running_loss += criteria(pred, label)
                    if isinstance(criteria, torch.nn.CrossEntropyLoss):
                        running_correct += pred.argmax(1).eq(label).sum()
                        running_samples += len(label)

                results[node_id]['finetune_loss'] = running_loss / (batch_count + 1)
                tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f} -> {results[node_id]['finetune_loss']:.4f}"
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    results[node_id]['finetune_correct'] = running_correct
                    results[node_id]['finetune_total'] = running_samples
                    results[node_id]['finetune_acc'] = running_correct / running_samples
                    tqdm_str += f", AVG Acc: {results[node_id]['acc']:.4f} -> {results[node_id]['finetune_acc']:.4f}"
        node_iter.set_description(tqdm_str)
    return results


def train(config: dict, hparam_dict: dict={}) -> None:
    hparam_dict.update(config)
    set_seed(args.seed)
    device = get_device(gpus=hparam_dict['gpu'])
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(hparam_dict['data_name'], batch_size=hparam_dict['batch_size'], eval_batch_size=hparam_dict['eval_batch_size'], num_workers=hparam_dict['num_workers'], device=device, data_dir=hparam_dict['data_dir'])
    num_nodes = nodes.n_nodes
    all_nodes = np.arange(num_nodes)
    client_sample_rng = np.random.default_rng(hparam_dict['seed'])
    if 'METR-LA-' in hparam_dict['data_name'] or 'PEMS-BAY-' in hparam_dict['data_name']:
        new_order = all_nodes
        train_nodes = nodes.train_nodes
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = list(set(all_nodes) - set(train_nodes))
        train_A = to_dense_adj(nodes.train_edge_index)
        full_A = to_dense_adj(nodes.eval_edge_index)
    else:
        new_order = nodes.shuffle()
        num_train = int(0.8*num_nodes)
        train_nodes = all_nodes[:num_train]
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = all_nodes[num_train:]
        G = nx.from_numpy_array(nodes.A)
        data = from_networkx(G)
        train_edge_index, train_edge_attr = subgraph(torch.LongTensor(train_nodes), data.edge_index, num_nodes=num_nodes)
        train_A = to_dense_adj(train_edge_index, max_num_nodes=num_nodes).squeeze()
        full_A = nodes.A
    logging.info(f"Num of nodes: {len(all_nodes)}")
    logging.info(f"Num of train nodes: {len(train_nodes)}")
    logging.info(f"adv_node_list: {adv_node_list}")
    if 'METR-LA' in hparam_dict['data_name'] or 'PEMS-BAY' in hparam_dict['data_name']:
        net = GRUTarget(nodes.data_dim, out_dim=nodes.label_dim)
    elif hparam_dict['data_name'] == 'CompCars':
        CNN = getattr(torchvision.models, hparam_dict['target_arch'])
        net = CNN(pretrained=True)
        final_layer_name, final_layer = list(net.named_children())[-1]
        setattr(net, final_layer_name, torch.nn.Linear(final_layer.in_features, nodes.label_dim))
    else:
        net = MLPTarget(nodes.data_dim, out_dim=nodes.label_dim)
    net.to(device)

    deform = DeformWrapper(device, hparam_dict['aug_method'], nodes.label_dim, hparam_dict['sigma'])

    ##################
    # init optimizer #
    ##################
    if hparam_dict['data_name'] in ['DG15', 'DG60', 'DG60_TRUNCNORM', 'CompCars']:
        criteria = torch.nn.CrossEntropyLoss()
    elif hparam_dict['data_name'] in ['TPT48', 'TPT48_TRUNCNORM', 'METR-LA', 'METR-LA-0.25', 'METR-LA-0.5', 'METR-LA-0.75', 'METR-LA-0.05', 'METR-LA-0.9', 'PEMS-BAY', 'PEMS-BAY-0.25', 'PEMS-BAY-0.5', 'PEMS-BAY-0.75', 'PEMS-BAY-0.05', 'PEMS-BAY-0.9']:
        criteria = torch.nn.MSELoss()
        acc_str = ""
    else:
        raise f"Unknown dataset {hparam_dict['data_name']}."
    
    metric_dict = {
        'Train/best_step': np.nan,
        'Train/best_test_avg_loss': np.nan,
        'Train/best_test_avg_acc': np.nan,
        'Unseen/best_step': np.nan,
        'Unseen/best_test_avg_loss': np.nan,
        'Unseen/best_test_avg_acc': np.nan,
        'Unseen_0/test_avg_loss': np.nan,
        'Unseen_0/test_avg_acc': np.nan,
    }

    writer = SummaryWriter(hparam_dict['run_path'])
    if hparam_dict['pretrain_targetnet_steps'] > 0:
        net.train()
        step_iter = trange(hparam_dict['pretrain_targetnet_steps'])
        inner_optim = torch.optim.Adam(
            net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
        )
        node_id = train_nodes[0]
        for step in step_iter:
            inner_optim.zero_grad()

            batch = next(iter(nodes.train_loaders[node_id]))
            batch = tuple(t.to(device, non_blocking=True) for t in batch)
            label = batch[1]
            if isinstance(net, GRUTarget):
                pred = net(batch)
            else:
                pred = net(batch[0])

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            inner_optim.step()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                    acc_str = f",  Acc: {acc:.4f}"
                else:
                    acc = torch.tensor(0)
            step_iter.set_description(
                f"Pre-train Target NN Step: {step}, Node ID: {node_id}, Loss: {loss.item():.4f}" + acc_str
            )
            writer.add_scalar('Pre-Train/loss_target', loss, step)
            writer.add_scalar('Pre-Train/acc_target', acc, step)
        net.eval()

    ################
    # init metrics #
    ################
    nodes.global_weights = detach_clone(OrderedDict(net.named_parameters()))
    nodes.global_buffers = detach_clone(OrderedDict(net.named_buffers()))
    steps = hparam_dict['num_steps']
    step_iter = trange(steps)
    eval_every = hparam_dict['eval_every']
    nodes.weights = {}
    nodes.buffers = {}
    for node_id in all_nodes:
        nodes.weights[node_id] = detach_clone(nodes.global_weights)
        nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
    # nodes.__setattr__('joined', np.zeros_like(all_nodes).astype(bool))
    num_clients = min(len(train_nodes), hparam_dict['num_clients'])
    inner_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['inner_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['inner_lr_min'])
    inner_lr_scheduler.step()
    evaluator = partial(evaluate, nodes, net, criteria, device, hparam_dict['eval_batch_count'], hparam_dict['eval_finetune_epochs'], hparam_dict['eval_finetune_lr'], hparam_dict['eval_finetune_wd'])
    for step in step_iter:
        if step % eval_every == 0:
            results = eval_model(evaluator, train_nodes, adv_node_list)
            write_results(writer, 'Train', results, step)

        # select client at random
        node_ids = client_sample_rng.choice(train_nodes, num_clients, replace=False)
        logging.debug(f"node_ids: {node_ids}")
        # node_iter = tqdm(node_ids, position=1, leave=False)
        node_results = []
        # produce & load local network weights
        for node_id in node_ids:
            nodes.weights[node_id] = detach_clone(nodes.global_weights)
            nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
            nodes.joined[node_id] = True
            for n,p in net.named_parameters():
                p.data = nodes.weights[node_id][n]
            for n,p in net.named_buffers():
                p.data = nodes.buffers[node_id][n]
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
            if hparam_dict['piecewise_lr_scheduler']:
                piecewise_lr_scheduler = PiecewiseLinearLR(inner_optim, max_epoch=hparam_dict['inner_epochs'], batch_size=hparam_dict['batch_size'], batch_count=len(nodes.train_loaders[node_id]))

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            batch_count = 0
            if hparam_dict['inner_epochs'] is None:
                inner_steps = hparam_dict['inner_steps']
                inner_epochs = inner_steps // len(nodes.train_loaders[node_id]) + 1
            else:
                inner_epochs = hparam_dict['inner_epochs']
                inner_steps = inner_epochs * len(nodes.train_loaders[node_id])
            for e in range(inner_epochs):
                 # inner updates -> obtaining theta_tilda
                for batch in nodes.train_loaders[node_id]:
                    batch_count += 1
                    if batch_count > inner_steps:
                        break
                    inner_optim.zero_grad()
                    # batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    if node_id in adv_node_list:
                        batch = deform(batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)

                    # SFL regulizer
                    m1 = sd_matrixing(OrderedDict(net.named_parameters())).reshape(1, -1).to(device)
                    m2 = sd_matrixing(nodes.global_weights).reshape(1, -1).to(device)
                    m3 = sd_matrixing(nodes.weights[node_id]).reshape(1, -1).to(device)
                    reg1 = torch.nn.functional.pairwise_distance(m1, m2, p=2)
                    reg2 = torch.nn.functional.pairwise_distance(m1, m3, p=2)
                    (loss + args.reg * reg1 + args.reg * reg2).backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
                    inner_optim.step()
                    if hparam_dict['piecewise_lr_scheduler']:
                        piecewise_lr_scheduler.step()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                    acc_str = f",  Acc: {prvs_acc:.4f} -> {acc:.4f}"
                else:
                    acc = torch.tensor(0)
            # node_iter.set_description(
            #     f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f}" + acc_str
            # )
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
        # if step % eval_every == 0:
            logging.debug(
                f"Step: {step}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
            )
        prvs_loss, node_loss, prvs_acc, acc = np.mean(node_results, axis=0)
        step_iter.set_description(
            f"Step: {step}, Loss: {prvs_loss:.4f} -> {node_loss:.4f}" + acc_str
        )
        nodes.weights = SFL_graph_aggr(nodes, train_A, args, node_ids)
        for n,p in net.named_parameters():
            nodes.global_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)
        nodes.weights = {}
        for n,p in net.named_buffers():
            nodes.global_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
        nodes.buffers = {}
        if hparam_dict['inner_lr_scheduler']:
            inner_lr_scheduler.step()
    # Log Final performance
    step += 1
    results = eval_model(evaluator, train_nodes, adv_node_list)
    write_results(writer, 'Train', results, step)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'global_buffers': nodes.global_buffers,
        'client_weights': nodes.weights,
        'client_buffers': nodes.buffers,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'final.ckpt'))


    # Evaluate on unseen nodes
    step = 0
    for node_id in unseen_nodes:
        nodes.weights[node_id] = detach_clone(nodes.global_weights)
    results = eval_model(evaluator, unseen_nodes)
    write_results(writer, 'Unseen_0', results, step)
    best_metric_dict = {
        'Unseen_0/test_avg_loss': results['test_avg_loss'],
        'Unseen_0/test_max_loss': results['test_max_loss'],
        'Unseen_0/test_min_loss': results['test_min_loss'],
        'Unseen_0/test_std_loss': results['test_std_loss']
    }
    if isinstance(criteria, torch.nn.CrossEntropyLoss):
        best_metric_dict.update({
            'Unseen_0/test_avg_acc': results['test_avg_acc'],
            'Unseen_0/test_max_acc': results['test_max_acc'],
            'Unseen_0/test_min_acc': results['test_min_acc'],
            'Unseen_0/test_std_acc': results['test_std_acc']
        })
    metric_dict.update(best_metric_dict)
    exp, ssi, sei = hparams(hparam_dict, metric_dict)
    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k,v in best_metric_dict.items():
        writer.add_scalar(k, v, step)

    step_iter = trange(hparam_dict['eval_unseen_steps'], position=0)
    num_clients = min(len(unseen_nodes), hparam_dict['num_clients'])
    for step in step_iter:
        # select client at random
        node_ids = np.random.choice(unseen_nodes, num_clients, replace=False)
        aggr_weights = OrderedDict([[n, torch.zeros_like(p.data)] for n, p in net.named_parameters()])
        node_iter = tqdm(node_ids, position=1, leave=False)
        node_results = []
        # produce & load local network weights
        for node_id in node_iter:
            if not nodes.joined[node_id]:
                nodes.weights[node_id] = detach_clone(nodes.global_weights)
                nodes.joined[node_id] = True
            for n,p in net.named_parameters():
                p.data = nodes.weights[node_id][n].data.clone()
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            logging.debug(f"Start local training on client {node_id}.")
            for e in range(hparam_dict['inner_epochs']):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    # batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)

                    # SFL regulizer
                    m1 = sd_matrixing(OrderedDict(net.named_parameters())).reshape(1, -1).to(device)
                    m2 = sd_matrixing(nodes.global_weights).reshape(1, -1).to(device)
                    m3 = sd_matrixing(nodes.weights[node_id]).reshape(1, -1).to(device)
                    reg1 = torch.nn.functional.pairwise_distance(m1, m2, p=2)
                    reg2 = torch.nn.functional.pairwise_distance(m1, m3, p=2)
                    (loss + args.reg * reg1 + args.reg * reg2).backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
                    inner_optim.step()
            for n,p in net.named_parameters():
                nodes.weights[node_id][n].data = p.data.detach().clone()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                    acc_str = f",  Acc: {prvs_acc:.4f} -> {acc:.4f}"
                else:
                    acc = torch.tensor(0)
            node_iter.set_description(
                f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f}" + acc_str
            )
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
        node.weights = SFL_graph_aggr(nodes.weights, full_A, args, node_ids)
        aggr_weights = OrderedDict()
        aggr_buffers = OrderedDict()
        for n,p in net.named_parameters():
            aggr_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)
        for n,p in net.named_buffers():
            aggr_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
        nodes.global_weights.update(aggr_weights)
        nodes.global_buffers.update(aggr_buffers)

        results = eval_model(evaluator, unseen_nodes)
        write_results(writer, 'Unseen', results, step+1)
        step_iter.set_description(f"Unseen Step: {step}, AVG Loss: {results['test_avg_loss']:.4f},  AVG Acc: {results['test_all_acc']:.4f}")
    writer.close()
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'global_buffers': nodes.global_buffers,
        'client_weights': nodes.weights,
        'client_buffers': nodes.buffers,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'unseen.ckpt'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Federated Hypernetwork with Lookahead experiment"
    )

    #############################
    #       Dataset Args        #
    #############################

    parser.add_argument("--data-name", type=str, default="DG15")
    parser.add_argument("--data-dir", type=str, default="data", help="dir name of dataset")

    ##################################
    #       Optimization args        #
    ##################################

    parser.add_argument("--pretrain-targetnet-steps", type=int, default=0)
    parser.add_argument("--num-steps", type=int, default=800)
    parser.add_argument("--eval-unseen-steps", type=int, default=0)
    parser.add_argument("--optim", type=str, default='sgd', choices=['adam', 'sgd'], help="optimizer")
    parser.add_argument("--num-clients", type=int, default=5)
    parser.add_argument("--last-layer-only", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument('--aug-method', type=str, default='label_flip', help='type of augmentation for training')
    parser.add_argument("--adv-ratio", type=float, default=.0, help="Ratio of adversarial clients participated FL training.")
    parser.add_argument("--sigma", type=float, default=.0, help="noise hyperparameter")

    ################################
    #       Model Prop args        #
    ################################
    parser.add_argument("--layers", type=int, default=3, help="num. hidden layers")
    parser.add_argument("--inner-lr", type=float, default=1e-2, help="learning rate for inner optimizer")
    parser.add_argument("--inner-lr-min", type=float, default=1e-3, help="Min learning rate for inner optimizer")
    parser.add_argument("--inner-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--piecewise-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
    parser.add_argument("--inner-epochs", type=int, default=None)
    parser.add_argument("--inner-steps", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--target-arch", type=str, default=None, help="architecture of target network")
    parser.add_argument("--alpha", type=float, default=1.0)
    parser.add_argument("--reg", type=float, default=0.3)

    # Graph Learning
    parser.add_argument("--algorithm", type=str, default='SFL', choices=['SFL', 'SFL_v2', 'SFL_v3'], help="Use structure learning")
    parser.add_argument('--subgraph_size', type=int, default=30, help='k')
    parser.add_argument('--adjalpha', type=float, default=3, help='adj alpha')
    parser.add_argument('--gc_epoch', type=int, default=10, help='')
    parser.add_argument('--adjbeta', type=float, default=0.05, help='update ratio')
    parser.add_argument("--full-adj", type=bool, default=True, help="Use full adj to aggregate weights.")

    #############################
    #       General args        #
    #############################
    parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
    parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
    parser.add_argument("--eval-batch-size", type=int, default=64, help="eval batch size")
    parser.add_argument("--eval-batch-count", type=int, default=999999, help="eval num batch")
    parser.add_argument("--eval-finetune-epochs", type=int, default=0)
    parser.add_argument("--eval-finetune-lr", type=float, default=1e-3, help="learning rate for eval finetune optimizer")
    parser.add_argument("--eval-finetune-wd", type=float, default=5e-5, help="weight decay for eval finetune optimizer")
    parser.add_argument("--num-workers", type=int, default=0, help="num workers for dataloader")
    parser.add_argument("--run-name", type=str, default=None, help="name for output run file")
    parser.add_argument("--run-dir", type=str, default='runs', help="dir path for output run file")
    parser.add_argument("--task-name", type=str, default=None, help="ClearML task name")
    parser.add_argument("--log-level", type=str, default='INFO', help="log level")
    parser.add_argument("--seed", type=int, default=42, help="seed value")

    args = parser.parse_args()
    if args.target_arch is None:
        if 'METR-LA' in args.data_name or 'PEMS-BAY' in args.data_name:
            args.target_arch = 'GRUSeq2Seq'
        elif args.data_name == 'CompCars':
            args.target_arch = 'resnet18'
        else:
            args.target_arch = 'mlp'
    assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"

    if args.run_name is None:
        import time
        args.run_name = f"{time.strftime('%Y%m%d_%H-%M-%S')}_{args.data_name}_{args.algorithm}"
        if args.adv_ratio > 0:
            if args.aug_method == 'label_flip':
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}"
            elif args.aug_method != 'nominal' and args.sigma > 0:
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}_{args.sigma}"
            else:
                args.adv_ratio = 0
    args.run_name += f'_seed_{args.seed}'
    hparam_dict = vars(args)
    hparam_dict['run_path'] = os.path.join(args.run_dir, args.run_name)
    if not os.path.exists(hparam_dict['run_path']):
        os.makedirs(hparam_dict['run_path'], exist_ok=True)
    set_logger(os.path.join(hparam_dict['run_path'],'output.log'), level=args.log_level)
    logging.info(json.dumps(hparam_dict, sort_keys=True, indent=4))
    try:
        from clearml import Task
        if args.task_name is None:
            hparam_dict['task_name'] = args.run_name
        task = Task.init(project_name='Panacea NeurIPS 2023', task_name=hparam_dict['task_name'])
    except:
        task = None

    try:
        train(hparam_dict)
    except Exception:
        print(traceback.format_exc())
    if task is not None:
        task.close()
