import os
import argparse
import json
import logging
import random
from collections import OrderedDict, defaultdict, namedtuple
from pathlib import Path
import copy
import traceback
from functools import partial

import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
from tqdm.auto import trange, tqdm
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, from_networkx
from torch_sparse import SparseTensor
import torchvision

from models import CNNTarget, GAEHyper, GNNHyper, GRUTarget, MLPTarget, MLPHyper, MLPInnerProductDecoder
from node import BaseNodes
from utils import get_device, set_logger, set_seed, freeze, unfreeze, detach_clone, CustomCosineAnnealingLR, sd_matrixing
from deform import DeformWrapper


def write_results(writer, prefix, results, step):
    for k, v in results.items():
        if v.shape.numel() > 1:
            writer.add_histogram(f'{prefix}/{k}', v, step)
        else:
            writer.add_scalar(f'{prefix}/{k}', v, step)
    loss_str = f"{results['test_avg_loss']:.4f}"
    if 'clean_test_avg_loss' in results:
        loss_str += f"({results['clean_test_avg_loss']:.4f})"
    if 'finetune_test_avg_loss' in results:
        loss_str += f" -> {results['finetune_test_avg_loss']:.4f}"
    log_str = f"Step: {step}, AVG Loss: {loss_str}"
    if 'test_avg_acc' in results:
        acc_str = f"{results['test_avg_acc']:.4f}"
        if 'clean_test_avg_acc' in results:
            acc_str += f"({results['clean_test_avg_acc']:.4f})"
        if 'finetune_test_avg_acc' in results:
            acc_str += f" -> {results['finetune_test_avg_acc']:.4f}"
        log_str += f", AVG Acc: {acc_str}"
    logging.info(log_str)

def eval_model(evaluator, node_list, adv_node_list=[], split='test'):
    results = {}
    curr_results = evaluator(node_list, split)
    vals = curr_results.values()
    all_loss = torch.stack([val['loss'] for val in vals])
    results[f'{split}_all_loss'] = all_loss
    results[f'{split}_avg_loss'] = torch.mean(all_loss)
    results[f'{split}_max_loss'] = torch.max(all_loss)
    results[f'{split}_min_loss'] = torch.min(all_loss)
    results[f'{split}_std_loss'] = torch.std(all_loss)
    if 'correct' in curr_results[node_list[0]]:
        all_acc = torch.stack([val['acc'] for val in vals])
        results[f'{split}_all_acc'] = all_acc
        results[f'{split}_avg_acc'] = torch.mean(all_acc)
        results[f'{split}_max_acc'] = torch.max(all_acc)
        results[f'{split}_min_acc'] = torch.min(all_acc)
        results[f'{split}_std_acc'] = torch.std(all_acc)
    if len(adv_node_list) > 0:
        vals = [v for k, v in curr_results.items() if k not in adv_node_list]
        clean_all_loss = torch.stack([val['loss'] for val in vals])
        results[f'clean_{split}_all_loss'] = clean_all_loss
        results[f'clean_{split}_avg_loss'] = torch.mean(clean_all_loss)
        results[f'clean_{split}_max_loss'] = torch.max(clean_all_loss)
        results[f'clean_{split}_min_loss'] = torch.min(clean_all_loss)
        results[f'clean_{split}_std_loss'] = torch.std(clean_all_loss)
        if 'correct' in curr_results[node_list[0]]:
            clean_all_acc = torch.stack([val['acc'] for val in vals])
            results[f'clean_{split}_all_acc'] = clean_all_acc
            results[f'clean_{split}_avg_acc'] = torch.mean(clean_all_acc)
            results[f'clean_{split}_max_acc'] = torch.max(clean_all_acc)
            results[f'clean_{split}_min_acc'] = torch.min(clean_all_acc)
            results[f'clean_{split}_std_acc'] = torch.std(clean_all_acc)
    if 'finetune_loss' in curr_results[node_list[0]]:
        vals = [v for k, v in curr_results.items() if k not in adv_node_list]
        finetune_all_loss = torch.stack([val['finetune_loss'] for val in vals])
        results[f'finetune_{split}_all_loss'] = finetune_all_loss
        results[f'finetune_{split}_avg_loss'] = torch.mean(finetune_all_loss)
        results[f'finetune_{split}_max_loss'] = torch.max(finetune_all_loss)
        results[f'finetune_{split}_min_loss'] = torch.min(finetune_all_loss)
        results[f'finetune_{split}_std_loss'] = torch.std(finetune_all_loss)
        if 'correct' in curr_results[node_list[0]]:
            finetune_all_acc = torch.stack([val['finetune_acc']for val in vals])
            results[f'finetune_{split}_all_acc'] = finetune_all_acc
            results[f'finetune_{split}_avg_acc'] = torch.mean(finetune_all_acc)
            results[f'finetune_{split}_max_acc'] = torch.max(finetune_all_acc)
            results[f'finetune_{split}_min_acc'] = torch.min(finetune_all_acc)
            results[f'finetune_{split}_std_acc'] = torch.std(finetune_all_acc)
    return results

