import os
import argparse
import json
import logging
from platform import node
import random
from collections import OrderedDict, defaultdict
from pathlib import Path
import copy
import traceback
from functools import partial
from typing import Tuple, Union

import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
from tqdm.auto import trange, tqdm
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, from_networkx
from torch_sparse import SparseTensor
import torchvision

from models import CNNTarget, GAEHyper, GNNHyper, GRUTarget, MLPTarget, MLPHyper, MLPInnerProductDecoder
from node import BaseNodes
from utils import get_device, set_logger, set_seed, freeze, unfreeze, detach_clone, CustomCosineAnnealingLR, write_results, eval_model
from deform import DeformWrapper


def evaluate(nodes: BaseNodes, net, criteria, device, eval_batch_count, finetune_epochs, finetune_lr, finetune_wd, node_list, split):
    net.eval()
    results = defaultdict(lambda: defaultdict(list))

    node_iter = tqdm(node_list, position=-1, leave=False)
    for node_id in node_iter:  # iterating over nodes
        for n,p in net.named_parameters():
            p.data = nodes.global_weights[n].detach().clone()
        for n,p in net.named_buffers():
            p.data = nodes.global_buffers[n].detach().clone()

        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]
        with torch.no_grad():
            for batch_count, batch in enumerate(curr_data):
                if batch_count >= eval_batch_count:
                    break
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                running_loss += criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    running_correct += pred.argmax(1).eq(label).sum()
                    running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f}"
            if isinstance(criteria, torch.nn.CrossEntropyLoss):
                results[node_id]['correct'] = running_correct
                results[node_id]['total'] = running_samples
                results[node_id]['acc'] = running_correct / running_samples
                tqdm_str += f", AVG Acc: {results[node_id]['acc']}"
        if finetune_epochs > 0:
            net.train()
            inner_optim = torch.optim.SGD(net.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=finetune_wd)
            for e in range(finetune_epochs):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            net.eval()
            running_loss, running_correct, running_samples = 0., 0., 0.
            with torch.no_grad():
                for batch_count, batch in enumerate(curr_data):
                    if batch_count >= eval_batch_count:
                        break
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    running_loss += criteria(pred, label)
                    if isinstance(criteria, torch.nn.CrossEntropyLoss):
                        running_correct += pred.argmax(1).eq(label).sum()
                        running_samples += len(label)

                results[node_id]['finetune_loss'] = running_loss / (batch_count + 1)
                tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f} -> {results[node_id]['finetune_loss']:.4f}"
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    results[node_id]['finetune_correct'] = running_correct
                    results[node_id]['finetune_total'] = running_samples
                    results[node_id]['finetune_acc'] = running_correct / running_samples
                    tqdm_str += f", AVG Acc: {results[node_id]['acc']:.4f} -> {results[node_id]['finetune_acc']:.4f}"
        node_iter.set_description(tqdm_str)
    return results

class BatchIter(object):
    def __init__(self, dataloader, device) -> None:
        self.dataloader = dataloader
        self.device = device
        self.iter_dataloader = iter(dataloader)
    def get_data_batch(self):
        try:
            x, y = next(self.iter_dataloader)
        except StopIteration:
            self.iter_dataloader = iter(self.dataloader)
            x, y = next(self.iter_dataloader)
        return x.to(device), y.to(device)

def compute_grad(
    model: torch.nn.Module,
    data_batch: Tuple[torch.Tensor, torch.Tensor],
    criterion,
    v: Union[Tuple[torch.Tensor, ...], None] = None,
    second_order_grads=False,
):
    x, y = data_batch
    if isinstance(model, GRUTarget):
        x = data_batch
    if second_order_grads:
        frz_model_params = copy.deepcopy(model.state_dict())
        delta = 1e-3
        dummy_model_params_1 = OrderedDict()
        dummy_model_params_2 = OrderedDict()
        with torch.no_grad():
            for (layer_name, param), grad in zip(model.named_parameters(), v):
                dummy_model_params_1.update({layer_name: param + delta * grad})
                dummy_model_params_2.update({layer_name: param - delta * grad})

        model.load_state_dict(dummy_model_params_1, strict=False)
        logit_1 = model(x)
        loss_1 = criterion(logit_1, y)
        grads_1 = torch.autograd.grad(loss_1, model.parameters())

        model.load_state_dict(dummy_model_params_2, strict=False)
        logit_2 = model(x)
        loss_2 = criterion(logit_2, y)
        grads_2 = torch.autograd.grad(loss_2, model.parameters())

        model.load_state_dict(frz_model_params)

        grads = []
        with torch.no_grad():
            for g1, g2 in zip(grads_1, grads_2):
                grads.append((g1 - g2) / (2 * delta))
        return grads

    else:
        logit = model(x)
        loss = criterion(logit, y)
        grads = torch.autograd.grad(loss, model.parameters())
        return grads

