# Adapted from https://raw.githubusercontent.com/mxu34/prompt-dt/refs/heads/main/pdt_main.py
# Original work Copyright (c) 2022 Mengdi Xu et al. 
# Modifications Copyright (c) 2025 King.com Ltd
from ast import parse
import gym
import numpy as np
import torch
import wandb

import argparse
import pickle
import random
import sys
import time
import itertools

from prompt_dt.prompt_decision_transformer import PromptDecisionTransformer
from prompt_dt.prompt_seq_trainer import PromptSequenceTrainer
from prompt_dt.prompt_utils import get_env_list
from prompt_dt.prompt_utils import get_prompt_batch, get_prompt, get_batch, get_batch_finetune
from prompt_dt.prompt_utils import process_total_data_mean, load_data_prompt, process_info
from prompt_dt.prompt_utils import eval_episodes

from collections import namedtuple
import json, pickle, os


def experiment_mix_env(
        exp_prefix,
        variant,
):
    device = variant['device']
    log_to_wandb = variant['log_to_wandb']
    seed = variant['seed']
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    ######
    # construct train and test environments
    ######

    cur_dir = os.getcwd()
    config_save_path = os.path.join(cur_dir, 'config')
    data_save_path = os.path.join(cur_dir, 'data')
    save_path = os.path.join(cur_dir, 'model_saved/')
    if not os.path.exists(save_path): os.mkdir(save_path)

    config_path_dict = {
        'cheetah_vel': "cheetah_vel/cheetah_vel_40.json",
        'cheetah_dir': "cheetah_dir/cheetah_dir_2.json",
        'ant_dir': "ant_dir/ant_dir_50.json",
        'ML1-pick-place-v2': "ML1-pick-place-v2/ML1_pick_place.json",
    }

    task_config = os.path.join(config_save_path, config_path_dict[args.env])
    with open(task_config, 'r') as f:
        task_config = json.load(f, object_hook=lambda d: namedtuple('X', d.keys())(*d.values()))
    train_env_name_list, test_env_name_list = [], []
    for task_ind in task_config.train_tasks:
        train_env_name_list.append(args.env + '-' + str(task_ind))
    for task_ind in task_config.test_tasks:
        test_env_name_list.append(args.env + '-' + str(task_ind))
    # training envs
    info, env_list = get_env_list(train_env_name_list, config_save_path, device, seed)
    # testing envs
    test_info, test_env_list = get_env_list(test_env_name_list, config_save_path, device, seed)

    # print(f'Env Info: {info} \n\n Test Env Info: {test_info}\n\n\n')
    # print(f'Env List: {env_list} \n\n Test Env List: {test_env_list}')
    ######
    # process train and test datasets
    ######

    K = variant['K']
    batch_size = variant['batch_size']
    pct_traj = variant.get('pct_traj', 1.)
    mode = variant.get('mode', 'normal')
    dataset_mode = variant['dataset_mode']
    test_dataset_mode = variant['test_dataset_mode']
    train_prompt_mode = variant['train_prompt_mode']
    test_prompt_mode = variant['test_prompt_mode']

    # load training dataset
    trajectories_list, prompt_trajectories_list = load_data_prompt(train_env_name_list, data_save_path, dataset_mode,
                                                                   train_prompt_mode, args)
    # load testing dataset
    test_trajectories_list, test_prompt_trajectories_list = load_data_prompt(test_env_name_list, data_save_path,
                                                                             test_dataset_mode, test_prompt_mode, args)

    # change to total train trajecotry 
    if variant['average_state_mean']:
        train_total = list(itertools.chain.from_iterable(trajectories_list))
        test_total = list(itertools.chain.from_iterable(test_trajectories_list))
        total_traj_list = train_total + test_total
        print(len(total_traj_list))
        total_state_mean, total_state_std = process_total_data_mean(total_traj_list, mode)
        variant['total_state_mean'] = total_state_mean
        variant['total_state_std'] = total_state_std

    # process train info
    info = process_info(train_env_name_list, trajectories_list, info, mode, dataset_mode, pct_traj, variant)
    # process test info
    test_info = process_info(test_env_name_list, test_trajectories_list, test_info, mode, test_dataset_mode, pct_traj,
                             variant)

    ######
    # construct dt model and trainer
    ######

    exp_prefix = exp_prefix + '-' + args.env
    num_env = len(train_env_name_list)
    group_name = f'{exp_prefix}-{str(num_env)}-Env-{dataset_mode}'
    # exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}'
    exp_prefix = f'{group_name}-seed-{seed}-J-{args.prompt_episode}-H-{args.prompt_length}-{random.randint(int(1e5), int(1e6) - 1)}{"-randomPromptTime" if variant["random_prompt_time"] else ""}'

    state_dim = test_env_list[0].observation_space.shape[0]
    act_dim = test_env_list[0].action_space.shape[0]

    model = PromptDecisionTransformer(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=1000,
        hidden_size=variant['embed_dim'],
        n_layer=variant['n_layer'],
        n_head=variant['n_head'],
        n_inner=4 * variant['embed_dim'],
        activation_function=variant['activation_function'],
        n_positions=1024,
        resid_pdrop=variant['dropout'],
        attn_pdrop=variant['dropout'],
    )
    model = model.to(device=device)


    warmup_steps = variant['warmup_steps']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant['learning_rate'],
        weight_decay=variant['weight_decay'],
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda steps: min((steps + 1) / warmup_steps, 1)
    )

    env_name = train_env_name_list[0]
    trainer = PromptSequenceTrainer(
        model=model,
        optimizer=optimizer,
        batch_size=batch_size,
        get_batch=get_batch(trajectories_list[0], info[env_name], variant),
        scheduler=scheduler,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
        eval_fns=None,
        get_prompt=get_prompt(prompt_trajectories_list[0], info[env_name], variant),
        get_prompt_batch=get_prompt_batch(trajectories_list, prompt_trajectories_list, info, variant,
                                          train_env_name_list)
    )

    if not variant['evaluation']:
        ######
        # start training
        ######
        if log_to_wandb:
            wandb.init(
                name=exp_prefix,
                group=group_name,
                project='prompt-decision-transformer',
                config=variant
            )
            save_path += wandb.run.name
            os.mkdir(save_path)

        # construct model post fix
        model_post_fix = '_TRAIN_' + variant['train_prompt_mode'] + '_TEST_' + variant['test_prompt_mode']
        if variant['no_prompt']:
            model_post_fix += '_NO_PROMPT'
        if variant['finetune']:
            model_post_fix += '_FINETUNE'
        if variant['no_r']:
            model_post_fix += '_NO_R'

        for iter in range(variant['max_iters']):
            env_id = iter % num_env
            env_name = train_env_name_list[env_id]
            outputs = trainer.pure_train_iteration_mix(
                num_steps=variant['num_steps_per_iter'],
                no_prompt=args.no_prompt
            )

            # start evaluation
            if iter % args.test_eval_interval == 0:
                # evaluate test
                if not args.finetune:
                    test_eval_logs = trainer.eval_iteration_multienv(
                        get_prompt, test_prompt_trajectories_list,
                        eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=iter + 1,
                        print_logs=True, no_prompt=args.no_prompt, group='test')
                    outputs.update(test_eval_logs)
                else:
                    test_eval_logs = trainer.finetune_eval_iteration_multienv(
                        get_prompt, get_batch_finetune, test_prompt_trajectories_list, test_trajectories_list,
                        eval_episodes, test_env_name_list, test_info,
                        variant, test_env_list, iter_num=iter + 1,
                        print_logs=True, no_prompt=args.no_prompt,
                        group='finetune-test', finetune_opt=variant['finetune_opt'])
                    outputs.update(test_eval_logs)

            if iter % args.train_eval_interval == 0:
                # evaluate train
                train_eval_logs = trainer.eval_iteration_multienv(
                    get_prompt, prompt_trajectories_list,
                    eval_episodes, train_env_name_list, info, variant, env_list, iter_num=iter + 1,
                    print_logs=True, no_prompt=args.no_prompt, group='train')
                outputs.update(train_eval_logs)

            if iter % variant['save_interval'] == 0:
                trainer.save_model(
                    env_name=args.env,
                    postfix=model_post_fix + '_iter_' + str(iter),
                    folder=save_path)

            outputs.update({"global_step": iter})  # set global step as iteration

            if log_to_wandb:
                wandb.log(outputs)

        trainer.save_model(env_name=args.env, postfix=model_post_fix + '_iter_' + str(iter), folder=save_path)

    else:
        ####
        # start evaluating
        ####

        saved_model_path = variant['load_path']  # os.path.join(save_path, variant['load_path'])
        print(saved_model_path)
        model.load_state_dict(torch.load(saved_model_path))
        print('model initialized from: ', saved_model_path)
        
        # update prompt params to match the load model! (num and length segments)
        num_segments = saved_model_path.split('-J-')[1].split('-')[0]
        variant['num_traj_prompt_j'] = int(num_segments)
        seg_length = saved_model_path.split('-H-')[1].split('-')[0]
        variant['prompt_length'] = int(seg_length)

        if not args.finetune:
            eval_train_task = variant["eval_train_task"]  # whether to eval on the training tasks or the holdout eval tasks...
            if eval_train_task:
                train_eval_indices = [5, 10, 15, 20, 30]
                train_eval_prompt_trajectories_list = [prompt_trajectories_list[i] for i in train_eval_indices]
                train_eval_env_name_list = [train_env_name_list[i] for i in train_eval_indices]
                train_eval_info = {}
                for eval_train_env_name in train_eval_env_name_list:
                    train_eval_info[eval_train_env_name] = info[eval_train_env_name]
                train_eval_env_list = [env_list[i] for i in train_eval_indices]

            exp_string = f"{group_name}-eval-evalSeed-{seed}"
            exp_string += f"-J-{saved_model_path.split('-J-')[1].split('-')[0]}"
            exp_string += f"-H-{saved_model_path.split('-H-')[1].split('-')[0]}"
            if eval_train_task:
                exp_string += f"-IDTasks"
            else:
                exp_string += f"-OODTasks"
            if variant["prompt_tune"]:  # prompt tuning
                exp_string += "-" + variant["prompt_tuner"]

                if variant["prompt_tuner"] == "zoranksgd":
                    exp_string += f"-eta-{variant['zorank_eta']}"
                    exp_string += f"-m-{variant['zorank_m']}"
                    exp_string += f"-mode-{variant['zorank_mode']}"

                elif variant["prompt_tuner"] == "hillclimbing":
                    exp_string += f"-eta-{variant['hillclimbing_eta']}"

                elif variant["prompt_tuner"] == "bandit":
                    exp_string += f"-eps-{variant['bandit_epsilon']}"
                    if variant["bandit_use_transformer_features"]:
                        exp_string += "-transformerFeatures"
                    else:
                        exp_string += "-rawSegments"
                    exp_string += f"-arch-{variant['bandit_arch']}"

                elif variant["prompt_tuner"] == "bandit_ts":
                    exp_string += f"-thompsonSampling"
                    if variant["bandit_use_transformer_features"]:
                        exp_string += "-transformerFeatures"
                    else:
                        exp_string += "-rawSegments"
                elif variant["prompt_tuner"] == "bandit_ucb":
                    exp_string += f"-UCB"
                    if variant["bandit_use_transformer_features"]:
                        exp_string += "-transformerFeatures"
                    else:
                        exp_string += "-rawSegments"

            else:  # no tuning
                exp_string += "-noTune"
                if variant["random_prompt_time"]:
                    exp_string += "-randomPrompts"
                else:
                    exp_string += "-endOfTrajPrompts"

            if variant["extra_exp_str"]:
                exp_string += "_" + variant["extra_exp_str"]

            print(f"Experiment: {exp_string}")

            if not args.prompt_tune:
                """ MODEL EVALUATION"""
                if log_to_wandb:
                    exp_prefix = exp_string
                    wandb.init(
                        name=exp_prefix,
                        group=group_name,
                        project='prompt-decision-transformer',
                        config=variant,
                        reinit=True,
                    )
                if eval_train_task:
                    for iter in range(variant["eval_rollouts"]):
                        print(f"Iter: {iter}")
                        eval_iter_num = int(saved_model_path.split('_')[-1])
                        eval_logs = trainer.eval_iteration_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list,
                            iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval')
                        outputs = eval_logs
                        outputs.update({"global_step": iter})  # set global step as iteration

                        if log_to_wandb:
                            wandb.log(outputs)

                else:
                    for iter in range(250):
                        print(f"Iter: {iter}")
                        eval_iter_num = int(saved_model_path.split('_')[-1])
                        eval_logs = trainer.eval_iteration_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval')
                        outputs = eval_logs
                        outputs.update({"global_step": iter})  # set global step as iteration

                        if log_to_wandb:
                            wandb.log(outputs)

            else:
                """ MODEL BANDIT TRAINING"""
                prompt_tuner = variant['prompt_tuner']

                print(f"Prompt Tuner: {prompt_tuner}")

                if log_to_wandb:
                    model_seed = saved_model_path.split('-')[8]
                    exp_prefix = exp_string
                    wandb.init(
                        name=exp_prefix,
                        group=group_name,
                        project='prompt-decision-transformer',
                        config=variant,
                        reinit=True,

                    )
                if eval_train_task:
                    eval_iter_num = int(saved_model_path.split('_')[-1])

                    if prompt_tuner == "bandit":
                        eval_logs = trainer.bandit_evaluation_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "hillclimbing":
                        eval_logs = trainer.hillclimbing_evaluation_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "zoranksgd":
                        eval_logs = trainer.zoranksgd_evaluation_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "bandit_ts":  # bandit propmt tuning with thompson sampling
                        eval_logs = trainer.thompson_evaluation_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "bandit_ucb":
                        eval_logs = trainer.ucb_evaluation_multienv(
                            get_prompt, train_eval_prompt_trajectories_list,
                            eval_episodes, train_eval_env_name_list, train_eval_info, variant, train_eval_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)


                    else:
                        raise NotImplementedError()

                else:
                    eval_iter_num = int(saved_model_path.split('_')[-1])
                    prompt_tuner = variant['prompt_tuner']
                    if prompt_tuner == "bandit":
                        eval_logs = trainer.bandit_evaluation_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "hillclimbing":
                        eval_logs = trainer.hillclimbing_evaluation_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "zoranksgd":
                        eval_logs = trainer.zoranksgd_evaluation_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "bandit_ts":  # bandit propmt tuning with thompson sampling
                        eval_logs = trainer.thompson_evaluation_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list, iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    elif prompt_tuner == "bandit_ucb":  # bandit propmt tuning with thompson sampling
                        eval_logs = trainer.ucb_evaluation_multienv(
                            get_prompt, test_prompt_trajectories_list,
                            eval_episodes, test_env_name_list, test_info, variant, test_env_list,
                            iter_num=eval_iter_num,
                            print_logs=True, no_prompt=args.no_prompt, group='eval', wandb=wandb)

                    else:
                        raise NotImplementedError()


        else:
            """ MODEL FINE-TUNING """
            model_seed = saved_model_path.split('-')[8]
            exp_prefix = f'{group_name}-seed-{model_seed}-J-{args.prompt_episode}-H-{args.prompt_length}-{random.randint(int(1e5), int(1e6) - 1)}-finetune-seed-{seed}'

            if log_to_wandb:
                wandb.init(
                    name=exp_prefix,
                    group=group_name,
                    project='prompt-decision-transformer',
                    config=variant,
                    reinit=True,
                )
                save_path += wandb.run.name
            model_post_fix = '_TRAIN_' + variant['train_prompt_mode'] + '_TEST_' + variant['test_prompt_mode']
            model_post_fix += '_FINETUNE'

            for iter in range(250):
                test_eval_logs = trainer.finetune_eval_iteration_multienv(
                    get_prompt, get_batch_finetune, test_prompt_trajectories_list, test_trajectories_list,
                    eval_episodes, test_env_name_list, test_info,
                    variant, test_env_list, iter_num=iter + 1,
                    print_logs=True, no_prompt=args.no_prompt,
                    group='finetune-test', finetune_opt=variant['finetune_opt'])
                outputs = test_eval_logs
                outputs.update({"global_step": iter})  # set global step as iteration

                if log_to_wandb:
                    wandb.log(outputs)

            wandb.finish()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str,
                        default='cheetah_vel')  # ['cheetah_vel', 'ant_dir', 'ML1-pick-place-v2']
    parser.add_argument('--dataset_mode', type=str, default='expert')
    parser.add_argument('--test_dataset_mode', type=str, default='expert')
    parser.add_argument('--train_prompt_mode', type=str, default='expert')
    parser.add_argument('--test_prompt_mode', type=str, default='expert')
    parser.add_argument('--seed', type=int, default=1)

    parser.add_argument('--prompt-episode', type=int, default=1)
    parser.add_argument('--prompt-length', type=int, default=5)
    parser.add_argument('--stochastic-prompt', action='store_true', default=True)  # is this used anywhere? Is it supposed to do what the `--random_prompt_time` flag that I added does?
    parser.add_argument('--no-prompt', action='store_true', default=False)
    parser.add_argument('--no-r', action='store_true', default=False)
    parser.add_argument('--no-rtg', action='store_true', default=False)
    parser.add_argument('--random_prompt_time', action='store_true', default=True)  # sample prompts with random time, instead of always using the last steps in prompt trajs...

    # prompt tuning args
    parser.add_argument('--load_path', type=str, default='')  # PDT CP file to load
    parser.add_argument('--evaluation', action='store_true', default=False)  # if True, no training, only eval the given model, should be true for prompt tune
    parser.add_argument('--eval_rollouts', type=int, default=250)  # number of rollouts to do when prompt tuning
    parser.add_argument('--prompt_tune', action='store_true', default=False)  # if True, do prompt tuning during evaluation
    parser.add_argument('--eval_train_task', action='store_true', default=False)  # if True, does evaluation (and promp-tuning) on subset of in-distribution training tasks. If talse, does evaluation (and prompt-tuning) on holdout test tasks.
    parser.add_argument('--prompt_tuner', type=str, default='bandit')  # bandit, bandit_ts, hillclimbing, zoranksgd
    parser.add_argument('--zorank_mode', type=str, default='online')  # online, offline
    parser.add_argument('--zorank_eta', type=str, default="1e-4")  # 1e-4, 1e-3, 1e-2, or 'schedule' -- learning rate with the estimated gradient
    parser.add_argument('--zorank_m', type=int, default=5)  # number of perturbed prompt versions to estimate gradient with ZORankgSGD
    parser.add_argument('--hillclimbing_eta', type=str, default="1e-2")  # 1e-4, 1e-3, 1e-2, or 'schedule' -- scale of the noise added to the prompt
    parser.add_argument("--bandit_arch", type=int, nargs='+', default=[32, 32, 32])
    parser.add_argument('--bandit-use-transformer-features', action='store_true', default=True)
    parser.add_argument('--bandit_epsilon', type=str, default="0.1")  # float in [0, 1], or  'schedule' -- epsilon-greedy exploration rate
    parser.add_argument('--extra_exp_str', type=str, default="")

    parser.add_argument('--num_traj_prompt_j', type=int, default=1)
    parser.add_argument('--finetune', action='store_true', default=False)
    parser.add_argument('--finetune_steps', type=int, default=10)
    parser.add_argument('--finetune_batch_size', type=int, default=256)
    parser.add_argument('--finetune_opt', action='store_true', default=True)
    parser.add_argument('--finetune_lr', type=float, default=1e-4)
    parser.add_argument('--no_state_normalize', action='store_true', default=False)
    parser.add_argument('--average_state_mean', action='store_true', default=True)
    parser.add_argument('--render', action='store_true', default=False)

    parser.add_argument('--mode', type=str, default='normal')
    parser.add_argument('--K', type=int, default=20)
    parser.add_argument('--pct_traj', type=float, default=1.)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--num_eval_episodes', type=int, default=1)
    parser.add_argument('--max_iters', type=int, default=5000)
    parser.add_argument('--num_steps_per_iter', type=int, default=10)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', '-w', type=bool, default=True)
    parser.add_argument('--train_eval_interval', type=int, default=500)
    parser.add_argument('--test_eval_interval', type=int, default=100)
    parser.add_argument('--save-interval', type=int, default=500)

    args = parser.parse_args()
    print(args.prompt_tune, args.evaluation)

    if args.evaluation:
        for seed in [11, 12, 13]:  # repeat eval with multiple seeds
            args.seed = seed
            experiment_mix_env('gym-experiment', variant=vars(args))
    else:
        experiment_mix_env('gym-experiment', variant=vars(args))
