import os
import argparse
import json
import logging
from platform import node
import random
from collections import OrderedDict, defaultdict
from pathlib import Path
import copy
import traceback
from functools import partial

import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
from tqdm.auto import trange, tqdm
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, from_networkx
from torch_sparse import SparseTensor
import torchvision

from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

from models import CNNTarget, GAEHyper, GNNHyper, GRUTarget, MLPTarget, MLPHyper, MLPInnerProductDecoder
from node import BaseNodes
from utils import get_device, set_logger, set_seed, freeze, unfreeze, detach_clone, CustomCosineAnnealingLR, write_results, eval_model
from deform import DeformWrapper


def evaluate(nodes: BaseNodes, net, criteria, device, eval_batch_count, finetune_epochs, finetune_lr, finetune_wd, node_list, split):
    net.eval()
    results = defaultdict(lambda: defaultdict(list))

    node_iter = tqdm(node_list, position=-1, leave=False)
    for node_id in node_iter:  # iterating over nodes

        running_loss, running_correct, running_samples = 0., 0., 0.
        if split == 'test':
            curr_data = nodes.test_loaders[node_id]
        elif split == 'val':
            curr_data = nodes.val_loaders[node_id]
        else:
            curr_data = nodes.train_loaders[node_id]
        weights = nodes.weights.get(node_id, nodes.global_weights)
        for n,p in net.named_parameters():
            p.data = weights.get(n, nodes.global_weights[n]).detach().clone()
        for n,p in net.named_buffers():
            p.data = nodes.global_buffers[n].detach().clone()
        with torch.no_grad():
            for batch_count, batch in enumerate(curr_data):
                if batch_count >= eval_batch_count:
                    break
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                running_loss += criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    running_correct += pred.argmax(1).eq(label).sum()
                    running_samples += len(label)

            results[node_id]['loss'] = running_loss / (batch_count + 1)
            tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f}"
            if isinstance(criteria, torch.nn.CrossEntropyLoss):
                results[node_id]['correct'] = running_correct
                results[node_id]['total'] = running_samples
                results[node_id]['acc'] = running_correct / running_samples
                tqdm_str += f", AVG Acc: {results[node_id]['acc']}"
        if finetune_epochs > 0:
            net.train()
            inner_optim = torch.optim.SGD(net.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=finetune_wd)
            for e in range(finetune_epochs):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            net.eval()
            running_loss, running_correct, running_samples = 0., 0., 0.
            with torch.no_grad():
                for batch_count, batch in enumerate(curr_data):
                    if batch_count >= eval_batch_count:
                        break
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    running_loss += criteria(pred, label)
                    if isinstance(criteria, torch.nn.CrossEntropyLoss):
                        running_correct += pred.argmax(1).eq(label).sum()
                        running_samples += len(label)

                results[node_id]['finetune_loss'] = running_loss / (batch_count + 1)
                tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f} -> {results[node_id]['finetune_loss']:.4f}"
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    results[node_id]['finetune_correct'] = running_correct
                    results[node_id]['finetune_total'] = running_samples
                    results[node_id]['finetune_acc'] = running_correct / running_samples
                    tqdm_str += f", AVG Acc: {results[node_id]['acc']:.4f} -> {results[node_id]['finetune_acc']:.4f}"
        node_iter.set_description(tqdm_str)
    return results


