import os
import sys
import tempfile
import json
# import argparse
import azureml
from azureml.core import Workspace, Experiment, Environment
from azureml.core.compute import ComputeTarget
from azureml.core import ScriptRunConfig, RunConfiguration
from azureml.core.runconfig import DockerConfiguration
from azureml.contrib.core.gjdrunconfig import GlobalJobDispatcherConfiguration
from azureml.core.conda_dependencies import CondaDependencies  # from azureml.widgets import RunDetails
from datetime import datetime
def parser():
    import easydict
    args = easydict.EasyDict({
        "use_gjd": True, #True
        "region": "westus2",
        "cpu": False,
        "exp_name": "zhijun_pytorch", #PLC_pytorch/PLC_synthesize/PLC_challenge/PLC_multi_task
        "info": "ar libritts tfcodec 16",
        "nb_repro": 1,
        "train_config": "config.yaml",
        "checkpoint_path": None,
        "checkpoint_dir": None,
    })
    return args

if __name__ == '__main__':
    import argparse
    external_parser = argparse.ArgumentParser(description='Process command line parameters.')
    external_parser.add_argument('--info', type=str, default="baseline",
                        help='info of experiments')
    external_parser.add_argument('--config', type=str, default=None,
                        help='config')
    ext_args = external_parser.parse_args()
    info = ext_args.info
    exp_config = ext_args.config
    args = parser()
    args.info = "ar libritts tfcodec 16"
    DATASTORE_DICT = {"packet_loss_concealment": 'zhijun_data'} # datastore need to be registered first on the workspace
    REGION_DICT = {"westus2": 'packet_loss_concealment'}
    VC_DICT = {"packet_loss_concealment": 'Westus2.PacketLossConcealment2',}
    workspace_name = REGION_DICT[args.region]
    datastore_name = DATASTORE_DICT[workspace_name]
    vc_name = VC_DICT[workspace_name]
    
    # from azureml.core.authentication import InteractiveLoginAuthentication
    # InteractiveLoginAuthentication(force=True)
    # init workspace
    ws = Workspace(
        subscription_id="f1f491ac-0340-4e5d-87b7-47692be1cb31",
        resource_group="IC3_Common_GPU_cluster",
        workspace_name=workspace_name
    )
    # increase shared memory. NB: must be set before the env def
    # azureml.core.environment._DEFAULT_SHM_SIZE = '10000g'  # enlarge for lmdb
    # # init environment
    # env = Environment.get(ws, "valle-4-23-noconda-v2")
    env = Environment(name="valle-4-23-noconda-v10")
    env.docker.enabled = True
    # # env.docker.base_image = "1294582975/valle-4-23-noconda:v0" #"huxue-pytorch1.7-cuda11-python3.7
    env.docker.base_image = "jiazhijun/valle-4-23-noconda:v10" #"huxue-pytorch1.7-cuda11-python3.7
    env.docker.base_image_registry.address = 'msrresrchcr.azurecr.io' #"ic3aicommonpoolregistry.azurecr.io"
    env.docker.base_image_registry.username = 'msrresrchcr' #"ic3aicommonpoolregistry" #service_principal_id # in order to push
    # env.docker.base_image_registry.password = '' #"${TVFCKXng/8L+Y7gX2C8wRk68j41YYksrDWE0lm2OEL+ACRDF23+Y}"
    env.docker.base_image_registry.password = "VNaA+VSV0v8S+u5k6qcrBgz4FfepnEuMinm/bINT9j+ACRCzBz0+"
    # specifying your own Python interpreter:
    # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments#specify-your-own-python-interpreter
    env.python.user_managed_dependencies = True
    # env.python.interpreter_path = "/usr/bin/python3"
    env.python.interpreter_path = "/opt/conda/bin/python"
    env.register(workspace=ws)
    env.build(ws)
    # import pdb; pdb.set_trace()
    # Setup the run_config using a dict
    run_config_struct = dict()
    run_config_struct["node_count"] = 1
    run_config_struct["environment"] = env._serialize_to_dict(env)
    run_config_struct["target"] = " "
    if not args.use_gjd:
        run_config_struct["target"] = "evalNC24" if args.cpu else "PrimaryTrain"
    # Serialize the dict to json file
    project_dir = tempfile.mkdtemp()
    os.makedirs(os.path.join(project_dir, ".amlcompute"))
    run_config_path = os.path.join(project_dir, ".amlcompute", "simple.runconfig")
    with open(run_config_path, "w") as outfile:
        json.dump(run_config_struct, outfile)
    run_config = RunConfiguration.load(run_config_path)
    #docker
    run_config.docker = DockerConfiguration(use_docker=True, shared_volumes=True, shm_size='2000g')
    # get cluster
    if args.use_gjd:   # Enable global job dispatcher to leverage idle compute resources from other workspaces
        # vc_list = ["Microsoft.IC3.PrimaryTrain.{}".format(vc_name)]
        vc_list = ["Microsoft.IC3.Unified.Westus2.{}".format(vc_name)]
        run_config.global_job_dispatcher = GlobalJobDispatcherConfiguration(compute_type="Amlcompute",
                                                                                   region=[],
                                                                                   vc_block_list=vc_list)
        compute_target = None
    elif args.cpu:  # Only use a specific node in a single workspace
        compute_target = "evalNC24" #ws.compute_targets["evalNC24"] #ComputeTarget(workspace=ws, name="evalNC24")
    else:
        compute_target = "PrimaryTrain" #ComputeTarget(workspace=ws, name="PrimaryTrain")
    # compute_target = ComputeTarget(workspace=ws, name="evalNC24")
    # print(compute_target)

      #     --manifest-dir /mnt/shared/LibriTTS/data_valle/data/tokenized/ 
      # --text-tokens /mnt/shared/LibriTTS/data_valle/data/tokenized/unique_text_tokens.k2symbols 
      # --exp-dir /mnt/shared/LibriTTS/data_valle/data/output/$${dir_name}_$${basestr}
       #---------PLC data------------
    # data_ref = ws.datastores[datastore_name].path("PLC/data_folder/PLC/tfrecord_4s_half_frm_pl_JBC_lmdb_2/librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_399_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data    
    # val_ref = ws.datastores[datastore_name].path("PLC/data_folder/PLC/tfrecord_4s_half_frm_pl_JBC_lmdb_2/val_librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_2999_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data
    # data_ref = ws.datastores[datastore_name].path("librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_399_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data    
    # val_ref = ws.datastores[datastore_name].path("val_librivox600hrs_random_pl_10_20_30_40_50_markov_0_train_2999_20.0ms_shift_10.0ms_drop_0.0ms.lmdb").as_mount() # for offline data
    # train_ref = ws.datastores[datastore_name].path("hesam_train/tfnetv4_baseline_reproduce").as_mount()
    
    num_runs = 1
    num_runs_stage2=1
    top_k=1
    top_k_stage2=10
    timestamp = current_time.strftime("%Y_%m_%d_%H_%M_%S")

    outputdir_name = f"converted_pretrain_mode_5_mask_0_15_{timestamp}"
    semantic_tokens=ws.datastores[datastore_name].path("data/LibriTTS/lhotse_vc/initial_data/unique_semantic_tokens.k2symbols").as_mount() 
    semantic_sys_dir=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/benchmark_librispeech_10speakers/source").as_mount()
    audio_prompts_dir=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/benchmark_librispeech_10speakers/prompt/").as_mount()
    checkpoint1=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/valle-tensorboard-models/pretrain/mode_5_mask_0_15/epoch-53.pt").as_mount()
    checkpoint2=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/valle-tensorboard-models/vc/only_ar/epoch-40.pt").as_mount()
    hubert_path=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/valle-tensorboard-models/other_models/hubert/hubert_base_ls960.pt").as_mount()
    hubert_km_path=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/valle-tensorboard-models/other_models/hubert/hubert_base_ls960_L9_km500.bin").as_mount()
    tfnet_ckpt=ws.datastores[datastore_name].path("data/inference_test_ICLR/data/valle-tensorboard-models/other_models/tfcodec/890hrs16khz_tfnet_v2i_vqvae_20msVQ_hop5_combine4_rd1_6kbps/tfnet_v2i_vqvae_combineVQ-iter-514000-val1-0.348327-entropy-118.968-mel-0.067548-recon-0.012822.ckpt").as_mount()
    
    arg_list = [   
    "--model-name", "valle",    
    "--norm-first", "true",    
    "--add-prenet", "False",    
    "--decoder-dim", "1024",    
    "--nhead", "16",    
    "--decoder-dim-stage2", "1024",    
    "--nhead-stage2", "16",    
    "--num-decoder-layers", "12",    
    "--num-decoder-layers-stage2", "12",    
    "--share-embedding", "true",    
    "--nums", num_runs,    
    "--nums-stage2", num_runs_stage2,    
    "--semantic-tokens", str(semantic_tokens),    
    "--semantic-sys-dir", str(semantic_sys_dir),    
    "--audio-prompts-dir", str(audio_prompts_dir),    
    "--input-semantic", "True",    
    "--only-autoregressive", "True",    
    "--prefix-mode", "1",    
    "--checkpoint1", str(checkpoint1),    
    "--checkpoint2", str(checkpoint2),    
    "--top-k", top_k,
    "--top-k-stage2", top_k_stage2,    
    "--shared-linear-stage2", "False",    
    "--temperature", "1.0",    
    "--num-quantizers", "1",    
    "--num-quantizers-stage2", "16",    
    "--input-codec", "1",    
    "--target-mode", "2",    
    "--accent-remove", "False",    
    "--mode", "1",    
    "--mode-stage2", "0",    
    "--is-pretrain", "True",    
    "--pret-mode", "0",    
    "--outputdir-name", str(outputdir_name),
    "--hubert-path", str(hubert_path),
    "--hubert-km-path", str(hubert_km_path),
    "--tfnet-ckpt", str(tfnet_ckpt),
    ]

    # estimator
    src_dir = "./" # parent dir
    estimator = ScriptRunConfig(
        source_directory=src_dir,
        script="egs/libritts/bin/combine_ar_nar_vc_dir_onlyar.py",
        run_config=run_config,
        arguments=arg_list,
        compute_target=compute_target,
    )
    estimator.run_config.target = " "
    estimator.run_config.data_references[semantic_tokens.data_reference_name] = semantic_tokens.to_config()
    estimator.run_config.data_references[semantic_sys_dir.data_reference_name] = semantic_sys_dir.to_config()
    estimator.run_config.data_references[audio_prompts_dir.data_reference_name] = audio_prompts_dir.to_config()
    estimator.run_config.data_references[checkpoint1.data_reference_name] = checkpoint1.to_config()
    estimator.run_config.data_references[checkpoint2.data_reference_name] = checkpoint2.to_config()
    estimator.run_config.data_references[hubert_path.data_reference_name] = hubert_path.to_config()
    estimator.run_config.data_references[hubert_km_path.data_reference_name] = hubert_km_path.to_config()
    estimator.run_config.data_references[tfnet_ckpt.data_reference_name] = tfnet_ckpt.to_config()
    
    # create experiment
    exp = Experiment(workspace=ws, name=args.exp_name)
    

    # submit a run
    for i in range(args.nb_repro):
        info_tag = {"info": "{} r{}".format(args.info, i)}
        run = exp.submit(estimator, tags=info_tag)
        # RunDetails(run).show()
        print(run.get_details())
        # last_run = list(exp.get_runs())[0]  
