import os
import torch
import torch.nn as nn

from args_factory import get_args
from loaders import get_loaders
from utils import Scheduler, Statistics
from networks import get_network, fuse_BN_wrt_Flatten, add_BN_wrt_Flatten
from model_wrapper import BasicModelWrapper, PGDModelWrapper, BoxModelWrapper, HZonoModelWrapper, TAPSModelWrapper, STAPSModelWrapper, SmallBoxModelWrapper, SmallHZonoModelWrapper, DeepPolyModelWrapper, SmallDPModelWrapper, GradAccuModelWrapper, MultiFacetModelWrapper, MNBaBDeepPolyModelWrapper
from utils import write_perf_to_json, load_perf_from_json, fuse_BN, seed_everything
from tqdm import tqdm
import random
import numpy as np
from regularization import compute_fast_reg, compute_vol_reg, compute_L1_reg, compute_PI_reg, compute_neg_reg
import time
from datetime import datetime
from AIDomains.abstract_layers import Sequential
import logging
from AIDomains.zonotope import HybridZonotope
import math
import json

import warnings
warnings.filterwarnings("ignore")

try:
    import neptune
except:
    neptune = None
# # You can manually disable neptune by setting it to None even if you installed it
# neptune = None

from get_stat import PI_loop, relu_loop, test_loop
from evo_wrapper_new import PGPEEvoWrapper, GSMEvoWrapper, MixedEvoWrapper, BasicEvoWrapper, evo_train_loop
# from rso_wrapper import RSOWrapper

import torch.multiprocessing as mp
import torch.distributed as dist

def setup(rank, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    # print(args.port)
    os.environ['MASTER_PORT'] = args.port if args.port is not None else '12356'

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
    torch.cuda.set_device(args.device)

def cleanup():
    dist.destroy_process_group()

def rank_print(rank, *args, **kwargs):
    if rank == 0:
        print(*args, **kwargs)

def get_model_wrapper(args, net, device, input_dim):
    # Define model wrapper here
    if args.multi_facets is None:
        if args.use_pgd_training:
            model_wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_vanilla_ibp:
            if args.use_small_box:
                model_wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_HBox_training:
            if args.use_small_box:
                model_wrapper = SmallHZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage, domain='hbox')
            else:
                model_wrapper = HZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, domain='hbox')
        elif args.use_Zono_training:
            if args.use_small_box:
                model_wrapper = SmallHZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage, domain='zono')
            else:
                model_wrapper = HZonoModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, domain='zono')
        elif args.use_TAPS_training:
            if args.use_small_box:
                model_wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
        elif args.use_DP_training:
            #model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            if args.use_small_box:
                model_wrapper = SmallDPModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            else:
                model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
                # model_wrapper = MNBaBDeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
        elif args.use_DPBox_training:
            model_wrapper = DeepPolyModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, use_dp_box=True)
        if args.grad_accu_batch is not None:
            model_wrapper = GradAccuModelWrapper(model_wrapper, args)
    if args.multi_facets is not None:
        assert len(args.multi_facets) >= 2, "At least two facets if multiple facet loss is used."
        wrappers = []
        for facet in args.multi_facets:
            if facet.lower() == "ibp":
                wrapper = BoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "pgd":
                wrapper = PGDModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args)
            elif facet.lower() == "sabr":
                wrapper = SmallBoxModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, eps_shrinkage=args.eps_shrinkage)
            elif facet.lower() == "taps":
                wrapper = TAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes)
            elif facet.lower() == "staps":
                wrapper = STAPSModelWrapper(net, nn.CrossEntropyLoss(), input_dim, device, args, block_sizes=args.block_sizes, eps_shrinkage=args.eps_shrinkage)
            else:
                raise NotImplementedError(f"Unknown facet specified: {facet}.")
            wrappers.append(wrapper)
        model_wrapper = MultiFacetModelWrapper(wrappers, args.facets_eps_ratio, args.facets_weight, args)
    model_wrapper.robust_weight = 1
    return model_wrapper