def train(hparam_dict) -> None:
    set_seed(args.seed)
    device = get_device(gpus=hparam_dict['gpu'])
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(hparam_dict['data_name'], batch_size=hparam_dict['batch_size'], eval_batch_size=hparam_dict['eval_batch_size'], num_workers=hparam_dict['num_workers'], device=device, data_dir=hparam_dict['data_dir'])
    num_nodes = nodes.n_nodes
    all_nodes = np.arange(num_nodes)
    client_sample_rng = np.random.default_rng(hparam_dict['seed'])
    if 'METR-LA-' in hparam_dict['data_name'] or 'PEMS-BAY-' in hparam_dict['data_name']:
        new_order = all_nodes
        train_nodes = nodes.train_nodes
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = list(set(all_nodes) - set(train_nodes))
    else:
        new_order = nodes.shuffle()
        num_train = int(0.8*num_nodes)
        train_nodes = all_nodes[:num_train]
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = all_nodes[num_train:]
    logging.info(f"Num of nodes: {len(all_nodes)}")
    logging.info(f"Num of train nodes: {len(train_nodes)}")
    logging.info(f"adv_node_list: {adv_node_list}")
    if 'METR-LA' in hparam_dict['data_name'] or 'PEMS-BAY' in hparam_dict['data_name']:
        net = GRUTarget(nodes.data_dim, out_dim=nodes.label_dim)
    elif hparam_dict['data_name'] == 'CompCars':
        CNN = getattr(torchvision.models, hparam_dict['target_arch'])
        net = CNN(pretrained=True)
        final_layer_name, final_layer = list(net.named_children())[-1]
        setattr(net, final_layer_name, torch.nn.Linear(final_layer.in_features, nodes.label_dim))
    else:
        net = MLPTarget(nodes.data_dim, out_dim=nodes.label_dim)
    net.to(device)

    deform = DeformWrapper(device, hparam_dict['aug_method'], nodes.label_dim, hparam_dict['sigma'])

    ##################
    # init optimizer #
    ##################
    if hparam_dict['data_name'] in ['DG15', 'DG60', 'DG60_TRUNCNORM', 'CompCars']:
        criteria = torch.nn.CrossEntropyLoss()
    elif hparam_dict['data_name'] in ['TPT48', 'TPT48_TRUNCNORM', 'METR-LA', 'METR-LA-0.25', 'METR-LA-0.5', 'METR-LA-0.75', 'METR-LA-0.05', 'METR-LA-0.9', 'PEMS-BAY', 'PEMS-BAY-0.25', 'PEMS-BAY-0.5', 'PEMS-BAY-0.75', 'PEMS-BAY-0.05', 'PEMS-BAY-0.9']:
        criteria = torch.nn.MSELoss()
    else:
        raise f"Unknown dataset {hparam_dict['data_name']}."
    
    metric_dict = {
        'Train/best_step': np.nan,
        'Train/best_test_avg_loss': np.nan,
        'Train/best_test_avg_acc': np.nan,
        'Unseen/best_step': np.nan,
        'Unseen/best_test_avg_loss': np.nan,
        'Unseen/best_test_avg_acc': np.nan,
        'Unseen_0/test_avg_loss': np.nan,
        'Unseen_0/test_avg_acc': np.nan,
    }

    writer = SummaryWriter(hparam_dict['run_path'])
    if hparam_dict['pretrain_targetnet_steps'] > 0:
        net.train()
        step_iter = trange(hparam_dict['pretrain_targetnet_steps'])
        inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
        node_id = train_nodes[0]
        for step in step_iter:
            inner_optim.zero_grad(set_to_none=True)

            batch = next(iter(nodes.train_loaders[node_id]))
            batch = tuple(t.to(device, non_blocking=True) for t in batch)
            label = batch[1]
            if isinstance(net, GRUTarget):
                pred = net(batch)
            else:
                pred = net(batch[0])

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            inner_optim.step()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            step_iter.set_description(
                f"Pre-train Target NN Step: {step}, Node ID: {node_id}, Loss: {loss.item():.4f}, Acc: {acc.item():.4f}"
            )
            writer.add_scalar('Pre-Train/loss_target', loss, step)
            writer.add_scalar('Pre-Train/acc_target', acc, step)
        net.eval()

    ################
    # init metrics #
    ################
    nodes.global_weights = detach_clone(OrderedDict(net.named_parameters()))
    nodes.global_buffers = detach_clone(OrderedDict(net.named_buffers()))
    steps = hparam_dict['num_steps']
    step_iter = trange(steps)
    eval_every = hparam_dict['eval_every']
    num_clients = min(len(train_nodes), hparam_dict['num_clients'])
    inner_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['inner_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['inner_lr_min'])
    inner_lr_scheduler.step()
    client_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['client_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['client_lr']*0.01)
    client_lr_scheduler.step()
    evaluator = partial(evaluate, nodes, net, criteria, device, hparam_dict['eval_batch_count'], hparam_dict['eval_finetune_epochs'], hparam_dict['eval_finetune_lr'], hparam_dict['eval_finetune_wd'])
    for step in step_iter:
        if step % eval_every == 0:
            results = eval_model(evaluator, train_nodes, adv_node_list)
            write_results(writer, 'Train', results, step)

        # select client at random
        node_ids = client_sample_rng.choice(train_nodes, num_clients, replace=False)
        logging.debug(f"node_ids: {node_ids}")
        if hparam_dict['update_all']:
            # node_iter = tqdm(train_nodes, position=1, leave=False)
            node_iter = train_nodes
        else:
            # node_iter = tqdm(node_ids, position=1, leave=False)
            node_iter = node_ids
        node_results = []
        # produce & load local network weights
        for node_id in node_iter:
            if not nodes.joined[node_id]:
                nodes.weights[node_id] = detach_clone(nodes.global_weights)
                nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
            for n,p in net.named_parameters():
                p.data = nodes.weights[node_id][n]
            for n,p in net.named_buffers():
                p.data = nodes.buffers[node_id][n]
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            batch_count = 0
            if hparam_dict['inner_epochs'] is None:
                inner_steps = hparam_dict['R'] * hparam_dict['K']
                inner_epochs = inner_steps // len(nodes.train_loaders[node_id]) + 1
            else:
                inner_epochs = hparam_dict['inner_epochs']
                inner_steps = inner_epochs * len(nodes.train_loaders[node_id])
            for e in range(inner_epochs):
                 # inner updates -> obtaining theta_tilda
                for batch in nodes.train_loaders[node_id]:
                    batch_count += 1
                    if batch_count > inner_steps:
                        break
                    inner_optim.zero_grad(set_to_none=True)
                    # batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    if node_id in adv_node_list:
                        batch = deform(batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
                    for n,p in net.named_parameters():
                        if p.grad.data is not None:
                            p.data = p.data - inner_lr_scheduler.get_lr() * (p.grad.data + hparam_dict['client_lambda'] * (p.data - nodes.weights[node_id][n].data) + hparam_dict['client_mu'] * p.data)
                    if batch_count % hparam_dict['K'] == 0:
                        for n,p in net.named_parameters():
                            nodes.weights[node_id][n].data -= hparam_dict['client_lambda'] * client_lr_scheduler.get_lr() * (nodes.weights[node_id][n].data - p.data)

            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            # node_iter.set_description(
            #     f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
            # )
            node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
        
        for n,p in net.named_parameters():
            nodes.global_weights[n] = (1-hparam_dict['beta']) * nodes.global_weights[n].data + hparam_dict['beta'] * torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0).data
        for n,p in net.named_buffers():
            nodes.global_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
        if hparam_dict['inner_lr_scheduler']:
            inner_lr_scheduler.step()
            client_lr_scheduler.step()
        prvs_loss, node_loss, prvs_acc, acc = np.mean(node_results, axis=0)
        step_iter.set_description(
            f"Step: {step}, Loss: {prvs_loss:.4f} -> {node_loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
        )
    # Log Final performance
    step += 1
    results = eval_model(evaluator, train_nodes, adv_node_list)
    write_results(writer, 'Train', results, step)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'global_buffers': nodes.global_buffers,
        'client_weights': nodes.weights,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'final.ckpt'))


    # Evaluate on unseen nodes
    step = 0
    for node_id in unseen_nodes:
        nodes.weights[node_id] = detach_clone(nodes.global_weights)
    results = eval_model(evaluator, unseen_nodes)
    write_results(writer, 'Unseen_0', results, step)
    best_metric_dict = {
        'Unseen_0/test_avg_loss': results['test_avg_loss'],
        'Unseen_0/test_max_loss': results['test_max_loss'],
        'Unseen_0/test_min_loss': results['test_min_loss'],
        'Unseen_0/test_std_loss': results['test_std_loss']
    }
    if isinstance(criteria, torch.nn.CrossEntropyLoss):
        best_metric_dict.update({
            'Unseen_0/test_avg_acc': results['test_avg_acc'],
            'Unseen_0/test_max_acc': results['test_max_acc'],
            'Unseen_0/test_min_acc': results['test_min_acc'],
            'Unseen_0/test_std_acc': results['test_std_acc']
        })
    metric_dict.update(best_metric_dict)
    exp, ssi, sei = hparams(hparam_dict, metric_dict)
    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k,v in best_metric_dict.items():
        writer.add_scalar(k, v, step)

    step_iter = trange(hparam_dict['eval_unseen_steps'])
    num_clients = min(len(unseen_nodes), hparam_dict['num_clients'])
    for step in step_iter:
        # select client at random
        node_ids = np.random.choice(unseen_nodes, num_clients, replace=False)
        aggr_weights = OrderedDict([[n, torch.zeros_like(p.data)] for n, p in net.named_parameters()])
        # produce & load local network weights
        for node_id in node_ids:
            for n,p in net.named_parameters():
                p.data = nodes.global_weights[n].data.clone()
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            batch_count = 0
            for e in range(hparam_dict['inner_epochs']):
                 # inner updates -> obtaining theta_tilda
                logging.debug(f"Start local training on client {node_id}.")
                for batch in nodes.train_loaders[node_id]:
                    batch_count += 1
                    inner_optim.zero_grad(set_to_none=True)
                    batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])

                    loss = criteria(pred, label)
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
                    for n,p in net.named_parameters():
                        if p.grad.data is not None:
                            p.data = p.data -  hparam_dict['inner_lr'] * (p.grad.data + hparam_dict['client_lambda'] * (p.data - nodes.weights[node_id][n].data) + hparam_dict['client_mu'] * p.data)
                    if batch_count % hparam_dict['inner_steps'] == 0: 
                        for n,p in net.named_parameters():
                            nodes.weights[node_id][n].data -= hparam_dict['client_lambda'] * hparam_dict['client_lr'] * (nodes.weights[node_id][n].data - p.data)
            for n,p in net.named_parameters():
                aggr_weights[n].data += nodes.weights[node_id][n].data.detach().clone() / hparam_dict['num_clients']
            
            step_iter.set_description(
                f"Step: {step}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
            )
        for n,p in net.named_parameters():
            nodes.global_weights[n].data = (1-hparam_dict['beta']) * nodes.global_weights[n].data + hparam_dict['beta'] * aggr_weights[n].data

        results = eval_model(evaluator, unseen_nodes)
        write_results(writer, 'Unseen', results, step+1)
        step_iter.set_description(f"Unseen Step: {step}, AVG Loss: {results['test_avg_loss']:.4f},  AVG Acc: {results['test_all_acc']:.4f}")
    writer.close()
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'global_weights': nodes.global_weights,
        'global_buffers': nodes.global_buffers,
        'client_weights': nodes.weights,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'unseen.ckpt'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Federated Hypernetwork with Lookahead experiment"
    )

    #############################
    #       Dataset Args        #
    #############################

    parser.add_argument("--data-name", type=str, default="DG15")
    parser.add_argument("--data-dir", type=str, default="data", help="dir name of dataset")

    ##################################
    #       Optimization args        #
    ##################################

    parser.add_argument("--pretrain-targetnet-steps", type=int, default=0)
    parser.add_argument("--num-steps", type=int, default=800)
    parser.add_argument("--eval-unseen-steps", type=int, default=0)
    parser.add_argument("--optim", type=str, default='sgd', choices=['adam', 'sgd'], help="optimizer")
    parser.add_argument("--num-clients", type=int, default=5)
    parser.add_argument("--last-layer-only", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument("--ray-tune", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument("--num-trials", type=int, default=10)
    parser.add_argument('--aug-method', type=str, default='label_flip', help='type of augmentation for training')
    parser.add_argument("--adv-ratio", type=float, default=.0, help="Ratio of adversarial clients participated FL training.")
    parser.add_argument("--sigma", type=float, default=.0, help="noise hyperparameter")

    ################################
    #       Model Prop args        #
    ################################
    parser.add_argument("--n-hidden", type=int, default=3, help="num. hidden layers")
    parser.add_argument("--inner-lr", type=float, default=1e-2, help="learning rate for inner optimizer")
    parser.add_argument("--inner-lr-min", type=float, default=1e-3, help="Min learning rate for inner optimizer")
    parser.add_argument("--inner-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
    parser.add_argument("--inner-epochs", type=int, default=None)
    parser.add_argument("--inner-steps", type=int, default=50)
    parser.add_argument("--K", type=int, default=5, help="number of updates in each local round")
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--target-arch", type=str, default=None, help="architecture of target network")
    parser.add_argument("--client-lr", type=float, default=0.001, help="learning rate for inner optimizer")
    parser.add_argument("--beta", type=float, default=1.0)
    parser.add_argument("--client-lambda", type=float, default=15, help="learning rate for inner optimizer")
    parser.add_argument("--client-mu", type=float, default=0.001, help="learning rate for inner optimizer")
    parser.add_argument("--update-all", action='store_true', default=False, help="Update all client's weights in each round")

    #############################
    #       General args        #
    #############################
    parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
    parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
    parser.add_argument("--eval-batch-size", type=int, default=64, help="eval batch size")
    parser.add_argument("--eval-batch-count", type=int, default=999999, help="eval num batch")
    parser.add_argument("--eval-finetune-epochs", type=int, default=0)
    parser.add_argument("--eval-finetune-lr", type=float, default=1e-3, help="learning rate for eval finetune optimizer")
    parser.add_argument("--eval-finetune-wd", type=float, default=5e-5, help="weight decay for eval finetune optimizer")
    parser.add_argument("--num-workers", type=int, default=0, help="num workers for dataloader")
    parser.add_argument("--run-name", type=str, default=None, help="name for output run file")
    parser.add_argument("--run-dir", type=str, default='runs', help="dir path for output run file")
    parser.add_argument("--task-name", type=str, default=None, help="ClearML task name")
    parser.add_argument("--log-level", type=str, default='INFO', help="log level")
    parser.add_argument("--seed", type=int, default=42, help="seed value")

    args = parser.parse_args()
    if args.target_arch is None:
        if 'METR-LA' in args.data_name or 'PEMS-BAY' in args.data_name:
            args.target_arch = 'GRUSeq2Seq'
        elif args.data_name == 'CompCars':
            args.target_arch = 'resnet18'
        else:
            args.target_arch = 'mlp'
    assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"

    if args.run_name is None:
        import time
        args.run_name = f"{time.strftime('%Y%m%d_%H-%M-%S')}_{args.data_name}_pfedme"
        if args.adv_ratio > 0:
            if args.aug_method == 'label_flip':
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}"
            elif args.aug_method != 'nominal' and args.sigma > 0:
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}_{args.sigma}"
            else:
                args.adv_ratio = 0
    args.run_name += f'_seed_{args.seed}'
    hparam_dict = vars(args)
    hparam_dict['algorithm'] = 'pfedme'
    hparam_dict['R'] = hparam_dict['inner_steps'] // hparam_dict['K']
    hparam_dict['run_path'] = os.path.join(args.run_dir, args.run_name)
    if not os.path.exists(hparam_dict['run_path']):
        os.makedirs(hparam_dict['run_path'], exist_ok=True)
    set_logger(os.path.join(hparam_dict['run_path'],'output.log'), level=args.log_level)
    logging.info(json.dumps(hparam_dict, sort_keys=True, indent=4))
    try:
        from clearml import Task
        if args.task_name is None:
            hparam_dict['task_name'] = args.run_name
        task = Task.init(project_name='Panacea NeurIPS 2023', task_name=hparam_dict['task_name'])
    except:
        task = None

    try:
        train(hparam_dict)
    except Exception:
        print(traceback.format_exc())
    if task is not None:
        task.close()
