import os
import argparse
import json
import logging
from platform import node
import random
from collections import OrderedDict, defaultdict
from pathlib import Path
import copy
import traceback
from functools import partial

import numpy as np
import torch
import torch.utils.data
import torch.nn.functional as F
from tqdm.auto import trange, tqdm
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import subgraph, from_networkx
from torch_sparse import SparseTensor
import torchvision

from models import CNNTarget, GAEHyper, GNNHyper, GRUTarget, MLPTarget, MLPHyper, MLPInnerProductDecoder
from node import BaseNodes
from utils import get_device, set_logger, set_seed, freeze, unfreeze, detach_clone, CustomCosineAnnealingLR, write_results, eval_model
from deform import DeformWrapper


def evaluate(nodes: BaseNodes, net, criteria, device, eval_batch_count, finetune_epochs, finetune_lr, finetune_wd, node_list, split):
    net.eval()
    results = defaultdict(lambda: defaultdict(list))
    
    node_iter = tqdm(node_list, position=-1, leave=False)
    for node_id in node_iter:  # iterating over nodes
        for n,p in net.named_buffers():
            p.data = nodes.global_buffers[n].detach().clone()
        for n,p in net.named_parameters():
            p.data = nodes.global_weights[n].detach().clone()
        with torch.no_grad():
            running_loss, running_correct, running_samples = 0., 0., 0.
            if split == 'test':
                curr_data = nodes.test_loaders[node_id]
            elif split == 'val':
                curr_data = nodes.val_loaders[node_id]
            else:
                curr_data = nodes.train_loaders[node_id]

            for batch_count, batch in enumerate(curr_data):
                if batch_count >= eval_batch_count:
                    break
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                running_loss += criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    running_correct += pred.argmax(1).eq(label).sum()
                    running_samples += len(label)
            results[node_id]['loss'] = running_loss / (batch_count + 1)
            tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f}"
            if isinstance(criteria, torch.nn.CrossEntropyLoss):
                results[node_id]['correct'] = running_correct
                results[node_id]['total'] = running_samples
                results[node_id]['acc'] = running_correct / running_samples
                tqdm_str += f", AVG Acc: {results[node_id]['acc']}"
        if finetune_epochs > 0:
            net.train()
            inner_optim = torch.optim.SGD(net.parameters(), lr=finetune_lr, momentum=0.9, weight_decay=finetune_wd)
            for e in range(finetune_epochs):
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            net.eval()
            running_loss, running_correct, running_samples = 0., 0., 0.
            with torch.no_grad():
                for batch_count, batch in enumerate(curr_data):
                    if batch_count >= eval_batch_count:
                        break
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    running_loss += criteria(pred, label)
                    if isinstance(criteria, torch.nn.CrossEntropyLoss):
                        running_correct += pred.argmax(1).eq(label).sum()
                        running_samples += len(label)

                results[node_id]['finetune_loss'] = running_loss / (batch_count + 1)
                tqdm_str = f"Eval {split}set Node ID: {node_id}, AVG Loss: {results[node_id]['loss']:.4f} -> {results[node_id]['finetune_loss']:.4f}"
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    results[node_id]['finetune_correct'] = running_correct
                    results[node_id]['finetune_total'] = running_samples
                    results[node_id]['finetune_acc'] = running_correct / running_samples
                    tqdm_str += f", AVG Acc: {results[node_id]['acc']:.4f} -> {results[node_id]['finetune_acc']:.4f}"
        node_iter.set_description(tqdm_str)
    return results