def train(hparam_dict: dict) -> None:

    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.tensorboard.summary import hparams
    writer = SummaryWriter(hparam_dict['run_path'])
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(hparam_dict['data_name'], batch_size=hparam_dict['batch_size'], eval_batch_size=hparam_dict['eval_batch_size'], num_workers=hparam_dict['num_workers'], device=device, data_dir=hparam_dict['data_dir'])
    num_nodes = nodes.n_nodes
    all_nodes = np.arange(num_nodes)
    client_sample_rng = np.random.default_rng(hparam_dict['seed'])
    if 'METR-LA-' in hparam_dict['data_name'] or 'PEMS-BAY-' in hparam_dict['data_name']:
        new_order = all_nodes
        train_nodes = nodes.train_nodes
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = list(set(all_nodes) - set(train_nodes))
    else:
        new_order = nodes.shuffle()
        num_train = int(0.8*num_nodes)
        train_nodes = all_nodes[:num_train]
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = all_nodes[num_train:]
    logging.info(f"Num of nodes: {len(all_nodes)}")
    logging.info(f"Num of train nodes: {len(train_nodes)}")
    logging.info(f"adv_node_list: {adv_node_list}")
    if 'METR-LA' in hparam_dict['data_name'] or 'PEMS-BAY' in hparam_dict['data_name']:
        net = GRUTarget(nodes.data_dim, out_dim=nodes.label_dim)
    elif hparam_dict['data_name'] == 'CompCars':
        CNN = getattr(torchvision.models, hparam_dict['target_arch'])
        net = CNN(pretrained=True)
        final_layer_name, final_layer = list(net.named_children())[-1]
        setattr(net, final_layer_name, torch.nn.Linear(final_layer.in_features, nodes.label_dim))
    else:
        net = MLPTarget(nodes.data_dim, out_dim=nodes.label_dim)
    net.to(device)

    deform = DeformWrapper(device, hparam_dict['aug_method'], nodes.label_dim, hparam_dict['sigma'])

    ##################
    # init optimizer #
    ##################
    if hparam_dict['data_name'] in ['DG15', 'DG60', 'DG60_TRUNCNORM', 'CompCars']:
        criteria = torch.nn.CrossEntropyLoss()
    elif hparam_dict['data_name'] in ['TPT48', 'TPT48_TRUNCNORM', 'METR-LA', 'METR-LA-0.25', 'METR-LA-0.5', 'METR-LA-0.75', 'METR-LA-0.05', 'METR-LA-0.9', 'PEMS-BAY', 'PEMS-BAY-0.25', 'PEMS-BAY-0.5', 'PEMS-BAY-0.75', 'PEMS-BAY-0.05', 'PEMS-BAY-0.9']:
        criteria = torch.nn.MSELoss()
    else:
        raise f"Unknown dataset {hparam_dict['data_name']}."
    
    metric_dict = {
        'Train/best_step': np.nan,
        'Train/best_test_avg_loss': np.nan,
        'Train/best_test_avg_acc': np.nan,
        'Unseen/best_step': np.nan,
        'Unseen/best_test_avg_loss': np.nan,
        'Unseen/best_test_avg_acc': np.nan,
        'Unseen_0/val_avg_loss': np.nan,
        'Unseen_0/test_avg_loss': np.nan,
        'Unseen_0/test_max_loss': np.nan,
        'Unseen_0/test_min_loss': np.nan,
        'Unseen_0/test_std_loss': np.nan,
        'Unseen_0/val_avg_acc': np.nan,
        'Unseen_0/test_avg_acc': np.nan,
        'Unseen_0/test_max_acc': np.nan,
        'Unseen_0/test_min_acc': np.nan,
        'Unseen_0/test_std_acc': np.nan,
    }

    if hparam_dict['pretrain_targetnet_steps'] > 0:
        net.train()
        step_iter = trange(hparam_dict['pretrain_targetnet_steps'])
        inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
        node_id = train_nodes[0]
        for step in step_iter:
            inner_optim.zero_grad(set_to_none=True)

            batch = next(iter(nodes.train_loaders[node_id]))
            batch = tuple(t.to(device, non_blocking=True) for t in batch)
            label = batch[1]
            if isinstance(net, GRUTarget):
                pred = net(batch)
            else:
                pred = net(batch[0])

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            inner_optim.step()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            step_iter.set_description(
                f"Pre-train Target NN Step: {step}, Node ID: {node_id}, Loss: {loss.item():.4f}, Acc: {acc.item():.4f}"
            )
            writer.add_scalar('Pre-Train/loss_target', loss, step)
            writer.add_scalar('Pre-Train/acc_target', acc, step)
        net.eval()

    ################
    # init metrics #
    ################
    nodes.global_weights = detach_clone(OrderedDict(net.named_parameters()))
    nodes.global_buffers = detach_clone(OrderedDict(net.named_buffers()))
    num_steps = hparam_dict['num_steps']
    eval_every = hparam_dict['eval_every']
    step_iter = trange(num_steps)
    num_clients = min(len(train_nodes), hparam_dict['num_clients'])
    inner_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['inner_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['inner_lr_min'])
    inner_lr_scheduler.step()
    inner_beta_scheduler = CustomCosineAnnealingLR(hparam_dict['beta'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['beta']*0.01)
    inner_beta_scheduler.step()
    evaluator = partial(evaluate, nodes, net, criteria, device, hparam_dict['eval_batch_count'], hparam_dict['eval_finetune_epochs'], hparam_dict['eval_finetune_lr'], hparam_dict['eval_finetune_wd'])
    for step in step_iter:
        if step % eval_every == 0:
            results = eval_model(evaluator, train_nodes, adv_node_list)
            write_results(writer, 'Train', results, step)

        # select client at random
        node_ids = client_sample_rng.choice(train_nodes, num_clients, replace=False)
        logging.debug(f"node_ids: {node_ids}")
        aggr_weights = OrderedDict()
        aggr_buffers = OrderedDict()
        alpha = inner_lr_scheduler.get_lr()
        beta = inner_beta_scheduler.get_lr()
        # node_iter = tqdm(node_ids, position=1, leave=False)
        node_results = []
        # produce & load local network weights
        for node_id in node_ids:
            nodes.weights[node_id] = detach_clone(nodes.global_weights)
            nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
            for n,p in net.named_parameters():
                p.data = nodes.weights[node_id][n]
            for n,p in net.named_buffers():
                p.data = nodes.buffers[node_id][n]

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = (pred.argmax(1).eq(label).sum() / len(label)).item()
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            batch_iter = BatchIter(nodes.train_loaders[node_id], device)
            if hparam_dict['inner_epochs'] is None:
                local_epoch = hparam_dict['inner_steps'] // 3
            else:
                local_epoch = (hparam_dict['inner_epochs'] * len(nodes.train_loaders[node_id])) // 3
            logging.debug(f"Start local training on client {node_id}.")
            for e in range(local_epoch):
                temp_model = copy.deepcopy(net)
                data_batch_1 = batch_iter.get_data_batch()
                grads = compute_grad(temp_model, data_batch_1, criteria)
                for param, grad in zip(temp_model.parameters(), grads):
                    param.data.sub_(alpha * grad)

                data_batch_2 = batch_iter.get_data_batch()
                grads_1st = compute_grad(temp_model, data_batch_2, criteria)

                data_batch_3 = batch_iter.get_data_batch()

                grads_2nd = compute_grad(
                    net, data_batch_3, criteria, v=grads_1st, second_order_grads=True
                )
                # NOTE: Go check https://github.com/KarhouTam/Per-FedAvg/issues/2 if you confuse about the model update.
                for param, grad1, grad2 in zip(
                    net.parameters(), grads_1st, grads_2nd
                ):
                    param.data.sub_(beta * grad1 - beta * alpha * grad2)
            
            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = (pred.argmax(1).eq(label).sum() / len(label)).item()
                else:
                    acc = 0
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc, acc]]
            # node_iter.set_description(
            #     f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
            # )
        for n,p in net.named_parameters():
            aggr_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)
        for n,p in net.named_buffers():
            aggr_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
        nodes.global_weights.update(aggr_weights)
        nodes.global_buffers.update(aggr_buffers)
        if hparam_dict['inner_lr_scheduler']:
            inner_lr_scheduler.step()
            inner_beta_scheduler.step()
        prvs_loss, node_loss, prvs_acc, acc = np.mean(node_results, axis=0)
        step_iter.set_description(
            f"Step: {step}, Loss: {prvs_loss:.4f} -> {node_loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
        )

    # Log Final performance
    step += 1
    results = eval_model(evaluator, train_nodes, adv_node_list)
    write_results(writer, 'Train', results, step)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'final.ckpt'))


    # Evaluate on unseen nodes
    step = 0
    results = eval_model(evaluator, unseen_nodes)
    write_results(writer, 'Unseen_0', results, step)
    best_metric_dict = {
        'Unseen_0/test_avg_loss': results['test_avg_loss'],
        'Unseen_0/test_max_loss': results['test_max_loss'],
        'Unseen_0/test_min_loss': results['test_min_loss'],
        'Unseen_0/test_std_loss': results['test_std_loss']
    }
    if isinstance(criteria, torch.nn.CrossEntropyLoss):
        best_metric_dict.update({
            'Unseen_0/test_avg_acc': results['test_avg_acc'],
            'Unseen_0/test_max_acc': results['test_max_acc'],
            'Unseen_0/test_min_acc': results['test_min_acc'],
            'Unseen_0/test_std_acc': results['test_std_acc']
        })
    metric_dict.update(best_metric_dict)
    exp, ssi, sei = hparams(hparam_dict, metric_dict)
    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k,v in best_metric_dict.items():
        writer.add_scalar(k, v, step)

    step_iter = trange(hparam_dict['eval_unseen_steps'], position=0)
    num_clients = min(len(unseen_nodes), hparam_dict['num_clients'])
    for step in step_iter:
        # select client at random
        node_ids = np.random.choice(unseen_nodes, num_clients, replace=False)
        aggr_weights = OrderedDict([[n, torch.zeros_like(p.data)] for n, p in net.named_parameters()])
        node_iter = tqdm(node_ids, position=1, leave=False)
        node_results = []
        # produce & load local network weights
        for node_id in node_iter:
            for n,p in net.named_parameters():
                p.data = nodes.global_weights[n].detach().clone()
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
            inner_beta_optim = torch.optim.SGD(
                net.parameters(), lr=hparam_dict['beta']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            for e in range(hparam_dict['inner_epochs']):
                 # inner updates -> obtaining theta_tilda
                logging.debug(f"Start local training on client {node_id}.")
                curr_iter = iter(nodes.train_loaders[node_id])
                for i in range(len(nodes.train_loaders[node_id])//2):
                    temp_model = copy.deepcopy(list(net.parameters()))
                    inner_optim.zero_grad(set_to_none=True)
                    batch = next(curr_iter)
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
                    inner_beta_optim.zero_grad(set_to_none=True)
                    batch = next(curr_iter)
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    for old_p, new_p in zip(net.parameters(), temp_model):
                        old_p.data = new_p.detach().clone()
                    inner_beta_optim.step()
                
            for n,p in net.named_parameters():
                aggr_weights[n].data += p.detach().clone() / hparam_dict['num_clients']
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            node_iter.set_description(
                f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
            )
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
        for n,p in net.named_parameters():
            nodes.global_weights[n].data = aggr_weights[n].data

        results = eval_model(evaluator, unseen_nodes)
        write_results(writer, 'Unseen', results, step+1)
        step_iter.set_description(f"Unseen Step: {step}, AVG Loss: {results['test_avg_loss']:.4f},  AVG Acc: {results['test_all_acc']:.4f}")
    writer.close()
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'unseen.ckpt'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Federated Hypernetwork with Lookahead experiment"
    )

    #############################
    #       Dataset Args        #
    #############################

    parser.add_argument("--data-name", type=str, default="DG15")
    parser.add_argument("--data-dir", type=str, default="data", help="dir name of dataset")

    ##################################
    #       Optimization args        #
    ##################################

    parser.add_argument("--pretrain-targetnet-steps", type=int, default=0)
    parser.add_argument("--num-steps", type=int, default=800)
    parser.add_argument("--eval-unseen-steps", type=int, default=0)
    parser.add_argument("--optim", type=str, default='sgd', choices=['adam', 'sgd'], help="optimizer")
    parser.add_argument("--num-clients", type=int, default=5)
    parser.add_argument("--last-layer-only", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument('--aug-method', type=str, default='label_flip', help='type of augmentation for training')
    parser.add_argument("--adv-ratio", type=float, default=.0, help="Ratio of adversarial clients participated FL training.")
    parser.add_argument("--sigma", type=float, default=.0, help="noise hyperparameter")

    ################################
    #       Model Prop args        #
    ################################
    parser.add_argument("--n-hidden", type=int, default=3, help="num. hidden layers")
    parser.add_argument("--inner-lr", type=float, default=1e-2, help="learning rate for inner optimizer")
    parser.add_argument("--inner-lr-min", type=float, default=1e-3, help="Min learning rate for inner optimizer")
    parser.add_argument("--beta", type=float, default=0.001)
    parser.add_argument("--inner-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
    parser.add_argument("--inner-epochs", type=int, default=None)
    parser.add_argument("--inner-steps", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--target-arch", type=str, default=None, help="architecture of target network")

    #############################
    #       General args        #
    #############################
    parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
    parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
    parser.add_argument("--eval-batch-size", type=int, default=64, help="eval batch size")
    parser.add_argument("--eval-batch-count", type=int, default=999999, help="eval num batch")
    parser.add_argument("--eval-finetune-epochs", type=int, default=0)
    parser.add_argument("--eval-finetune-lr", type=float, default=1e-3, help="learning rate for eval finetune optimizer")
    parser.add_argument("--eval-finetune-wd", type=float, default=5e-5, help="weight decay for eval finetune optimizer")
    parser.add_argument("--num-workers", type=int, default=0, help="num workers for dataloader")
    parser.add_argument("--run-name", type=str, default=None, help="name for output run file")
    parser.add_argument("--run-dir", type=str, default='runs', help="dir path for output run file")
    parser.add_argument("--task-name", type=str, default=None, help="ClearML task name")
    parser.add_argument("--log-level", type=str, default='INFO', help="log level")
    parser.add_argument("--seed", type=int, default=42, help="seed value")

    args = parser.parse_args()
    if args.target_arch is None:
        if 'METR-LA' in args.data_name or 'PEMS-BAY' in args.data_name:
            args.target_arch = 'GRUSeq2Seq'
        elif args.data_name == 'CompCars':
            args.target_arch = 'resnet18'
        else:
            args.target_arch = 'mlp'
    assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"

    if args.run_name is None:
        import time
        args.run_name = f"{time.strftime('%Y%m%d_%H-%M-%S')}_{args.data_name}_per-fedavg"
        if args.adv_ratio > 0:
            if args.aug_method == 'label_flip':
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}"
            elif args.aug_method != 'nominal' and args.sigma > 0:
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}_{args.sigma}"
            else:
                args.adv_ratio = 0
    args.run_name += f'_seed_{args.seed}'
    hparam_dict = vars(args)
    hparam_dict['algorithm'] = 'per-fedavg'
    hparam_dict['run_path'] = os.path.join(args.run_dir, args.run_name)
    if not os.path.exists(hparam_dict['run_path']):
        os.makedirs(hparam_dict['run_path'], exist_ok=True)
    set_logger(os.path.join(hparam_dict['run_path'],'output.log'), level=args.log_level)
    set_seed(args.seed)
    device = get_device(gpus=args.gpu)
    logging.info(json.dumps(hparam_dict, sort_keys=True, indent=4))
    try:
        from clearml import Task
        if args.task_name is None:
            hparam_dict['task_name'] = args.run_name
        task = Task.init(project_name='Panacea NeurIPS 2023', task_name=hparam_dict['task_name'])
    except:
        task = None

    try:
        train(hparam_dict)
    except Exception:
        print(traceback.format_exc())
    if task is not None:
        task.close()