def evaluate(nodes: BaseNodes, hnet, net, criteria, device, eval_batch_count, finetune_epochs, finetune_lr, finetune_wd, node_list, split):
    if hnet is not None:
        hnet.eval()
    net.eval()
    results = defaultdict(lambda: defaultdict(list))
    node_iter = tqdm(node_list, position=-1, leave=False)
    for node_id in node_iter:  # iterating over nodes
        for n,p in net.named_buffers():
            p.data = nodes.global_buffers[n].detach().clone()
        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]

        # if nodes.joined[node_id]:
        #     weights = nodes.weights[node_id]
        # else:
        with torch.no_grad():
            weights = hnet(torch.tensor([node_id], device=device, dtype=torch.long))
            for n,p in net.named_parameters():
                if n in weights:
                    p.data = weights[n]
                else:
                    p.data = nodes.global_weights[n].detach().clone()
            for batch_count, batch in enumerate(curr_data):
                if batch_count >= eval_batch_count:
                    break
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                running_loss += criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    running_correct += pred.argmax(1).eq(label).sum()
                    running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f}"
            if isinstance(criteria, torch.nn.CrossEntropyLoss):
                results[node_id]['correct'] = running_correct
                results[node_id]['total'] = running_samples
                results[node_id]['acc'] = running_correct / running_samples
                tqdm_str += f", AVG Acc: {results[node_id]['acc']}"
        if finetune_epochs > 0:
            net.train()
            inner_optim = torch.optim.SGD(net.parameters(), lr=finetune_lr)
            for e in range(finetune_epochs):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad(set_to_none=True)
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            inner_optim.zero_grad(set_to_none=True)
            del inner_optim
            net.eval()
            running_loss, running_correct, running_samples = 0., 0., 0.
            with torch.no_grad():
                for batch_count, batch in enumerate(curr_data):
                    if batch_count >= eval_batch_count:
                        break
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    running_loss += criteria(pred, label)
                    if isinstance(criteria, torch.nn.CrossEntropyLoss):
                        running_correct += pred.argmax(1).eq(label).sum()
                        running_samples += len(label)

                results[node_id]['finetune_loss'] = running_loss / (batch_count + 1)
                tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f} -> {results[node_id]['finetune_loss']:.4f}"
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    results[node_id]['finetune_correct'] = running_correct
                    results[node_id]['finetune_total'] = running_samples
                    results[node_id]['finetune_acc'] = running_correct / running_samples
                    tqdm_str += f", AVG Acc: {results[node_id]['acc']:.4f} -> {results[node_id]['finetune_acc']:.4f}"
        node_iter.set_description(tqdm_str)
    return results