def train(hparam_dict: dict) -> None:

    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.tensorboard.summary import hparams
    writer = SummaryWriter(hparam_dict['run_path'])
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(hparam_dict['data_name'], batch_size=hparam_dict['batch_size'], eval_batch_size=hparam_dict['eval_batch_size'], num_workers=hparam_dict['num_workers'], device=device, data_dir=hparam_dict['data_dir'])
    num_nodes = nodes.n_nodes
    all_nodes = np.arange(num_nodes)
    client_sample_rng = np.random.default_rng(hparam_dict['seed'])
    if 'METR-LA-' in hparam_dict['data_name'] or 'PEMS-BAY-' in hparam_dict['data_name']:
        new_order = all_nodes
        train_nodes = nodes.train_nodes
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = list(set(all_nodes) - set(train_nodes))
    else:
        new_order = nodes.shuffle()
        num_train = int(0.8*num_nodes)
        train_nodes = all_nodes[:num_train]
        adv_node_list = client_sample_rng.choice(train_nodes, int(len(train_nodes)*hparam_dict['adv_ratio']), replace=False)
        unseen_nodes = all_nodes[num_train:]
    logging.info(f"Num of nodes: {len(all_nodes)}")
    logging.info(f"Num of train nodes: {len(train_nodes)}")
    logging.info(f"adv_node_list: {adv_node_list}")
    if 'METR-LA' in hparam_dict['data_name'] or 'PEMS-BAY' in hparam_dict['data_name']:
        net = GRUTarget(nodes.data_dim, out_dim=nodes.label_dim)
    elif hparam_dict['data_name'] == 'CompCars':
        CNN = getattr(torchvision.models, hparam_dict['target_arch'])
        net = CNN(pretrained=True)
        final_layer_name, final_layer = list(net.named_children())[-1]
        setattr(net, final_layer_name, torch.nn.Linear(final_layer.in_features, nodes.label_dim))
    else:
        net = MLPTarget(nodes.data_dim, out_dim=nodes.label_dim)
    net.to(device)

    deform = DeformWrapper(device, hparam_dict['aug_method'], nodes.label_dim, hparam_dict['sigma'])

    ##################
    # init optimizer #
    ##################
    if hparam_dict['data_name'] in ['DG15', 'DG60', 'DG60_TRUNCNORM', 'CompCars']:
        criteria = torch.nn.CrossEntropyLoss()
    elif hparam_dict['data_name'] in ['TPT48', 'TPT48_TRUNCNORM', 'METR-LA', 'METR-LA-0.25', 'METR-LA-0.5', 'METR-LA-0.75', 'METR-LA-0.05', 'METR-LA-0.9', 'PEMS-BAY', 'PEMS-BAY-0.25', 'PEMS-BAY-0.5', 'PEMS-BAY-0.75', 'PEMS-BAY-0.05', 'PEMS-BAY-0.9']:
        criteria = torch.nn.MSELoss()
    else:
        raise f"Unknown dataset {hparam_dict['data_name']}."
    
    metric_dict = {
        'Train/best_step': np.nan,
        'Train/best_test_avg_loss': np.nan,
        'Train/best_test_avg_acc': np.nan,
        'Unseen/best_step': np.nan,
        'Unseen/best_test_avg_loss': np.nan,
        'Unseen/best_test_avg_acc': np.nan,
        'Unseen_0/val_avg_loss': np.nan,
        'Unseen_0/test_avg_loss': np.nan,
        'Unseen_0/test_max_loss': np.nan,
        'Unseen_0/test_min_loss': np.nan,
        'Unseen_0/test_std_loss': np.nan,
        'Unseen_0/val_avg_acc': np.nan,
        'Unseen_0/test_avg_acc': np.nan,
        'Unseen_0/test_max_acc': np.nan,
        'Unseen_0/test_min_acc': np.nan,
        'Unseen_0/test_std_acc': np.nan,
    }

    if hparam_dict['pretrain_targetnet_steps'] > 0:
        net.train()
        step_iter = trange(hparam_dict['pretrain_targetnet_steps'])
        inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )
        node_id = train_nodes[0]
        for step in step_iter:
            inner_optim.zero_grad()

            batch = next(iter(nodes.train_loaders[node_id]))
            batch = tuple(t.to(device, non_blocking=True) for t in batch)
            label = batch[1]
            if isinstance(net, GRUTarget):
                pred = net(batch)
            else:
                pred = net(batch[0])

            loss = criteria(pred, label)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net.parameters(), 1)
            inner_optim.step()
            with torch.no_grad():
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
            step_iter.set_description(
                f"Pre-train Target NN Step: {step}, Node ID: {node_id}, Loss: {loss.item():.4f}, Acc: {acc.item():.4f}"
            )
            writer.add_scalar('Pre-Train/loss_target', loss, step)
            writer.add_scalar('Pre-Train/acc_target', acc, step)
        net.eval()

    ################
    # init metrics #
    ################
    nodes.global_weights = detach_clone(OrderedDict(net.named_parameters()))
    nodes.global_buffers = detach_clone(OrderedDict(net.named_buffers()))
    num_steps = hparam_dict['num_steps']
    eval_every = hparam_dict['eval_every']
    step_iter = trange(num_steps, position=0)
    num_clients = min(len(train_nodes), hparam_dict['num_clients'])
    inner_lr_scheduler = CustomCosineAnnealingLR(hparam_dict['inner_lr'], T_max=hparam_dict['num_steps'], eta_min=hparam_dict['inner_lr_min'])
    inner_lr_scheduler.step()
    evaluator = partial(evaluate, nodes, net, criteria, device, hparam_dict['eval_batch_count'], hparam_dict['eval_finetune_epochs'], hparam_dict['eval_finetune_lr'], hparam_dict['eval_finetune_wd'])
    for step in step_iter:
        if step % eval_every == 0:
            results = eval_model(evaluator, train_nodes, adv_node_list)
            write_results(writer, 'Train', results, step)
        # select client at random
        node_ids = client_sample_rng.choice(train_nodes, num_clients, replace=False)
        logging.debug(f"node_ids: {node_ids}")
        node_results = []
        # node_iter = tqdm(node_ids, total=num_clients, leave=False, position=1)
        # produce & load local network weights
        for node_id in node_ids:
            nodes.weights[node_id] = detach_clone(nodes.global_weights)
            nodes.buffers[node_id] = detach_clone(nodes.global_buffers)
            for n,p in net.named_parameters():
                p.data = nodes.weights[node_id][n]
            for n,p in net.named_buffers():
                p.data = nodes.buffers[node_id][n]
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            batch_count = 0
            if hparam_dict['inner_epochs'] is None:
                inner_steps = hparam_dict['inner_steps']
                inner_epochs = inner_steps // len(nodes.train_loaders[node_id]) + 1
            else:
                inner_epochs = hparam_dict['inner_epochs']
                inner_steps = inner_epochs * len(nodes.train_loaders[node_id])
            for e in range(inner_epochs):
                 # inner updates -> obtaining theta_tilda
                for batch in nodes.train_loaders[node_id]:
                    batch_count += 1
                    if batch_count > inner_steps:
                        break
                    inner_optim.zero_grad()
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    if node_id in adv_node_list:
                        batch = deform(batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    acc = torch.tensor(0)
                node_results += [[prvs_loss.item(), loss.item(), prvs_acc.item(), acc.item()]]
            # node_iter.set_description(
            #     f"Node ID: {node_id}, Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
            # )
        for n,p in net.named_parameters():
            nodes.global_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)
        for n,p in net.named_buffers():
            nodes.global_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)
        prvs_loss, loss, prvs_acc, acc = np.mean(node_results, axis=0)
        step_iter.set_description(
            f"Step: {step}, Train Loss: {prvs_loss:.4f} -> {loss:.4f},  Acc: {prvs_acc:.4f} -> {acc:.4f}"
        )
        if hparam_dict['inner_lr_scheduler']:
            inner_lr_scheduler.step()

    # Log Final performance
    step += 1
    results = eval_model(evaluator, train_nodes, adv_node_list)
    write_results(writer, 'Train', results, step)
    torch.save({
        'hparam_dict': hparam_dict,
        'metric_dict': metric_dict,
        'nodes.global_weights': nodes.global_weights,
        'nodes.global_buffers': nodes.global_buffers,
        'step': step,
        'node_order': new_order
    }, os.path.join(hparam_dict['run_path'],'final.ckpt'))

    # Evaluate on unseen nodes
    step = 0
    results = eval_model(evaluator, unseen_nodes)
    write_results(writer, 'Unseen_0', results, step)
    best_metric_dict = {
        'Unseen_0/test_avg_loss': results['test_avg_loss'],
        'Unseen_0/test_max_loss': results['test_max_loss'],
        'Unseen_0/test_min_loss': results['test_min_loss'],
        'Unseen_0/test_std_loss': results['test_std_loss']
    }
    if isinstance(criteria, torch.nn.CrossEntropyLoss):
        best_metric_dict.update({
            'Unseen_0/test_avg_acc': results['test_avg_acc'],
            'Unseen_0/test_max_acc': results['test_max_acc'],
            'Unseen_0/test_min_acc': results['test_min_acc'],
            'Unseen_0/test_std_acc': results['test_std_acc']
        })
    metric_dict.update(best_metric_dict)
    exp, ssi, sei = hparams(hparam_dict, metric_dict)
    writer.file_writer.add_summary(exp)
    writer.file_writer.add_summary(ssi)
    writer.file_writer.add_summary(sei)
    for k,v in best_metric_dict.items():
        writer.add_scalar(k, v, step)

    step_iter = trange(hparam_dict['eval_unseen_steps'])
    num_clients = min(len(unseen_nodes), hparam_dict['num_clients'])
    for step in step_iter:
        # select client at random
        node_ids = np.random.choice(unseen_nodes, num_clients, replace=False)
        # produce & load local network weights
        for node_id in node_ids:
            for n,p in net.named_parameters():
                p.data = nodes.global_weights[n].detach().clone()
            for n,p in net.named_buffers():
                p.data = nodes.global_buffers[n].detach().clone()
            # init inner optimizer
            inner_optim = torch.optim.SGD(
                net.parameters(), lr=inner_lr_scheduler.get_lr(), momentum=0.9, weight_decay=hparam_dict['inner_wd']
            )

            # NOTE: evaluation on sent model
            with torch.no_grad():
                net.eval()
                batch = next(iter(nodes.test_loaders[node_id]))
                batch = tuple(t.to(device, non_blocking=True) for t in batch)
                label = batch[1]
                if isinstance(net, GRUTarget):
                    pred = net(batch)
                else:
                    pred = net(batch[0])
                prvs_loss = criteria(pred, label)
                if isinstance(criteria, torch.nn.CrossEntropyLoss):
                    prvs_acc = pred.argmax(1).eq(label).sum() / len(label)
                else:
                    prvs_acc = torch.tensor(0)

            net.train()
            for e in range(hparam_dict['inner_epochs']):
                 # inner updates -> obtaining theta_tilda
                logging.debug(f"Start local training on client {node_id}.")
                for batch in nodes.train_loaders[node_id]:
                    inner_optim.zero_grad()
                    # batch = next(iter(nodes.train_loaders[node_id]))
                    batch = tuple(t.to(device, non_blocking=True) for t in batch)
                    label = batch[1]
                    if isinstance(net, GRUTarget):
                        pred = net(batch)
                    else:
                        pred = net(batch[0])
                    loss = criteria(pred, label)
                    loss.backward()
                    inner_optim.step()  
            step_iter.set_description(
                f"Unseen Step: {step}, Node ID: {node_id}, Loss: {prvs_loss:.4f},  Acc: {prvs_acc:.4f}"
            )

        for n,p in net.named_parameters():
            nodes.global_weights[n] = torch.stack([nodes.weights[node_id][n] for node_id in node_ids]).mean(dim=0)
        for n,p in net.named_buffers():
            nodes.global_buffers[n] = torch.stack([nodes.buffers[node_id][n] for node_id in node_ids]).float().mean(dim=0).type(nodes.global_buffers[n].dtype)

        results = eval_model(evaluator, unseen_nodes)
        write_results(writer, 'Unseen', results, step+1)
        step_iter.set_description(f"Unseen Step: {step}, AVG Loss: {results['test_avg_loss']:.4f},  AVG Acc: {results['test_all_acc']:.4f}")
        logging.info(f"Unseen Step: {step}, AVG Loss: {results['test_avg_loss']:.4f},  AVG Acc: {results['test_all_acc']:.4f}")
    writer.close()
    if hparam_dict['eval_unseen_steps'] > 0:
        torch.save({
            'hparam_dict': hparam_dict,
            'metric_dict': metric_dict,
            'nodes.global_weights': nodes.global_weights,
            'nodes.global_buffers': nodes.global_buffers,
            'step': step,
            'node_order': new_order
        }, os.path.join(hparam_dict['run_path'],'unseen.ckpt'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Federated Hypernetwork with Lookahead experiment"
    )

    #############################
    #       Dataset Args        #
    #############################

    parser.add_argument("--data-name", type=str, default="DG15")
    parser.add_argument("--data-dir", type=str, default="data", help="dir name of dataset")

    ##################################
    #       Optimization args        #
    ##################################

    parser.add_argument("--pretrain-targetnet-steps", type=int, default=0)
    parser.add_argument("--num-steps", type=int, default=800)
    parser.add_argument("--eval-unseen-steps", type=int, default=0)
    parser.add_argument("--optim", type=str, default='sgd', choices=['adam', 'sgd'], help="optimizer")
    parser.add_argument("--num-clients", type=int, default=5)
    parser.add_argument("--last-layer-only", action='store_true', default=False, help="Only personalize the last layer.")
    parser.add_argument('--aug-method', type=str, default='label_flip', help='type of augmentation for training')
    parser.add_argument("--adv-ratio", type=float, default=.0, help="Ratio of adversarial clients participated FL training.")
    parser.add_argument("--sigma", type=float, default=.0, help="noise hyperparameter")

    ################################
    #       Model Prop args        #
    ################################
    parser.add_argument("--n-hidden", type=int, default=3, help="num. hidden layers")
    parser.add_argument("--inner-lr", type=float, default=1e-2, help="learning rate for inner optimizer")
    parser.add_argument("--inner-lr-min", type=float, default=1e-3, help="Min learning rate for inner optimizer")
    parser.add_argument("--inner-wd", type=float, default=5e-5, help="inner weight decay")
    parser.add_argument("--inner-lr-scheduler", action='store_true', default=False, help="Use learning rate scheduler for inner optimizer")
    parser.add_argument("--inner-epochs", type=int, default=None)
    parser.add_argument("--inner-steps", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=64)
    parser.add_argument("--target-arch", type=str, default=None, help="architecture of target network")

    #############################
    #       General args        #
    #############################
    parser.add_argument("--gpu", type=int, default=0, help="gpu device ID")
    parser.add_argument("--eval-every", type=int, default=30, help="eval every X selected epochs")
    parser.add_argument("--eval-batch-size", type=int, default=64, help="eval batch size")
    parser.add_argument("--eval-batch-count", type=int, default=999999, help="eval num batch")
    parser.add_argument("--eval-finetune-epochs", type=int, default=0)
    parser.add_argument("--eval-finetune-lr", type=float, default=1e-3, help="learning rate for eval finetune optimizer")
    parser.add_argument("--eval-finetune-wd", type=float, default=5e-5, help="weight decay for eval finetune optimizer")
    parser.add_argument("--num-workers", type=int, default=0, help="num workers for dataloader")
    parser.add_argument("--run-name", type=str, default=None, help="name for output run file")
    parser.add_argument("--run-dir", type=str, default='runs', help="dir path for output run file")
    parser.add_argument("--task-name", type=str, default=None, help="ClearML task name")
    parser.add_argument("--log-level", type=str, default='INFO', help="log level")
    parser.add_argument("--seed", type=int, default=42, help="seed value")

    args = parser.parse_args()
    if args.target_arch is None:
        if 'METR-LA' in args.data_name or 'PEMS-BAY' in args.data_name:
            args.target_arch = 'GRUSeq2Seq'
        elif args.data_name == 'CompCars':
            args.target_arch = 'resnet18'
        else:
            args.target_arch = 'mlp'
    assert args.gpu <= torch.cuda.device_count(), f"--gpu flag should be in range [0,{torch.cuda.device_count() - 1}]"

    if args.run_name is None:
        import time
        args.run_name = f"{time.strftime('%Y%m%d_%H-%M-%S')}_{args.data_name}_fedavg"
        if args.adv_ratio > 0:
            if args.aug_method == 'label_flip':
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}"
            elif args.aug_method != 'nominal' and args.sigma > 0:
                args.run_name += f"_adv_{args.adv_ratio}_{args.aug_method}_{args.sigma}"
            else:
                args.adv_ratio = 0
    args.run_name += f'_seed_{args.seed}'
    hparam_dict = vars(args)
    hparam_dict['algorithm'] = 'fedavg'
    hparam_dict['run_path'] = os.path.join(args.run_dir, args.run_name)
    if not os.path.exists(hparam_dict['run_path']):
        os.makedirs(hparam_dict['run_path'], exist_ok=True)
    set_logger(os.path.join(hparam_dict['run_path'],'output.log'), level=args.log_level)
    set_seed(args.seed)
    device = get_device(gpus=args.gpu)
    logging.info(json.dumps(hparam_dict, sort_keys=True, indent=4))
    try:
        from clearml import Task
        if args.task_name is None:
            hparam_dict['task_name'] = args.run_name
        task = Task.init(project_name='Panacea NeurIPS 2023', task_name=hparam_dict['task_name'])
    except:
        task = None

    try:
        train(hparam_dict)
    except Exception:
        print(traceback.format_exc())
    if task is not None:
        task.close()
