from agents.multi_pg_agent import MultiDimPGAgent
from collections import OrderedDict

import os
import time
from typing import Optional, Sequence, List
import pickle as pkl

import gym
import numpy as np
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from infrastructure import pytorch_util as ptu

from infrastructure import utils
from infrastructure.logger import Logger
from envs.circuit_env_mask_stdgraph import CircuitEnv

import wandb
import random

from Netlist2Graph.netlist_dataset import NetlistDataset
import math 

MAX_NVIDEO = 10

def get_sample_order(circuit_order, INDEX, sample_map):
    sample_order = []
    for co in circuit_order:
        sample_order.append(sample_map[co]["sampled_nodes"])
    return sample_order, sample_map[INDEX]["sampled_nodes"]

def index_experts(sample_map):
    INDEX = list(sample_map.keys())[0]
    INDEXES = list(sample_map.keys())
    return INDEX, INDEXES


def index_half(net_dataset):
    name_map = {}
    half_indicies = []
    half_name_list = ["A_cir_71", "A_cir_77","A_cir_84","A_cir_86", "A_cir_88", 
                        "A_cir_90", "A_cir_92", "A_cir_281", "A_cir_293", "A_cir_299", "A_cir_309",
                        "A_cir_329", "A_cir_331",
                         "A_cir_364", "A_cir_402", "A_cir_407", "A_cir_429",
                        "A_cir_518", 
                        "AAA_Comp_SA", "AAA_Comp_DSA_AP", "AAA_Comp_DoubleTail_AP"]
    for index in range(len(net_dataset)):
        graph,l = net_dataset.get(index)   
        for n in graph.get_nodes():
            if n[0] == "VSS": 
                print(index, n)
                name = n[1]["custom_features"]["name"]
                if name in half_name_list: # or "DSA" in name:
                    half_indicies.append(index)
                    name_map[index] = name
    return name_map, half_indicies

def shuffle_traj_dict(trajs_dict, traj_length=None):
    # zip the dictionaries together
    combined = list(zip(*[trajs_dict[k] for k in trajs_dict]))

    # shuffle the zipped list, unit trajectories
    random.shuffle(combined)

    if traj_length is not None and traj_length < len(combined):
        combined = combined[:traj_length]

    shuffled = {k: [] for k in trajs_dict}
    for values in combined:
        # reassemble the dictionaries
        for i, k in enumerate(trajs_dict):
            shuffled[k].append(values[i])
    return shuffled