def train(hparam_dict: dict) -> None:

    hparam_dict['algorithm'] = f"pfedhn_{hparam_dict['hyper_arch']}"
    hparam_dict['run_path'] = os.path.join(hparam_dict['run_dir'], hparam_dict['run_name'])
    if not os.path.exists(hparam_dict['run_path']):
        os.makedirs(hparam_dict['run_path'], exist_ok=True)
    set_logger(os.path.join(hparam_dict['run_path'],'output.log'), level=hparam_dict['log_level'])
    set_seed(hparam_dict['seed'])
    device = get_device(gpus=hparam_dict['gpu'])
    logging.info(json.dumps(hparam_dict, sort_keys=True, indent=4))
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.tensorboard.summary import hparams
    writer = SummaryWriter(hparam_dict['run_path'])
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(hparam_dict['data_name'], batch_size=hparam_dict['batch_size'], eval_batch_size=hparam_dict['eval_batch_size'], num_workers=hparam_dict['num_workers'], device=device, data_dir=hparam_dict['data_dir'])
    num_nodes = nodes.n_nodes
    all_nodes = np.arange(num_nodes)
    client_sample_rng = np.random.default_rng(hparam_dict['seed'])
    if 'METR-LA-' in hparam_dict['data_name'] or 'PEMS-BAY-' in hparam_dict['data_name']:
        new_order = all_nodes
        train_nodes = nodes.train_nodes
        num_train = len(train_nodes)
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = list(set(all_nodes) - set(train_nodes))
        train_edge_index = nodes.train_edge_index
        data = Data(x=torch.tensor(all_nodes, device=device, dtype=torch.long), edge_index=nodes.eval_edge_index.to(device))
    else:
        new_order = nodes.shuffle()
        num_train = int(0.8*num_nodes)
        train_nodes = all_nodes[:num_train]
        adv_node_list = client_sample_rng.choice(train_nodes, int(num_train*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = all_nodes[num_train:]
        G = nx.from_numpy_array(nodes.A)
        data = from_networkx(G).to(device)
        train_edge_index, train_edge_attr = subgraph(torch.tensor(train_nodes, dtype=torch.long, device=device), data.edge_index, num_nodes=num_nodes)
    num_clients = min(len(train_nodes), hparam_dict['num_clients'])
    logging.info(f"Num of nodes: {len(all_nodes)}")
    logging.info(f"Num of train nodes: {len(train_nodes)}")
    logging.info(f"adv_node_list: {adv_node_list}")
    coef_hyper = hparam_dict['coef_hyper']
    coef_reg = hparam_dict['coef_reg']
    train_data = Data(x=torch.tensor(all_nodes, device=device, dtype=torch.long), edge_index=train_edge_index)
    sparse_adj = SparseTensor(row=train_edge_index[0], col=train_edge_index[1], sparse_sizes=(num_nodes, num_nodes)).t()
    if hparam_dict['embed_dim'] == -1:
        embed_dim = int(1 + num_nodes / 4)
        logging.info(f"Auto embedding size: {embed_dim}")
        hparam_dict['embed_dim'] = embed_dim
    else:
        embed_dim = hparam_dict['embed_dim']
    if 'METR-LA' in hparam_dict['data_name'] or 'PEMS-BAY' in hparam_dict['data_name']:
        net = GRUTarget(nodes.data_dim, out_dim=nodes.label_dim)
    elif hparam_dict['data_name'] == 'CompCars':
        CNN = getattr(torchvision.models, hparam_dict['target_arch'])
        net = CNN(pretrained=True)
        final_layer_name, final_layer = list(net.named_children())[-1]
        setattr(net, final_layer_name, torch.nn.Linear(final_layer.in_features, nodes.label_dim))
    else:
        net = MLPTarget(nodes.data_dim, out_dim=nodes.label_dim)
    net.to(device)
    final_layer_name, final_layer = list(net.named_children())[-1]

    if hparam_dict['last_layer_only']:
        net_struct = dict([(f'{final_layer_name}.{k}', v.shape) for k, v in final_layer.named_parameters() if 'bn' not in k])
    else:
        net_struct = dict([(k, v.shape) for k, v in net.named_parameters() if 'bn' not in k])

    if hparam_dict['hyper_arch'] == 'gnn':
        hnet = GNNHyper(train_data, num_nodes, embed_dim, hparam_dict['hyper_hid'], net_struct).to(device)
    elif hparam_dict['hyper_arch'] == 'gae':
        # hnet = GAEHyper(train_data, num_nodes, embed_dim, hparam_dict['hyper_hid'], net_struct, decoder=MLPInnerProductDecoder(hparam_dict['hyper_hid'], 3)).to(device)
        hnet = GAEHyper(train_data, num_nodes, embed_dim, hparam_dict['hyper_hid'], net_struct).to(device)
        coef_recon = hparam_dict['coef_recon']
    elif hparam_dict['hyper_arch'] == 'mlp':
        hnet = MLPHyper(num_nodes, embed_dim, hparam_dict['hyper_hid'], net_struct, n_hidden=3).to(device)
    else:
        raise(f"Unknown hypernetwork architecture {hparam_dict['hyper_arch']}.")
    if hparam_dict['freeze_embed']:
        freeze(hnet.embeddings)
    
    deform = DeformWrapper(device, hparam_dict['aug_method'], nodes.label_dim, hparam_dict['sigma'])
    if hparam_dict['data_name'] in ['DG15', 'DG60', 'DG60_TRUNCNORM', 'CompCars']:
        criteria = torch.nn.CrossEntropyLoss()
    elif hparam_dict['data_name'] in ['TPT48', 'TPT48_TRUNCNORM', 'METR-LA', 'METR-LA-0.25', 'METR-LA-0.5', 'METR-LA-0.75', 'METR-LA-0.05', 'METR-LA-0.9', 'PEMS-BAY', 'PEMS-BAY-0.25', 'PEMS-BAY-0.5', 'PEMS-BAY-0.75', 'PEMS-BAY-0.05', 'PEMS-BAY-0.9']:
        criteria = torch.nn.MSELoss()
    else:
        raise f"Unknown dataset {hparam_dict['data_name']}."
    
    metric_dict = {
        'Unseen_0/recon_loss': np.nan,
        'Unseen_0/test_avg_loss': np.nan,
        'Unseen_0/test_max_loss': np.nan,
        'Unseen_0/test_min_loss': np.nan,
        'Unseen_0/test_std_loss': np.nan,
        'Unseen_0/test_avg_acc': np.nan,
        'Unseen_0/test_max_acc': np.nan,
        'Unseen_0/test_min_acc': np.nan,
        'Unseen_0/test_std_acc': np.nan,
    }

    ##################
    # init optimizer #
    ##################
    optimizers = {
        'sgd': torch.optim.SGD,
        'adam': torch.optim.Adam
    }
    init_optimizer = lambda x,y: optimizers[x](
        [
            {'params': [p for n, p in y.named_parameters() if 'mlps' in n]},
            {'params': [p for n, p in y.named_parameters() if 'mlps' not in n], 'lr': hparam_dict['embed_lr']}
        ], lr=hparam_dict['lr']
    )

    if hparam_dict['pretrain_targetnet_steps'] > 0:
        net.train()
        step_iter = trange(hparam_dict['pretrain_targetnet_steps'])
        inner_optim = torch.optim.SGD(
                net.parameters(), lr=hparam_dict['inner_lr'], momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
        # node_id = train_nodes[0]
        node_id = random.choice(train_nodes)
        for step in step_iter:
            inner_optim.zero_grad(set_to_none=True)

            batch = next(iter(nodes.train_loaders[node_id]))
            batch = tuple(t.to(device, non_blocking=True) for t in batch)
            label = batch[1]
            if isinstance(net, GRUTarget):
                pred = net(batch)
            else:
                pred = net(batch[0])

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), hparam_dict['clip_grad'])
            inner_optim.step()
            
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            step_iter.set_description(
                f"Pre-train Target NN Step: {step}, Node ID: {node_id}, Loss: {loss.item():.4f}, Acc: {acc:.4f}"
            )
            # writer.add_scalar('Pre-Train/loss_target', loss, step)
            # writer.add_scalar('Pre-Train/acc_target', acc, step)
        net.eval()

    nodes.global_weights = detach_clone(OrderedDict(net.named_parameters()))
    nodes.global_buffers = detach_clone(OrderedDict(net.named_buffers()))
    if hparam_dict['pretrain_hypernet_steps'] > 0:
        net.eval()
        hnet.train()
        freeze(hnet)
        unfreeze(hnet.mlps)
        pretrain_optim = init_optimizer('adam', hnet.mlps)
        pretrain_step_iter = trange(hparam_dict['pretrain_hypernet_steps'])
        recon_loss = 0
        global_weights = sd_matrixing(nodes.global_weights, net_struct.keys())
        for pretrain_step in pretrain_step_iter:
            pretrain_optim.zero_grad(set_to_none=True)
            loss = 0
            # node_id = random.choice(train_nodes)
            node_ids = np.random.choice(train_nodes, num_clients, replace=False)
            # calculating phi gradient
            z = hnet.encode()
            if isinstance(hnet, GAEHyper):
                if hparam_dict['random_walk_length'] > 0:
                    node_id = random.choice(node_ids)
                    sample_nodes = sparse_adj.random_walk(torch.tensor([node_id], device=device, dtype=torch.long), hparam_dict['random_walk_length']).unique()
                    sample_edge_index, _ = subgraph(sample_nodes, train_edge_index)
                    recon_loss = hnet.recon_loss(z, sample_edge_index)
                else:
                    recon_loss = hnet.recon_loss(z)
                # writer.add_scalar('Pre-Train/recon_loss', recon_loss, pretrain_step)
                loss += recon_loss * coef_recon
            weights = hnet.mlps(z[node_ids])
            # calculating phi gradient
            loss_reg = F.mse_loss(weights, global_weights.expand_as(weights), reduction='sum') / num_clients
            # writer.add_scalar('Pre-Train/loss_reg', loss_reg, pretrain_step)
            loss += loss_reg * coef_hyper
            loss.backward()
            torch.nn.utils.clip_grad_norm_(hnet.parameters(), hparam_dict['clip_grad'])
            pretrain_optim.step()
            pretrain_step_iter.set_description(
                f"Pre-train HN Step: {pretrain_step}, Loss_reg: {loss_reg:.6f}, Loss_recon: {recon_loss:.4f}"
            )
            if loss_reg.item() < hparam_dict['pretrain_hypernet_threshold']:
                break
        pretrain_optim.zero_grad(set_to_none=True)
        del pretrain_optim
        logging.info(f"Pre-train HN Step: {pretrain_step}, Loss_reg: {loss_reg:.6f}, Loss_recon: {recon_loss:.4f}")
        unfreeze(hnet)
    ################
    # init metrics #
    ################
    
    if isinstance(hnet, GAEHyper):
        with torch.no_grad():
            hnet.eval()
            z = hnet.encode()
            recon_loss = hnet.recon_loss(z)
            hnet.train()
    else:
        recon_loss = np.nan
    loss_GH = np.nan
    loss_norm = 1
    # for node_id in train_nodes:
    #     nodes.weights[node_id] = detach_clone(nodes.global_weights)
    #     nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
    steps = hparam_dict['num_steps']
    step_iter = trange(steps, position=0)
    optimizer = init_optimizer(hparam_dict['optim'], hnet)
    inner_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['inner_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['inner_lr_min'])
    inner_lr_scheduler.step()
    evaluator = partial(evaluate, nodes, hnet, net, criteria, device, hparam_dict['eval_batch_count'], hparam_dict['eval_finetune_epochs'], hparam_dict['eval_finetune_lr'], hparam_dict['eval_finetune_wd'])
    for step in step_iter:
        if step % hparam_dict['eval_every'] == 0:
            results = eval_model(evaluator, train_nodes, adv_node_list)
            write_results(writer, 'Train', results, step)
        hnet.eval()
        # select client at random
        node_ids = client_sample_rng.choice(train_nodes, num_clients, replace=False)
        logging.debug(f"node_ids: {node_ids}")
        # produce & load local network weights
        with torch.no_grad():
            node_zs = hnet.encode(torch.from_numpy(node_ids).long().to(device))
        # node_iter = tqdm(zip(node_ids, node_zs), total=num_clients, position=1, leave=False)
        node_iter = zip(node_ids, node_zs)
        node_results = []
        for node_id, node_z in node_iter:
            logging.debug(f"Generate weights for client {node_id}.")
            # NOTE: Local buffer
            if node_id not in nodes.buffers or not hparam_dict['local_buffer']:
                buffers = detach_clone(nodes.global_buffers)

            if not nodes.joined[node_id]:
                weights = {}
            else:
                with torch.no_grad():
                    weights = hnet.generate_weights(node_z)
            nodes.joined[node_id] = True
            # NOTE: change the pointer of net params to nodes.weights, so the optimizer will directly update nodes.weights and there is no need to detach and clone.  Use global weights for the weights not generated by hypernetworks
            for n,p in net.named_parameters():
                if n not in weights:
                    weights[n] = nodes.global_weights[n].detach().clone()
                p.data = weights[n]
            for n,p in net.named_buffers():
                if n not in buffers:
                    buffers[n] = nodes.global_buffers[n].detach().clone()
                p.data = buffers[n]

            # NOTE: evaluation on sent model
            logging.debug(f"Eval test set on client {node_id}.")
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            # NOTE: init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            net.train()
            # NOTE: inner updates -> obtaining theta_tilda
            batch_count = 0
            if hparam_dict['inner_epochs'] is None:
                inner_steps = hparam_dict['inner_steps']
                inner_epochs = inner_steps // len(nodes.train_loaders[node_id]) + 1
            else:
                inner_epochs = hparam_dict['inner_epochs']
                inner_steps = inner_epochs * len(nodes.train_loaders[node_id])
            for e in range(inner_epochs):
                for batch in nodes.train_loaders[node_id]:
                    batch_count += 1
                    if batch_count > inner_steps:
                        break
                    inner_optim.zero_grad(set_to_none=True)
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    if node_id in adv_node_list:
                        batch = deform(batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), hparam_dict['clip_grad'])
                    inner_optim.step()
            inner_optim.zero_grad(set_to_none=True)
            del inner_optim
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            nodes.weights[node_id] = weights
            nodes.buffers[node_id] = buffers
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
            # node_iter.set_description(
            #     f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
            # )
        prvs_loss, node_loss, prvs_acc, acc = np.mean(node_results, axis=0)
        with torch.no_grad():
            for n,p in net.named_parameters():
                nodes.global_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)

            # NOTE: do not update global_buffer when using local buffer.
            if not hparam_dict['local_buffer']:
                for n,p in net.named_buffers():
                    nodes.global_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
                nodes.buffers = {}

        # update hnet weights
        logging.debug(f"Update HyperNetwork with gradient received from client {node_id}.")
        hnet.train()
        node_weights = torch.stack([sd_matrixing(nodes.weights[node_id], net_struct.keys()) for node_id in node_ids])
        global_weights = sd_matrixing(nodes.global_weights, net_struct.keys())
        with torch.no_grad():
            loss_norm = F.mse_loss(node_weights, global_weights.expand_as(node_weights), reduction='sum') / num_clients
        for i in range(hparam_dict['server_steps']):
            optimizer.zero_grad(set_to_none=True)
            loss = 0
            z = hnet.encode()
            if isinstance(hnet, GAEHyper):
                if hparam_dict['random_walk_length'] > 0:
                    sample_nodes = sparse_adj.random_walk(torch.tensor([node_id], device=device, dtype=torch.long), hparam_dict['random_walk_length']).unique()
                    sample_edge_index, _ = subgraph(sample_nodes, train_edge_index)
                    recon_loss = hnet.recon_loss(z, sample_edge_index)
                else:
                    recon_loss = hnet.recon_loss(z)
                loss += recon_loss * coef_recon
            
            weights = hnet.mlps(z)
            loss_GH = F.mse_loss(weights[node_ids], node_weights, reduction='sum') / num_clients
            loss_reg = F.mse_loss(weights, global_weights.expand_as(weights), reduction='sum') / num_train
            loss += (loss_GH * coef_hyper + loss_reg * coef_reg)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(hnet.parameters(), hparam_dict['clip_grad'])
            optimizer.step()
            if i == 0:
                org_recon_loss = recon_loss
                org_loss_GH = loss_GH
                org_loss_reg = loss_reg
            step_iter.set_description(
                f"Step: {step}, Loss: {prvs_loss:.4f} -> {node_loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}, Loss_GH: {org_loss_GH:.6f} -> {loss_GH:.6f}, Loss_reg: {org_loss_reg:.6f} -> {loss_reg:.6f}, Loss_norm: {loss_norm:.6f}, Loss_recon: {org_recon_loss:.4f} -> {recon_loss:.4f}, Server_Step: {i}"
            )
        nodes.weights = {}
        hnet.eval()
        writer.add_scalar('Train/org_recon_loss', org_recon_loss, step)
        writer.add_scalar('Train/org_loss_GH', org_loss_GH, step)
        writer.add_scalar('Train/org_loss_reg', org_loss_reg, step)
        writer.add_scalar('Train/recon_loss', recon_loss, step)
        writer.add_scalar('Train/loss_GH', loss_GH, step)
        writer.add_scalar('Train/loss_reg', loss_reg, step)
        writer.add_scalar('Train/loss_norm', loss_norm, step)
        writer.add_scalar('Train/prvs_loss', prvs_loss, step)
        writer.add_scalar('Train/node_loss', node_loss, step)
        writer.add_scalar('Train/prvs_acc', prvs_acc, step)
        writer.add_scalar('Train/node_acc', acc, step)
        if step % hparam_dict['eval_every'] == 0:
            logging.info(
                f"Step: {step}, Loss: {prvs_loss:.4f} -> {node_loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}, Loss_GH: {org_loss_GH:.6f} -> {loss_GH:.6f}, Loss_reg: {org_loss_reg:.6f} -> {loss_reg:.6f}, Loss_norm: {loss_norm:.6f}, Loss_recon: {org_recon_loss:.4f} -> {recon_loss:.4f}, Server_Step: {i}"
            )
        # NOTE: only update inner learning rate when inner_lr_scheduler is enabled.
        if hparam_dict['inner_lr_scheduler']:
            inner_lr_scheduler.step()
        
    step += 1
    logging.info(f"Step: {step}, Loss_GH: {loss_GH:.6f}, Loss_reg: {loss_reg:.6f}, Loss_recon: {recon_loss:.4f}")
    results = eval_model(evaluator, train_nodes, adv_node_list)
    write_results(writer, 'Train', results, step)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'hnet': hnet.state_dict(),
        'net': net.state_dict(),
        'optim': optimizer.state_dict(),
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'final.ckpt'))
    # Average neighbors' embeddings
    if hparam_dict['avg_emb']:
        with torch.no_grad():
            for node in unseen_nodes:
                neighbors = nx.neighbors(G, node)
                neighbors = list(set(neighbors)&set(train_nodes))
                hnet.embeddings.weight.data[node] = hnet.embeddings.weight.data[neighbors].mean(0)

    # Change graph data in hypernetwork
    hnet.x = torch.tensor(all_nodes, device=device, dtype=torch.long)
    hnet.edge_index = data.edge_index.to(device)
    sparse_adj = SparseTensor(row=data.edge_index[0], col=data.edge_index[1], sparse_sizes=(num_nodes,num_nodes)).t()
    
    # If freeze, the optimizer will only update the embedding layer.
    if hparam_dict['freeze_hypernet']:
        freeze(hnet.mlps)
    
    # Log Zero-Shot performence
    step = 0
    results = eval_model(evaluator, unseen_nodes)
    write_results(writer, 'Unseen_0', results, step)
    best_metric_dict = {
        'Unseen_0/test_avg_loss': results['test_avg_loss'],
        'Unseen_0/test_max_loss': results['test_max_loss'],
        'Unseen_0/test_min_loss': results['test_min_loss'],
        'Unseen_0/test_std_loss': results['test_std_loss']
    }
    if isinstance(criteria, torch.nn.CrossEntropyLoss):
        best_metric_dict.update({
            'Unseen_0/test_avg_acc': results['test_avg_acc'],
            'Unseen_0/test_max_acc': results['test_max_acc'],
            'Unseen_0/test_min_acc': results['test_min_acc'],
            'Unseen_0/test_std_acc': results['test_std_acc']
        })
    metric_dict.update(best_metric_dict)
    exp, ssi, sei = hparams(hparam_dict, metric_dict)
    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k,v in best_metric_dict.items():
        writer.add_scalar(k, v, 0)

    num_clients = min(len(unseen_nodes), hparam_dict['num_clients'])
    step_iter = trange(hparam_dict['eval_unseen_steps'])
    for step in step_iter:
        hnet.eval()
        if hparam_dict['eval_unseen_steps'] == 1:
            node_ids = unseen_nodes
        else:
            node_ids = client_sample_rng.choice(unseen_nodes, num_clients, replace=False)
        node_zs = hnet.encode(torch.from_numpy(node_ids).long().to(device))
        for node_id, node_z in zip(node_ids, node_zs):
            logging.debug(f"Generate weights for client {node_id}.")
            nodes.weights[node_id] = hnet.generate_weights(node_z)
            for n,p in net.named_parameters():
                if n in nodes.weights[node_id]:
                    p.data = nodes.weights[node_id][n].data

            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            logging.debug(f"Eval test set on client {node_id}.")
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)
                net.train()

            # inner updates -> obtaining theta_tilda
            logging.debug(f"Start local training on client {node_id}.")
            for e in range(hparam_dict['inner_epochs']):
                for batch_count, batch in enumerate(nodes.train_loaders[node_id]):
                    net.train()
                    inner_optim.zero_grad(set_to_none=True)

                    # batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), hparam_dict['clip_grad'])
                    inner_optim.step()
            inner_optim.zero_grad(set_to_none=True)
            del inner_optim
            nodes.weights[node_id] = detach_clone(net.state_dict())

        # update hnet weights
        logging.debug(f"Update HyperNetwork with gradient received from client {node_id}.")
        hnet.train()
        node_weights = torch.stack([sd_matrixing(nodes.weights[node_id], net_struct.keys()) for node_id in node_ids])
        global_weights = sd_matrixing(nodes.global_weights, net_struct.keys())
        for i in range(hparam_dict['server_steps']):
            optimizer.zero_grad(set_to_none=True)
            loss = 0
            z = hnet.encode()
            if isinstance(hnet, GAEHyper):
                if hparam_dict['random_walk_length'] > 0:
                    sample_nodes = sparse_adj.random_walk(torch.tensor([node_id], device=device, dtype=torch.long), hparam_dict['random_walk_length']).unique()
                    sample_edge_index, _ = subgraph(sample_nodes, hnet.edge_index)
                    recon_loss = hnet.recon_loss(z, sample_edge_index)
                else:
                    recon_loss = hnet.recon_loss(z)
                writer.add_scalar('Unseen/recon_loss', recon_loss, step)
                loss += recon_loss * coef_recon
            
            weights = hnet.mlps(z)
            loss_GH = F.mse_loss(weights[node_ids], node_weights, reduction='sum') / num_clients
            loss_reg = F.mse_loss(weights, global_weights.expand_as(weights), reduction='sum') / num_train
            loss += loss_GH * coef_hyper + loss_reg * coef_reg
            loss.backward()
            torch.nn.utils.clip_grad_norm_(hnet.parameters(), hparam_dict['clip_grad'])
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        hnet.eval()
        writer.add_scalar('Unseen/loss_GH', loss_GH, step)
        writer.add_scalar('Unseen/loss_reg', loss_reg, step)
        logging.debug(f"Gradient applied.")
        step_iter.set_description(
            f"Unseen Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}, Loss_GH: {loss_GH:.6f}, Loss_reg: {loss_reg:.6f}, Loss_recon: {recon_loss:.4f}"
        )

        results = eval_model(evaluator, unseen_nodes)
        write_results(writer, 'Unseen', results, step+1)
    writer.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Federated Hypernetwork with Lookahead experiment"
    )

    #############################
    #       Dataset Args        #
    #############################

    parser.add_argument("--data-name", type=str, default="DG15")
    parser.add_argument("--data-dir", type=str, default="data", help="dir name of dataset")

    ##################################
    #       Optimization args        #
    ##################################

    parser.add_argument("--pretrain-targetnet-steps", type=int, default=0)
    parser.add_argument("--pretrain-hypernet-steps", type=int, default=500)
    parser.add_argument("--pretrain-hypernet-threshold", type=float, default=0.001)
    parser.add_argument("--num-steps", type=int, default=800)
    parser.add_argument("--eval-unseen-steps", type=int, default=0)
    parser.add_argument("--optim", type=str, default='adam', choices=['adam', 'sgd'], help="optimizer")
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--inner-epochs", type=int, default=None, help="number of inner epochs")
    parser.add_argument("--inner-steps", type=int, default=50, help="number of inner epochs")
    parser.add_argument("--server-steps", type=int, default=5, help="number of server steps")
    parser.add_argument("--last-layer-only", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument("--num-clients", type=int, default=5)
    parser.add_argument("--local-buffer", action='store_true', default=False, help="Clients will use its own buffer (of batch norm layers).")
    parser.add_argument('--aug-method', type=str, default='label_flip', help='type of augmentation for training')
    parser.add_argument("--adv-ratio", type=float, default=.0, help="Ratio of adversarial clients participated FL training.")
    parser.add_argument("--sigma", type=float, default=.0, help="noise hyperparameter")


    ################################
    #       Model Prop args        #
    ################################
    parser.add_argument("--n-hidden", type=int, default=3, help="num. hidden layers")
    parser.add_argument("--inner-lr", type=float, default=1e-2, help="learning rate for inner optimizer")
    parser.add_argument("--inner-lr-min", type=float, default=1e-3, help="Min learning rate for inner optimizer")
    parser.add_argument("--inner-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument("--lr-min", type=float, default=1e-4, help="Min learning rate")
    parser.add_argument("--wd", type=float, default=5e-5, help="weight decay")
    parser.add_argument("--embed-dim", type=int, default=100, help="embedding dim")
    parser.add_argument("--embed-lr", type=float, default=None, help="embedding learning rate")
    parser.add_argument("--hyper-arch", type=str, default='mlp', choices=['mlp', 'gnn','gae'], help="architecture of hypernetwork")
    parser.add_argument("--hyper-hid", type=int, default=100, help="hypernet hidden dim")
    parser.add_argument("--hyper-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for hyper optimizer")
    parser.add_argument("--server-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for hyper optimizer at each round")
    parser.add_argument("--coef-recon", type=float, default=1.0, help="coefficient of reconstruction loss")
    parser.add_argument("--coef-hyper", type=float, default=1.0, help="coefficient of hypernetwork gradient")
    parser.add_argument("--coef-reg", type=float, default=1.0, help="coefficient of hypernetwork reguralizer")
    parser.add_argument("--target-arch", type=str, default=None, help="architecture of target network")
    parser.add_argument("--clip-grad", type=float, default=1, help="clip gradiant before apply")

    ################################
    #     Generalization args      #
    ################################   
    parser.add_argument("--avg-emb", action='store_true', default=False, help="Initialize new client's embedding by averaging its neighbors' embeddings.")
    parser.add_argument("--pretrain-emb-steps", type=int, default=0, help="Pre-train client's embedding with GAE reconstruction loss.")
    parser.add_argument("--unseen-pretrain-emb-steps", type=int, default=0, help="Pre-train unseen client's embedding with GAE reconstruction loss.")
    parser.add_argument("--freeze-hypernet", action='store_true', default=False, help="Freeze hypernetwork's parameters and only train client's embedding.")
    parser.add_argument("--random-walk-length", type=int, default=0, help="Sample subgraph by random walk.")
    parser.add_argument("--freeze-embed", action='store_true', default=False, help="Freeze client's embedding and only train hypernetwork's parameters.")

    #############################
    #       General args        #
    #############################
    parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
    parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
    parser.add_argument("--eval-batch-size", type=int, default=64, help="eval batch size")
    parser.add_argument("--eval-batch-count", type=int, default=999999, help="eval num batch")
    parser.add_argument("--eval-finetune-epochs", type=int, default=0)
    parser.add_argument("--eval-finetune-lr", type=float, default=1e-3, help="learning rate for eval finetune optimizer")
    parser.add_argument("--eval-finetune-wd", type=float, default=5e-5, help="weight decay for eval finetune optimizer")
    parser.add_argument("--num-workers", type=int, default=0, help="num workers for dataloader")
    parser.add_argument("--run-name", type=str, default=None, help="name for output run file")
    parser.add_argument("--run-dir", type=str, default='runs', help="dir path for output run file")
    parser.add_argument("--task-name", type=str, default=None, help="ClearML task name")
    parser.add_argument("--log-level", type=str, default='INFO', help="log level")
    parser.add_argument("--seed", type=int, default=42, help="seed value")

    args = parser.parse_args()
    args.embed_lr = args.embed_lr if args.embed_lr is not None else args.lr
    args.coef_recon = args.coef_recon if args.hyper_arch == 'gae' else 0
    if args.target_arch is None:
        if 'METR-LA' in args.data_name or 'PEMS-BAY' in args.data_name:
            args.target_arch = 'GRUSeq2Seq'
        elif args.data_name == 'CompCars':
            args.target_arch = 'resnet18'
        else:
            args.target_arch = 'mlp'
    assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"

    if args.run_name is None:
        import time
        args.run_name = f"{time.strftime('%Y%m%d_%H-%M-%S')}_{args.data_name}_pfedhn_{args.hyper_arch}"
        if args.adv_ratio > 0:
            if args.aug_method == 'label_flip':
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}"
            elif args.aug_method != 'nominal' and args.sigma > 0:
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}_{args.sigma}"
            else:
                args.adv_ratio = 0
    args.run_name += f'_seed_{args.seed}'
    hparam_dict = vars(args)
    try:
        from clearml import Task
        if args.task_name is None:
            hparam_dict['task_name'] = args.run_name
        task = Task.init(project_name='Panacea NeurIPS 2023', task_name=hparam_dict['task_name'])
    except:
        task = None

    try:
        train(hparam_dict)
    except Exception:
        print(traceback.format_exc())
    if task is not None:
        task.close()
