import os
import sys
import tempfile
import json
# import argparse
import azureml
from azureml.core import Workspace, Experiment, Environment
from azureml.core.compute import ComputeTarget
from azureml.core import ScriptRunConfig, RunConfiguration
from azureml.core.runconfig import DockerConfiguration
from azureml.contrib.core.gjdrunconfig import GlobalJobDispatcherConfiguration
# from azureml.widgets import RunDetails
from datetime import datetime
def parser():
    import easydict
    args = easydict.EasyDict({
        "use_gjd": True, #True
        "region": "westus2",
        "cpu": False,
        "exp_name": "zhijun_pytorch", #PLC_pytorch/PLC_synthesize/PLC_challenge/PLC_multi_task
        "info": "nar_40duration_200epoch",
        "nb_repro": 1,
        "train_config": "config.yaml",
        "checkpoint_path": None,
        "checkpoint_dir": None,
    })
    return args

if __name__ == '__main__':
    import argparse
    external_parser = argparse.ArgumentParser(description='Process command line parameters.')
    external_parser.add_argument('--info', type=str, default="baseline",
                        help='info of experiments')
    external_parser.add_argument('--config', type=str, default=None,
                        help='config')
    ext_args = external_parser.parse_args()
    info = ext_args.info
    exp_config = ext_args.config
    args = parser()
    args.info = info
    DATASTORE_DICT = {"packet_loss_concealment": 'zhijun_data'} # datastore need to be registered first on the workspace
    REGION_DICT = {"westus2": 'packet_loss_concealment'}
    VC_DICT = {"packet_loss_concealment": 'Westus2.PacketLossConcealment2',}
    workspace_name = REGION_DICT[args.region]
    datastore_name = DATASTORE_DICT[workspace_name]
    vc_name = VC_DICT[workspace_name]
    
    # from azureml.core.authentication import InteractiveLoginAuthentication
    # InteractiveLoginAuthentication(force=True)
    # init workspace
    ws = Workspace(
        subscription_id="f1f491ac-0340-4e5d-87b7-47692be1cb31",
        resource_group="IC3_Common_GPU_cluster",
        workspace_name=workspace_name
    )
    # increase shared memory. NB: must be set before the env def
    # azureml.core.environment._DEFAULT_SHM_SIZE = '10000g'  # enlarge for lmdb
    # # init environment
    # env = Environment.get(ws, "valle-4-23-noconda-v2")
    env = Environment(name="valle-4-23-noconda-v8")
    env.docker.enabled = True
    # # env.docker.base_image = "1294582975/valle-4-23-noconda:v0" #"huxue-pytorch1.7-cuda11-python3.7
    env.docker.base_image = "jiazhijun/valle-4-23-noconda:v8" #"huxue-pytorch1.7-cuda11-python3.7
    env.docker.base_image_registry.address = 'msrresrchcr.azurecr.io' #"ic3aicommonpoolregistry.azurecr.io"
    env.docker.base_image_registry.username = 'msrresrchcr' #"ic3aicommonpoolregistry" #service_principal_id # in order to push
    # env.docker.base_image_registry.password = '' #"${TVFCKXng/8L+Y7gX2C8wRk68j41YYksrDWE0lm2OEL+ACRDF23+Y}"
    env.docker.base_image_registry.password = "TVFCKXng/8L+Y7gX2C8wRk68j41YYksrDWE0lm2OEL+ACRDF23+Y"
    # specifying your own Python interpreter:
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments#specify-your-own-python-interpreter
    env.python.user_managed_dependencies = True
    # env.python.interpreter_path = "/usr/bin/python3"
    env.python.interpreter_path = "/opt/conda/bin/python"
    env.register(workspace=ws)
    env.build(ws)
    # import pdb; pdb.set_trace()
    # Setup the run_config using a dict
    run_config_struct = dict()
    run_config_struct["node_count"] = 1
    run_config_struct["environment"] = env._serialize_to_dict(env)
    run_config_struct["target"] = " "
    if not args.use_gjd:
        run_config_struct["target"] = "evalNC24" if args.cpu else "PrimaryTrain"
    # Serialize the dict to json file
    project_dir = tempfile.mkdtemp()
    os.makedirs(os.path.join(project_dir, ".amlcompute"))
    run_config_path = os.path.join(project_dir, ".amlcompute", "simple.runconfig")
    with open(run_config_path, "w") as outfile:
        json.dump(run_config_struct, outfile)
    run_config = RunConfiguration.load(run_config_path)
    #docker
    run_config.docker = DockerConfiguration(use_docker=True, shared_volumes=True, shm_size='2000g')
    # get cluster
    if args.use_gjd:   # Enable global job dispatcher to leverage idle compute resources from other workspaces
        # vc_list = ["Microsoft.IC3.PrimaryTrain.{}".format(vc_name)]
        vc_list = ["Microsoft.IC3.Unified.Westus2.{}".format(vc_name)]
        run_config.global_job_dispatcher = GlobalJobDispatcherConfiguration(compute_type="Amlcompute",
                                                                                   region=[],
                                                                                   vc_block_list=vc_list)
        compute_target = None
    elif args.cpu:  # Only use a specific node in a single workspace
        compute_target = "evalNC24" #ws.compute_targets["evalNC24"] #ComputeTarget(workspace=ws, name="evalNC24")
    else:
        compute_target = "PrimaryTrain" #ComputeTarget(workspace=ws, name="PrimaryTrain")
    # compute_target = ComputeTarget(workspace=ws, name="evalNC24")
    # print(compute_target)

      #     --manifest-dir /mnt/shared/LibriTTS/data_valle/data/tokenized/ 
      # --text-tokens /mnt/shared/LibriTTS/data_valle/data/tokenized/unique_text_tokens.k2symbols 
      # --exp-dir /mnt/shared/LibriTTS/data_valle/data/output/$${dir_name}_$${basestr}
       #---------PLC data------------
    # data_ref = ws.datastores[datastore_name].path("PLC/data_folder/PLC/tfrecord_4s_half_frm_pl_JBC_lmdb_2/librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_399_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data    
    # val_ref = ws.datastores[datastore_name].path("PLC/data_folder/PLC/tfrecord_4s_half_frm_pl_JBC_lmdb_2/val_librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_2999_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data
    # data_ref = ws.datastores[datastore_name].path("librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_399_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data    
    # val_ref = ws.datastores[datastore_name].path("val_librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_2999_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data
    # train_ref = ws.datastores[datastore_name].path("hesam_train/tfnetv4_baseline_reproduce").as_mount()
    
    name="VALLE"  
    max_duration=50  
    dtype="float32"  
    base_lr=0.05  
    world_size=8
    echo=2
    train_stage=2
    start_epoch=1
    current_time = datetime.now()
    timestamp = current_time.strftime("%Y_%m_%d_%H_%M_%S")
    accumulate_grad_steps=4
    prefix_mode=1
    input_semantic = "True"
    valid_interval=500 #4000
    dir_name=f"Name_{name}_max-duration_{max_duration}_dtype_{dtype}_base-lr_{base_lr}_world-size_{world_size}_train-stage_{train_stage}_echo_{echo}_start_echo_{start_epoch}_accumulate_grad_steps_{accumulate_grad_steps}_prefix_mode_{prefix_mode}_input_semantic_{input_semantic}_valid_interval_{valid_interval}" 
    
    
    current_time = datetime.now()
    timestamp = current_time.strftime("%Y_%m_%d_%H_%M_%S")
    manifest_dir = ws.datastores[datastore_name].path("LibriTTS/data_valle/data_azure/data/vc_tokenized_v2").as_mount()
    text_tokens = ws.datastores[datastore_name].path("LibriTTS/data_valle/data_azure/data/vc_tokenized_v2/unique_text_tokens.k2symbols").as_mount()
    semantic_tokens = ws.datastores[datastore_name].path("LibriTTS/data_valle/data_azure/data/vc_tokenized_v2/unique_semantic_tokens.k2symbols").as_mount()
    # exp_dir = ws.datastores[datastore_name].path(f"LibriTTS/data_valle/data_azure/data/output/{dir_name}_{timestamp}").as_mount()
    exp_dir = ws.datastores[datastore_name].path(f"LibriTTS/data_valle/data_azure/data/output_vc/{dir_name}_{timestamp}").as_mount()
    # warmup_ref = ws.datastores[datastore_name].path("${warm_up_folder_path}").as_mount()
    warmup_ref = None
    # ckpt_ref = ws.datastores[datastore_name].path("${pretrained_ckpt_path}").as_mount()
    ckpt_ref = None
    checkpoint_dir_ref=None
     
    print(f"timestamp : {timestamp}")
    arg_list = [
        "--nproc-per-node", world_size,
        "--nnodes", 1,
        "--max-duration", max_duration,
        "--filter-min-duration", 0.5,
        "--filter-max-duration", 14,
        "--train-stage", train_stage,
        "--num-buckets", 6,
        "--dtype", dtype,
        "--save-every-n", 4000000000,
        "--valid-interval", valid_interval,
        "--model-name", "valle",
        "--share-embedding", True,  
        "--norm-first", True,  
        "--add-prenet", False,  
        "--decoder-dim", 1024,  
        "--nhead",16,  
        "--num-decoder-layers", 12,  
        "--prefix-mode", prefix_mode,  
        "--base-lr", base_lr,  
        "--warmup-steps", 200,  
        "--average-period", 0,  
        "--num-epochs", echo,  
        "--start-epoch", start_epoch,  
        "--start-batch", 0,  
        "--accumulate-grad-steps",accumulate_grad_steps,  
        "--world-size", world_size,
        "--manifest-dir", str(manifest_dir),
        "--text-tokens", str(text_tokens),
        "--semantic-tokens", str(semantic_tokens),
        "--exp-dir", str(exp_dir),
        "--newfile-suffix", str(timestamp),
        "--is-local", False,
        "--input-semantic", str(input_semantic),
    ]
    # estimator
    src_dir = "./" # parent dir
    estimator = ScriptRunConfig(
        source_directory=src_dir,
        script="egs/libritts/bin/multiprocess_caller.py",
        run_config=run_config,
        arguments=arg_list,
        compute_target=compute_target,
    )
    estimator.run_config.target = " "
    estimator.run_config.data_references[manifest_dir.data_reference_name] = manifest_dir.to_config()
    estimator.run_config.data_references[text_tokens.data_reference_name] = text_tokens.to_config()    
    estimator.run_config.data_references[semantic_tokens.data_reference_name] = semantic_tokens.to_config()

    estimator.run_config.data_references[exp_dir.data_reference_name] = exp_dir.to_config()

    
    # create experiment
    exp = Experiment(workspace=ws, name=args.exp_name)

    # submit a run
    for i in range(args.nb_repro):
        info_tag = {"info": "{} r{}".format(args.info, i)}
        run = exp.submit(estimator, tags=info_tag)
        # RunDetails(run).show()
        print(run.get_details())
        # last_run = list(exp.get_runs())[0]  