def get_train_mode(args):
    # Adjust the sequence to have different preference of training when the args is ambiguous.
    if not args.multi_facets:
        if args.use_pgd_training:
            mode = "PGD_trained"
        elif args.use_vanilla_ibp:
            mode = "IBP_trained" if not args.use_small_box else "SABR_trained"
        elif args.use_HBox_training:
            mode = "HBox_trained" if not args.use_small_box else "SHBox_trained"
        elif args.use_Zono_training:
            mode = "Zono_trained" if not args.use_small_box else "SZono_trained"
        elif args.use_TAPS_training:
            mode = "TAPS_trained" if not args.use_small_box else "STAPS_trained"
        elif args.use_DP_training:
            mode = "DP_trained"
        elif args.use_DPBox_training:
            mode = "DPBox_trained"
        else:
            raise NotImplementedError("Unknown training mode.")
    else:
        mode = "_".join([f"{facet}{weight}" for facet, weight in zip(args.multi_facets, args.facets_weight)])
    return mode

def get_evo_mode(args):
    # Adjust the sequence to have different preference of training when the args is ambiguous.
    if args.use_PGPE_evo:
        mode = "PGPE"
    elif args.use_GSM_evo:
        mode = "GSM" 
    elif args.use_Mixed_evo:
        mode = "Mixed" 
    else:
        raise NotImplementedError("Unknown training mode.")
    
    return mode

def parse_save_root(args, evo_mode, cert_mode):
    mode = cert_mode
    eps_str = f"eps{args.test_eps:.5g}"
    init_str = f"init_{args.init}" if args.load_model is None else f"init_pretrained"
    save_root = os.path.join(args.save_dir, evo_mode, args.dataset, eps_str, mode, args.net, init_str)
    if args.fast_reg > 0:
        save_root = os.path.join(save_root, f"fast_reg_{args.fast_reg}")
    if args.use_small_box:
        save_root = os.path.join(save_root, f"lambda_{args.eps_shrinkage}")
    if args.use_TAPS_training:
        save_root = os.path.join(save_root, f"last_block_{args.block_sizes[-1]}")
    if args.IBPR_reg > 0:
        save_root = os.path.join(save_root, f"IBPR_{args.IBPR_reg}")
    if args.L1_reg > 0:
        save_root = os.path.join(save_root, f"L1_{args.L1_reg}")

    if args.rank == 0:
        args.with_neptune_id = None
        if args.resume_train:
            print(f'Attempting to resume training of {args.resume_train}',flush=True)
            # assert save_root in args.resume_train, f"Train arguments don't match with the path for resume train!\n{save_root}\n{args.resume_train}"
            
            args.neptune_id = args.with_neptune_id = args.resume_train
            args.resume_train = os.path.join(save_root,args.resume_train)
            assert os.path.isdir(args.resume_train), f"Training session {args.neptune_id} not found in {save_root}. There might be an argument mismatch."

        os.makedirs(save_root, exist_ok=True)
        logging.info(f"The model will be saved at: {save_root}")
    return save_root

def init_nep_log_from_prev_run(args):
    with open(os.path.join(args.resume_train,'train_args.json'),'rt') as f:
        old_args = json.load(f)
    with open(os.path.join(args.resume_train,'monitor.json'),'rt') as f:
        old_monitor = json.load(f)

    for k,v in old_args.items():
        if k not in args:
            setattr(args,k,v)
        elif v != getattr(args,k):
            print('WARNING: different args:',k,v,getattr(args,k))

    last_epoch = len(old_monitor["val_nat_curve"]) - 1
    if os.path.exists(os.path.join(args.resume_train,'Every_Epoch_Model')):
        args.load_model = os.path.join(args.resume_train,'Every_Epoch_Model',f'epoch_{last_epoch}.ckpt')
    else:
        args.load_model = os.path.join(args.resume_train,'model.ckpt')
    args.start_epoch = last_epoch + 1
    nep_log = neptune.init_run(
        project="ethsri/Grad-free",
        with_id=args.with_neptune_id,
        tags=args.neptune_tags if args.neptune_tags else [],
    )
    return nep_log, old_monitor