def run_training_loop(args, data_path):
    
    # Set-up logger
    logger = Logger(args.logdir, config=dict(        
        embed_dim=args.embed_size,
        n_gcn_layers=args.n_gcn_layers,
        n_layers=args.n_layers,
        layer_size=args.layer_size,
        gamma=args.discount,
        learning_rate=args.learning_rate,
        eps_clip=args.eps_clip,
        K_epochs=args.K_epochs,
        use_baseline=args.use_baseline,
        use_reward_to_go=args.use_reward_to_go,
        normalize_advantages=args.normalize_advantages,
        baseline_learning_rate=args.baseline_learning_rate,
        baseline_gradient_steps=args.baseline_gradient_steps,
        gae_lambda=args.gae_lambda,))

    # Create random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Init the GPU
    ptu.init_gpu(use_gpu=not args.no_gpu, gpu_id=args.which_gpu)

    # Create the Netlist Dataset
    net_dataset = NetlistDataset("../RL/scripts/dataset/full_edit", "data_netlist_opampedit.pt")
    print(args.agent_name)
    # load a dictionary of  circuit_order_idx: {"circuit_name": ""
    #                                           "sample_order": []}
    # for each one of the net_dataset graphs that are experts for the op amp task
    with open("opamp_sampled_circuits.pkl", "rb") as f:
        full_sample_map = pkl.load(f)
    sample_map = {}
    for index, samp_dict in full_sample_map.items():
        if samp_dict["circuit_name"] in ['A_cir_77','A_cir_84','A_cir_86']:
            sample_map[index] = samp_dict

    len_dataset = len(net_dataset)

    discrete = True 
    env = CircuitEnv(80, 10, 7, net_dataset, set_masks=not(args.no_mask), simple=args.simple, 
                     embed_dim=args.embed_size, n_gcn_layers=args.n_gcn_layers, 
                     n_layers=args.n_layers, layer_size=args.layer_size,
                     learning_rate=args.learning_rate, gpu_id=args.which_gpu, save_dir=data_path, diff=True, se=False)
    envs_list = [CircuitEnv(80, 10, 7, net_dataset, set_masks=not(args.no_mask), simple=args.simple, 
                     embed_dim=args.embed_size, n_gcn_layers=args.n_gcn_layers, 
                     n_layers=args.n_layers, layer_size=args.layer_size,
                     learning_rate=args.learning_rate, stack_index=i, gpu_id=args.which_gpu, save_dir=data_path,
                      diff=True, se=False) for i in range(args.num_workers)]

    max_ep_len = args.max_ep_len 
    ob_feature_dim = env.num_types
    ac_dim_list = list(env.action_space) 
    ac_dim_list = [int(a_space.n) for a_space in ac_dim_list]

    # initialize agent
    agent = MultiDimPGAgent(
        ob_feature_dim,
        ac_dim_list,
        discrete,
        embed_dim=args.embed_size,
        n_gcn_layers=args.n_gcn_layers,
        n_layers=args.n_layers,
        layer_size=args.layer_size,
        gamma=args.discount,
        learning_rate=args.learning_rate,
        eps_clip=args.eps_clip,
        K_epochs=args.K_epochs,
        use_baseline=args.use_baseline,
        use_reward_to_go=args.use_reward_to_go,
        normalize_advantages=args.normalize_advantages,
        baseline_learning_rate=args.baseline_learning_rate,
        baseline_gradient_steps=args.baseline_gradient_steps,
        gae_lambda=args.gae_lambda,
    )
    agent.share_memory()


    total_envsteps = 0
    start_time = time.time()

    # Try BC
    optimizer_bc = optim.Adam(agent.actor.parameters(), args.learning_rate/4)
    if args.pretrain:
        end_step = args.pretrain_batch_size
        start_step = 0
        circuit_indicies = np.random.permutation(len_dataset)
        for itr in range(args.pretrain_itr):
            print(f"\n********** Pretraining {itr} **********")
            print(f"start {start_step} , end {end_step} , len {len_dataset}")
            print(circuit_indicies)

            circuit_order = circuit_indicies[start_step:end_step]
            optimizer_bc.zero_grad()
            trajs_bc_list = env.reconstruct_inverter(args.pretrain_batch_size) if args.simple \
                                    else env.get_expert_trajectory(args.pretrain_batch_size,
                                                                   args.random_start, args.random_circuit,
                                                                   circuit_order=circuit_order)  #still want rollout for behavioral cloning
            end_step = end_step + args.pretrain_batch_size
            start_step = start_step + args.pretrain_batch_size
            if end_step> len_dataset and start_step>len_dataset:
                circuit_indicies = np.random.permutation(len_dataset)
                end_step = args.pretrain_batch_size
                start_step = 0
            if end_step > len_dataset:
                end_step = len_dataset
                start_step = start_step
    

            t = trajs_bc_list[0]
            graph_list = t['graph']
            reward = t['reward']
            total_reward = sum(reward)

            trajs_bc = {k: [traj[k] for traj in trajs_bc_list] for k in trajs_bc_list[0]}
            o_flat: Sequence[float] = []
            a_flat: Sequence[float] = []
            m_flat: Sequence[float] = []
            for (o, a, m) in zip(trajs_bc["observation"], trajs_bc["action"], trajs_bc["mask"]):
                a_flat = a_flat + list(a)
                [o_flat.append(o[i]) for i in range(len(o))]
                [m_flat.append(m[i]) for i in range(len(m))]
                # m_flat = m_flat + list(m)
            obs = o_flat
            actions = np.array(a_flat)
            actions = ptu.from_numpy(actions)
            ac0 = actions[:,0].to(torch.long).squeeze()
            ac1 = actions[:,1].to(torch.long).squeeze()
            ac2 = actions[:,2].to(torch.long).squeeze()
            ac3 = actions[:,3].to(torch.long).squeeze()

            logit_start = time.time()
            logit= agent.actor(obs)
            logit_end = time.time()
            ac_prob0 = F.softmax(logit[0], dim=-1) 
            ac_prob1 = F.softmax(logit[1], dim=-1) 
            ac_prob2 = F.softmax(logit[2], dim=-1) 
            ac_prob3 = F.softmax(logit[3], dim=-1) 

            log_prob_ac0 = -torch.mean(torch.log(ac_prob0[np.arange(len(ac0)),ac0]))
            log_prob_ac1 = -torch.mean(torch.log(ac_prob1[np.arange(len(ac1)),ac1]))
            log_prob_ac2 = -torch.mean(torch.log(ac_prob2[np.arange(len(ac2)),ac2]))
            log_prob_ac3 = -torch.mean(torch.log(ac_prob3[np.arange(len(ac3)),ac3]))

            # Construct the loss using log probabilities and apply negative sum
            loss_bc = torch.sum(log_prob_ac0 + log_prob_ac1 + log_prob_ac2 + log_prob_ac3)
            bc_end = time.time()
            loss_bc.backward()
            agent.optimizer.step()

            logs = OrderedDict()
            logs['loss_bc'] = loss_bc
            logs['log_prob_ac0'] = log_prob_ac0
            logs['log_prob_ac1'] = log_prob_ac1
            logs['log_prob_ac2'] = log_prob_ac2
            logs['log_prob_ac3'] = log_prob_ac3

            for key, value in logs.items():
                print("{} : {}".format(key, value))
                logger.log_scalar(value, key, step=itr)
            
            if (itr+1)%500 == 0:
                torch.save(agent.actor.state_dict(), '../RL/trained_models/fine_tuned_model_'+args.pretrain_name+str(itr)+'.pt')
            if itr+1 == args.pretrain_itr:
                torch.save(agent.actor.state_dict(), '../RL/trained_models/fine_tuned_model_'+args.pretrain_name+'.pt')
        if args.pretrain_itr==0: 
            agent.actor.load_state_dict(torch.load('../RL/trained_models/fine_tuned_model_'+args.pretrain_name+'.pt', map_location=ptu.device), strict=False)
            agent.old_actor.load_state_dict(agent.actor.state_dict(), strict=False)
            print("loading previous pretrained model :)")
        else:
            torch.save(agent.actor.state_dict(), '../RL/trained_models/fine_tuned_model_'+args.pretrain_name+'.pt')
            agent.old_actor.load_state_dict(agent.actor.state_dict())
    
    # Expert training
    INDEX, INDEXES = index_experts(sample_map)
    name_map_rev = {v["circuit_name"]: k for k, v in sample_map.items()}

    if not args.simple:
        # Pretrain discriminator first
        if args.pretrain_disc: 
            if args.pretrain_itr ==0: 
                itr = 0
            gen_length = 64
            end_step_expert = gen_length
            start_step_expert = 0
            name_map = {}
            for i in range(itr,50+itr+1):
                print("----Discriminator {}----".format(i))

                expert_indicies = random.choices(INDEXES, k=args.batch_size//10) 
                sample_order, sample_INDEX = get_sample_order(expert_indicies, INDEX, sample_map)
                trajs, curriculum_check, envsteps_this_batch, _ = utils.sample_trajectories(env, agent.actor,
                                                    args.batch_size, max_ep_len, 
                                                    args.random_start, args.random_circuit, domain=False, circuit_index=INDEX, 
                                                    num_workers=1, gpu_id=args.which_gpu,
                                                    sample_nodes=sample_INDEX) 
                
                trajs_dict = {k: [traj[k] for traj in trajs] for k in trajs[0]}
                                
                gen_nextobs_flat = []
                for (o_list, val_list) in zip(trajs_dict["next_observation"], trajs_dict["valid"]):
                    for (o,v) in zip(o_list, val_list):
                        if v: 
                            gen_nextobs_flat.append(o)
  
                
                traj_expert = env.get_expert_trajectory(gen_length, args.random_start, args.random_circuit,
                                                        circuit_order=expert_indicies, num_sample=17, 
                                                        sample_nodes_list=sample_order)
                trajs_bc_expert = {k: [traj[k] for traj in traj_expert] for k in traj_expert[0]}
                
                expert_nextobs_flat = []
                for (o_real) in trajs_bc_expert["next_observation"]:
                    #a_flat = a_flat + list(a)
                    [expert_nextobs_flat.append(o_real[i]) for i in range(len(o_real))]
                total_envsteps += envsteps_this_batch
                if len(gen_nextobs_flat)>10 :
                    disc_info = env.discriminator.update(expert_nextobs_flat, gen_nextobs_flat, True, pretrain=True)
                    for key, value in disc_info.items():
                        print("{} : {}".format(key, value))
                        logger.log_scalar(value, key, step=i)
                    for name, parameter in env.discriminator.named_parameters():
                        if parameter.grad is not None:
                            try:
                                wandb.log({f"gradients/{name}": wandb.Histogram(parameter.grad.cpu().numpy())}, step=i)
                            except:
                                continue
            
            torch.save(env.discriminator.state_dict(), 'pretrained_disc_opamp_partial_similar_half_formal2.pt')
            for e in envs_list:
                e.discriminator.load_state_dict(torch.load('pretrained_disc_opamp_partial_similar_half_formal2.pt',map_location=ptu.device))
          
        else:
            env.discriminator.load_state_dict(torch.load('pretrained_disc_opamp_partial_similar_half_formal2.pt',map_location=ptu.device))
            for e in envs_list:
                e.discriminator.load_state_dict(torch.load('pretrained_disc_opamp_partial_similar_half_formal2.pt',map_location=ptu.device))
            disc_info = {}
            i = -1 
    else:
        i = -1 
        print("No discriminator training, running simple agent")
        disc_info = {}


    # PPO
    load_prev = (args.load_itr>0) and args.pretrain
    if load_prev: 
        print("LOADING PREV AGENT STATES")
        agent.actor.load_state_dict(torch.load('../RL/trained_models/save_actor_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        agent.old_actor.load_state_dict(torch.load('../RL/trained_models/save_old_actor_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        agent.optimizer.load_state_dict(torch.load('../RL/trained_models/save_optimizer_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        agent.scheduler.load_state_dict(torch.load('../RL/trained_models/save_scheduler_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        if not args.simple: 
            env.discriminator.load_state_dict(torch.load('../RL/trained_models/save_disc_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        if args.use_baseline:
            agent.critic.load_state_dict(torch.load('../RL/trained_models/save_critic_{}_'.format(args.load_itr)+args.agent_name+'.pt', map_location=ptu.device))
        print(args.load_itr, args.agent_name)
        i = args.load_itr
        

    train_disc = not args.simple 
    I = i+1  
    found_curriculum_step = False
    curriculum_steps = 0
    curriculum_itrs = I
    workers = 2
    
    for itr in range(I, I+args.n_iter+1):
        print(f"\n********** Iteration {itr} ************")

        
        expert_indicies = random.choices(INDEXES, k=args.pretrain_batch_size)
        sample_order, sample_INDEX = get_sample_order(expert_indicies, INDEX, sample_map)

        trajs, curriculum_check, envsteps_this_batch, traj_time = utils.sample_trajectories(envs_list, agent.actor,
                                                     args.batch_size, max_ep_len,
                                                     args.random_start, args.random_circuit,
                                                     circuit_order=expert_indicies, sample_order=sample_order,
                                                     check_curriculum_reward=False,
                                                     num_workers=workers, gpu_id=args.which_gpu, iteration=itr) 
        if traj_time > 400: 
            workers = min(workers+1, args.num_workers)
        elif traj_time < 200:
            workers = max(workers-1, 2)
        total_envsteps += envsteps_this_batch

        # trajs should be a list of dictionaries of NumPy arrays, where each dictionary corresponds to a trajectory.
        # this line converts this into a single dictionary of lists of NumPy arrays.
        trajs_dict = {k: [traj[k] for traj in trajs] for k in trajs[0]}

        valid_flat = []
        for val_list in  trajs_dict["valid"]:
            for v in val_list:
               valid_flat.append(v)

        if args.bc:
            l1 = args.lambda1**itr
            l0 = args.lambda0  
            
            exp_start = time.time()
            traj_expert = env.get_expert_trajectory(args.pretrain_batch_size, args.random_start, args.random_circuit,
                                                    num_sample=17, circuit_order=expert_indicies,  
                                                    sample_nodes_list = sample_order)

            trajs_bc_expert = {k: [traj[k] for traj in traj_expert] for k in traj_expert[0]}
            disc_info = {}
            if not args.simple and train_disc:
                gen_nextobs_flat = []
                for (o_list, val_list) in zip(trajs_dict["next_observation"], trajs_dict["valid"]):
                    #a_flat = a_flat + list(a)
                    for (o,v) in zip(o_list, val_list):
                        if v: 
                            gen_nextobs_flat.append(o)

                expert_nextobs_flat = []
                for (o_real) in trajs_bc_expert["next_observation"]:
                    [expert_nextobs_flat.append(o_real[i]) for i in range(len(o_real))]
              
                if len(gen_nextobs_flat)>10 and not(itr%5):
                    disc_info = env.discriminator.update(expert_nextobs_flat, gen_nextobs_flat, step=False) 
                    #just visualize the loss

            train_info = agent.update(trajs_dict["observation"], trajs_dict["action"], 
                                    trajs_dict["reward"], trajs_dict["terminal"], trajs_dict["mask"],
                                    trajs_bc_expert["observation"], trajs_bc_expert["action"],
                                    trajs_bc_expert["mask"], l0, l1, valid_flat, entropy_weight=args.entropy)
                
        else:
            # train the agent using the sampled trajectories and the agent's update function
            train_info = agent.update(trajs_dict["observation"], trajs_dict["action"], 
                                    trajs_dict["reward"], trajs_dict["terminal"], trajs_dict["mask"],
                                    None, None, None, 0, 0,
                                    valid=valid_flat, entropy_weight=args.entropy)

        sum_save = 1 if load_prev else 0
        if (itr % args.save_itr)==0: 
            torch.save(agent.actor.state_dict(), '../RL/trained_models/save_actor_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
            torch.save(agent.old_actor.state_dict(), '../RL/trained_models/save_old_actor_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
            torch.save(agent.optimizer.state_dict(), '../RL/trained_models/save_optimizer_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt') #"optimizer.pth")
            torch.save(agent.scheduler.state_dict(), '../RL/trained_models/save_scheduler_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
            if not args.simple: 
                torch.save(env.discriminator.state_dict(), '../RL/trained_models/save_disc_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
                torch.save(env.discriminator.optimizer.state_dict(), '../RL/trained_models/save_disc_optimizer_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
                torch.save(env.discriminator.scheduler.state_dict(), '../RL/trained_models/save_disc_scheduler_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')
            if args.use_baseline:
                torch.save(agent.critic.state_dict(), '../RL/trained_models/save_critic_load{}_'.format(args.load_itr+sum_save+itr)+args.agent_name+'.pt')


        if itr % args.scalar_log_freq == 0:
            # save eval metrics
            agent.actor.eval()
            with torch.no_grad():
                eval_trajs = trajs


                logs = utils.compute_metrics(trajs, eval_trajs)
                # compute additional metrics
                logs.update(train_info)
                if not args.simple:
                    logs.update(disc_info)
                logs["Train_EnvstepsSoFar"] = total_envsteps
                logs["TimeSinceStart"] = time.time() - start_time
                if itr == 0:
                    logs["Initial_DataCollection_AverageReturn"] = logs[
                        "Train_AverageReturn"
                    ]

                # perform the logging
                for key, value in logs.items():
                    print("{} : {}".format(key, value))
                    logger.log_scalar(value, key, step=itr)

                for name, parameter in agent.actor.named_parameters():
                    if parameter.grad is not None:
                        try:
                            wandb.log({f"gradients/{name}": wandb.Histogram(parameter.grad.cpu().numpy())}, step=itr)
                        except:
                            continue
                print("Done logging...\n\n")
            agent.actor.train()
    
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def main():
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name", type=str, default='circuit', required=False)
    parser.add_argument("--exp_name", type=str, default='basic', required=False)
    parser.add_argument("--n_iter", "-n", type=int, default=1000)

    parser.add_argument("--use_reward_to_go", "-rtg", action="store_true")
    parser.add_argument("--use_baseline", action="store_true")
    parser.add_argument("--baseline_learning_rate", "-blr", type=float, default=5e-4)
    parser.add_argument("--baseline_gradient_steps", "-bgs", type=int, default=5)
    parser.add_argument("--gae_lambda", type=float, default=None)
    parser.add_argument("--normalize_advantages", "-na", action="store_true")
    parser.add_argument(
        "--batch_size", "-b", type=int, default=1000
    )  # steps collected per train iteration
    parser.add_argument(
        "--eval_batch_size", "-eb", type=int, default=400
    )  # steps collected per eval iteration

    parser.add_argument("--discount", type=float, default=1.0)
    parser.add_argument("--learning_rate", "-lr", type=float, default=1e-3)
    parser.add_argument("--embed_size", type=int, default=6)
    parser.add_argument("--n_gcn_layers",  type=int, default=2)
    parser.add_argument("--n_layers", "-l", type=int, default=2)
    parser.add_argument("--layer_size", "-s", type=int, default=64)
    parser.add_argument("--eps_clip", type=float, default=0.1)
    parser.add_argument("--K_epochs", type=int, default=5)
    parser.add_argument("--pretrain", action="store_true")
    parser.add_argument("--pretrain_itr", type=int, default=0)
    parser.add_argument("--pretrain_disc", action="store_true")
    parser.add_argument("--pretrain_critic", action="store_true")
    parser.add_argument("--load_prev", action="store_true")
    parser.add_argument("--no_mask", action="store_true")
    parser.add_argument("--simple", action="store_true")
    parser.add_argument("--random_start", action="store_true")
    parser.add_argument("--random_circuit", action="store_true")
    parser.add_argument("--agent_name", type=str, default='', required=False)
    parser.add_argument("--pretrain_name", type=str, default='', required=False)
    parser.add_argument("--save_itr", type=int, default=75)
    parser.add_argument("--load_itr", type=int, default=0)
    parser.add_argument("--pretrain_batch_size", type=int, default=10)
    parser.add_argument("--bc", action="store_true") # for the behavior cloning
    parser.add_argument("--curriculum", action="store_true")
    parser.add_argument(
        "--ep_len", type=int
    )  # students shouldn't change this away from env's default
    parser.add_argument("--seed", type=int, default=1)
    parser.add_argument("--no_gpu", "-ngpu", action="store_true")
    parser.add_argument("--which_gpu", "-gpu_id", default=0)
    parser.add_argument("--log_freq", type=int, default=10)
    parser.add_argument("--scalar_log_freq", type=int, default=1)
    parser.add_argument("--max_ep_len", type=int, default=40)  # max episode length
    parser.add_argument("--num_workers", "-nw", type=int, default=4)  # number of env workers
    parser.add_argument("--lambda0", type=float, default=0.9) # for the bc loss tapering
    parser.add_argument("--lambda1", type=float, default=0.995) # for the bc loss tapering
    parser.add_argument("--entropy", type=float, default=5e-5) # for the entropy regularization

    parser.add_argument("--action_noise_std", type=float, default=0)

    args = parser.parse_args()

    # create directory for logging
    logdir_prefix = "circuit" 

    data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../data")

    if not (os.path.exists(data_path)):
        os.makedirs(data_path)

    logdir = (
        logdir_prefix
        + args.exp_name
        + "_"
        + args.env_name
        + "_"
        + time.strftime("%Y-%m-%d_%H-%M-%S")
    )
    logdir = os.path.join(data_path, logdir)
    args.logdir = logdir
    if not (os.path.exists(logdir)):
        os.makedirs(logdir)

    run_training_loop(args,logdir)


import pdb, traceback, sys
import torch.multiprocessing as mp
if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    try:
        main()
    except:
       extype, value, tb = sys.exc_info()
       traceback.print_exc()
       pdb.post_mortem(tb)