def run(args):
    
    # Define train mode (name of this train) here
    mode = get_train_mode(args)

    # Define evo mode (name of this train) here
    evo_mode = get_evo_mode(args)

    # Parse save root here
    save_root = parse_save_root(args, evo_mode, mode)
    args.save_root = save_root

    # neptune logging: placed here to enable remote interruption.
    old_monitor = {}
    args.start_epoch = 0
    rank = args.rank
    device = args.device


    if rank == 0:
        if neptune is not None and not args.disable_neptune:
            if args.resume_train:
                nep_log, old_monitor = init_nep_log_from_prev_run(args)
            else:
                nep_log = neptune.init_run(
                    project="ethsri/Grad-free",
                    tags=args.neptune_tags if args.neptune_tags else [],
                )
            neptune_id = nep_log["sys/id"].fetch()
        else:
            nep_log = None
            neptune_id = None
        args.neptune_id = neptune_id
        if args.neptune_id is not None:
            args.save_root = os.path.join(args.save_root, f"{args.neptune_id}")
            os.makedirs(args.save_root, exist_ok=True)
            logging.info(f"The model will be saved at: {args.save_root}")
        obj_list = [args, old_monitor]
    else:
        nep_log = None
        neptune_id = None
        obj_list = [None, None]

    if args.use_ddp: dist.broadcast_object_list(obj_list, src=0)
    args = obj_list[0]
    old_monitor = obj_list[1]
    save_root = args.save_root
    args.rank = rank
    args.device = device

    # if rank == 0:
    #     for k,v in old_monitor.items():
    #         perf_dict[k] = v

    # device = 'cuda' if torch.cuda.is_available() else 'cpu'
    perf_dict = {'val_nat_curve':[], 'val_cert_curve':[], 'val_loss_curve':[], 'train_nat_curve':[], 'train_cert_curve':[], 'train_loss_curve':[], 'lr_curve':[], "eps_curve":[]}
    PI_dict = {"PI_curve":[]}
    relu_dict = {"dead_relu_curve":[], "active_relu_curve":[], "unstable_relu_curve":[]}
    reg_dict = {"fastreg_curve":[]}
    adv_dict = {"train_adv_accu_curve":[], "val_adv_accu_curve":[]}

    # Add more perf_dict here to track more statistics.
    perf_dict = perf_dict | PI_dict | relu_dict | reg_dict | adv_dict
    for k,v in old_monitor.items():
        perf_dict[k] = v
    perf_dict["start_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    perf_dict["neptune_id"] = neptune_id
    verbose = False

    # Define dataset here
    loaders, input_size, input_channel, n_class = get_loaders(args, shuffle_train=True, num_workers=0)
    input_dim = (input_channel, input_size, input_size)
    args.num_classes = n_class
    if len(loaders) == 4:
        train_loader, val_loader, test_loader, train_test_loader = loaders
        
    else:
        train_loader, test_loader, train_test_loader = loaders
        val_loader = None
    perf_dict['model_selection'] = args.model_selection # tradition for certified training: use test set for model selection :(

    # Schedule for input epsilon
    if args.no_anneal:
        # use const eps == train_eps
        eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.train_eps, args.train_eps, "linear", s=len(train_loader))
    else:
        if args.schedule in ["smooth", "linear", "step"]:
            eps_scheduler = Scheduler(args.start_epoch_eps*len(train_loader), args.end_epoch_eps*len(train_loader), args.eps_start, args.eps_end, args.schedule, s=args.step_epoch*len(train_loader))
        else:
            raise NotImplementedError(f"Unknown schedule: {args.schedule}")

    # Define concrete (torch) model and convert it to abstract model here
    torch_net = get_network(args.net, args.dataset, device, init=args.init)
    # summary(net, (input_channel, input_size, input_size))
    net = Sequential.from_concrete_network(torch_net, input_dim, disconnect=False, retain_concrete=True)
    net.set_dim(torch.zeros((test_loader.batch_size, *input_dim), device=device))
    if args.load_model:
        net.load_state_dict(torch.load(args.load_model))
        print("Loaded:", args.load_model)
    rank_print(rank, net)

    # Define model wrapper here: this wraps how to compute the loss and how to compute the robust accuracy.
    model_wrapper = get_model_wrapper(args, net, device, input_dim)

    # x, y = next(iter(train_loader))
    # x, y = x.to(device), y.to(device)
    # model_wrapper.get_DP_bounds(x-args.train_eps, x+args.train_eps, y, return_bounds=True)
    # model_wrapper.get_robust_stat_from_bounds(x-args.train_eps, x+args.train_eps, y, intermediate_bounds=None, compute_bounds=True)

    # TODO: def evo wrapper here
    if args.use_PGPE_evo:
        evo_wrapper:BasicEvoWrapper = PGPEEvoWrapper(model_wrapper, args.lr, args.lr_std, device=device, args=args, popsize=args.popsize, symmetric=True, std_init=args.std_init, use_current_std=args.use_current_std, std_min=args.std_min, std_max=args.std_max, subbatch_size=None, nep_log=nep_log, optimizer=args.opt)
    elif args.use_GSM_evo:
        evo_wrapper:BasicEvoWrapper = GSMEvoWrapper(model_wrapper, args.lr, args.lr_std, device=device, args=args, popsize=args.popsize, symmetric=False, std_init=args.std_init, use_current_std=args.use_current_std, std_min=args.std_min, std_max=args.std_max, subbatch_size=None, nep_log=nep_log, optimizer=args.opt)
    elif args.use_Mixed_evo:
        evo_wrapper:BasicEvoWrapper = MixedEvoWrapper(model_wrapper, args.lr, args.lr_std, device=device, args=args, popsize=args.popsize, GSM_popsize=args.GSM_popsize, symmetric=True, std_init=args.std_init, use_current_std=args.use_current_std, std_min=args.std_min, std_max=args.std_max, subbatch_size=None, nep_log=nep_log, optimizer=args.opt, PGPE_weight=args.PGPE_weight)
    # elif args.use_GA_evo:
    #     mp_scheduler = None if args.start_epoch_std is None or args.end_epoch_std is None else Scheduler(args.start_epoch_std, args.end_epoch_std, 1, args.std_min/args.std_init, "log_linear")
    #     # psr_scheduler = None if args.start_epoch_psr is None or args.end_epoch_psr is None else Scheduler(args.start_epoch_psr, args.end_epoch_psr, args.psr_init, args.psr_min, "step", s=len(train_loader))
    #     # scheduler epochs would be multiplied by the number of iterations through the dataset per epoch
    #     evo_wrapper = GeneticsEvoWrapper(train_loader, torch_net, model_wrapper, eps_scheduler, popsize=args.popsize, num_elites=args.num_elites, num_parents=args.num_parents, mutation_power=args.std_init, use_current_std=args.use_current_std, pert_space_ratio=args.psr_init, mp_scheduler=mp_scheduler, psr_scheduler=None, popsize_scheduler=None, num_actors=args.num_actors, device=device, args=args, subbatch_size=None, nep_log=nep_log)
    else:
        raise NotImplementedError("Must specify an evolutional wrapper for using evo_train.")
    # evo_wrapper = RSOWrapper(model_wrapper, args.lr_milestones, args.lr_decay_factor, train_loader, eps_scheduler)

    # print("Optimizer state dict",evo_wrapper.searcher._optimizer._optim.state_dict())

    perf_dict["best_val_cert_accu"] = -1
    perf_dict["best_val_loss"] = 1e8

    if nep_log is not None:
        nep_log["args"] = args.__dict__

    train_time = 0.0
    for epoch_idx in range(args.start_epoch, args.n_epochs):
        rank_print(rank, "Epoch", epoch_idx + 1, "/", args.n_epochs)

        if epoch_idx in args.lr_milestones:
            evo_wrapper.step_lr(args.lr_decay_factor)

        #### ---- train loop below ----
        train_start_time = time.time()
        train_nat_accu, train_cert_accu, train_adv_accu, train_loss, mean_eval = evo_train_loop(evo_wrapper, eps_scheduler, train_loader, epoch_idx, device, args, nep_log=nep_log, extra_adv_attack=True)
        train_time += time.time() - train_start_time

        # eps = eps_scheduler.getcurrent(len(train_loader)*(epoch_idx+1))
        # eps = min(eps, args.test_eps)

        # train_nat_accu, train_cert_accu, train_loss, train_adv_accu = test_loop(model_wrapper, eps, train_test_loader, device, args, extra_adv_attack=True, max_batches=args.max_batches_train_eval)
        rank_print(rank, f"train_nat_accu: {train_nat_accu: .4f}, train_cert_accu: {train_cert_accu: .4f}, train_loss:{train_loss: .4f}, train_adv_accu:{train_adv_accu: .4f}")

        # Track train statistics here.
        perf_dict['train_nat_curve'].append(train_nat_accu)
        perf_dict['train_cert_curve'].append(train_cert_accu)
        perf_dict["train_loss_curve"].append(train_loss)
        perf_dict["train_adv_accu_curve"].append(train_adv_accu)



        # Update learning rate here.
        # evo_wrapper.step_lr()
        # lr = evo_wrapper.lr_scheduler.get_last_lr()[0]
        # perf_dict['lr_curve'].append(lr)
        eps = evo_wrapper.current_eps
        perf_dict["eps_curve"].append(eps)
        rank_print(rank, "current eps:", eps)
        rank_print(rank, "current robust_weight:", model_wrapper.robust_weight)

        #### ---- test loop below ----
        val_nat_accu, val_cert_accu, val_loss, val_adv_accu = test_loop(model_wrapper, eps, val_loader if val_loader is not None else test_loader, device, args, extra_adv_attack=True, max_batches=args.max_batches)
        rank_print(rank, f"val_nat_accu: {val_nat_accu: .4f}, val_cert_accu: {val_cert_accu: .4f}, val_loss:{val_loss: .4f}, val_adv_accu:{val_adv_accu: .4f}")
        perf_dict['val_nat_curve'].append(val_nat_accu)
        perf_dict['val_cert_curve'].append(val_cert_accu)
        perf_dict["val_loss_curve"].append(val_loss)
        perf_dict["val_adv_accu_curve"].append(val_adv_accu)

        # #### ---- additional model statistics tracking below ----
        # # -- propagation tightness --
        # PI = PI_loop(model_wrapper.net, 1e-5, val_loader if val_loader is not None else test_loader, device, args.num_classes, args, relu_adjust="local")
        # print(f"Propagation Tightness: {PI:.4f}")
        # perf_dict["PI_curve"].append(PI)

        # # -- relu status --
        # dead, unstable, active = relu_loop(model_wrapper.net, max(eps, 1e-6), val_loader if val_loader is not None else test_loader, device, args)
        # perf_dict["dead_relu_curve"].append(dead)
        # perf_dict["unstable_relu_curve"].append(unstable)
        # perf_dict["active_relu_curve"].append(active)
        # print(f"Dead: {dead:.3f}; Unstable: {unstable:.3f}; Active: {active:.3f}")

        if nep_log is not None:
            nep_log["train_nat_accu_curve"].append(train_nat_accu)
            nep_log["train_cert_accu_curve"].append(train_cert_accu)
            nep_log["train_loss_curve"].append(train_loss)
            nep_log["val_nat_accu_curve"].append(val_nat_accu)
            nep_log["val_cert_accu_curve"].append(val_cert_accu)
            nep_log["val_loss_curve"].append(val_loss)
            nep_log["eps_curve"].append(eps)
            nep_log["epoch"].append(epoch_idx)
            nep_log["train_adv_accu_curve"].append(train_adv_accu)
            nep_log["val_adv_accu_curve"].append(val_adv_accu)

        #### ---- model selection below ----
        if eps == args.test_eps and rank == 0:
            if (perf_dict["model_selection"] == "robust_accu" and val_cert_accu > perf_dict["best_val_cert_accu"]) or (perf_dict["model_selection"] == "loss" and val_loss < perf_dict["best_val_loss"]):
                torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
                if nep_log is not None:
                    nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
                    nep_log["best_val_cert_accu"].append(val_cert_accu)
                    nep_log["best_val_nat_accu"].append(val_nat_accu)
                print("New checkpoint saved.")
                perf_dict["best_ckpt_epoch"] = epoch_idx
            perf_dict["best_val_cert_accu"] = max(perf_dict["best_val_cert_accu"], val_cert_accu)
            perf_dict["best_val_loss"] = min(perf_dict["best_val_loss"], val_loss)

        if args.save_every_epoch and rank == 0:
            os.makedirs(os.path.join(save_root, "Every_Epoch_Model"), exist_ok=True)
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "Every_Epoch_Model", f"epoch_{epoch_idx}.ckpt"))

        if perf_dict["model_selection"] is None and rank == 0:
            # No model selection. Save the final model.
            torch.save(model_wrapper.net.state_dict(), os.path.join(save_root, "model.ckpt"))
            if nep_log is not None:
                nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
            print("New checkpoint saved.")

        if rank == 0:
            opt_dict = dict(
                center=evo_wrapper.optimizer_center.state_dict() if evo_wrapper.optimizer_center is not None else None,
                std=evo_wrapper.optimizer_std.state_dict() if evo_wrapper.optimizer_std is not None else None,
            )
            torch.save(opt_dict,os.path.join(save_root,'optimizer.ckpt'))

            # Write the logs to json file
            write_perf_to_json(perf_dict, save_root, "monitor.json")
            write_perf_to_json(args.__dict__, save_root, "train_args.json")

    # test for the best ckpt
    rank_print(rank, "-"*10 + f"Model Selection: {perf_dict['model_selection']}. Testing selected checkpoint." + "-"*10)
    model_wrapper.net.load_state_dict(torch.load(os.path.join(save_root, "model.ckpt")))
    test_nat_accu, test_cert_accu, loss = test_loop(model_wrapper, args.test_eps, test_loader, device, args, max_batches=args.max_batches)
    rank_print(rank, f"test_nat_accu: {test_nat_accu: .4f}, test_cert_accu: {test_cert_accu: .4f}")
    perf_dict["test_nat_accu"] = test_nat_accu
    perf_dict["test_cert_accu"] = test_cert_accu
    perf_dict["train_time"] = train_time
    perf_dict["end_time"] = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
    
    if rank == 0:
        write_perf_to_json(perf_dict, save_root, "monitor.json")
        write_perf_to_json(args.__dict__, save_root, "train_args.json")

    # if nep_log is not None:
    #     # neptune logging: save the logs in the end to enable quick web visualization. Logs during the run can be found locally in monitor.json
    #     nep_log["args"] = args.__dict__
    #     nep_log["result"] = perf_dict
    #     for k, v in perf_dict.items():
    #         if isinstance(v, list) and len(v)>0:
    #             for i in v:
    #                 nep_log[f"result/{k}"].append(i)
    #     nep_log["model"].upload(os.path.join(save_root, "model.ckpt"))
    if nep_log is not None:
        nep_log.stop()

def run_ddp(rank,args):
    args.rank = rank
    args.device = rank // args.actors_per_gpu
    if args.use_ddp: setup(rank, args)
    seed_everything(args.random_seed + args.rank)
    print(rank,"my_rank",rank,flush=True)
    run(args)
    if args.use_ddp: cleanup()

def main():
    args = get_args(["basic", "train", "evo"])
    # print(args.port)
    if args.precomp_bounds:
        print("Using precomp bounds with",args.reuse_bound_mode,flush=True)
    else:
        print("NOT Using precomp bounds with",args.reuse_bound_mode,flush=True)

    n_gpus = torch.cuda.device_count()
    args.actors = args.num_actors
    # args.port = None
    args.world_size = args.actors
    args.actors_per_gpu = args.actors // n_gpus
    args.use_ddp = args.actors > 1
    args.popsize = args.popsize // args.actors
    
    print(f'Using DDP with {args.actors_per_gpu} actors per GPU')

    seed_everything(args.random_seed)
    mp.spawn(run_ddp,
             args=(args,),
             nprocs=args.actors,
             join=True)

if __name__ == '__main__':
    main()
