import os
import time
import subprocess
import psutil

# Run this:
    # echo "seed is ${seed}:"
    # CUDA_VISIBLE_DEVICES=0 python ../train/train_bridge.py  --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --map_type ${map_type} --num_agents ${num_agents} \
    #  --seed ${seed} --n_timesteps 10 --n_training_threads 1 --n_rollout_threads 50 --num_mini_batch 1 --episode_length 30 --num_env_steps 1000000 --reward_shaping_horizon 100000000 \
    #  --ppo_epoch 15 --use_latent_actions \
    #  --save_interval 25 --log_inerval 10 --use_recurrent_policy --bc_loss_coef 0.5 \
    #  --wandb_name "username" --user_name "username" 

# env = "Bridge_dif_entropy3"
algo = "diff-mappo"


# my_env = os.environ.copy()
# my_env["CUDA_VISIBLE_DEVICES"] = "0"
num_gpu = 2
num_max_process = 1
envs = [os.environ.copy() for _ in range(num_max_process)]
process = [None for _ in range(num_max_process)]

# Get available CPUs and sort them by usage

# Assign GPUs and CPUs to each process
for i in range(num_max_process):
    envs[i]["CUDA_VISIBLE_DEVICES"] = str((i+0) % num_gpu)
    envs[i]["MUJOCO_GL"] = "osmesa"
    # envs[i]["CPU_AFFINITY"] = str(available_cpus[i % len(available_cpus)])

project = 'Robomimic-sweep15'
que = []


# norm_reward
for joint in [True]:
    for seed in [4, 5]:
        for use_recurrent_policy in [False]:
            for entropy_coef in [0]:
                for clone_episodes in [0]:
                    for n_timesteps in [5]:
                        for scenario_name, rnum_agents, episode_length in [('transport', 1, 800)]:
                            max_episode_length = 800
                            dataset_path = f'/workspace/Diffusion-PPO/on-policy/onpolicy/datasets/{scenario_name}/mh/low_dim_v141.hdf5'
                            robomimic_env_cfg_path = f'/workspace/Diffusion-PPO/on-policy/onpolicy/envs/robomimic/env_meta/{scenario_name}.json'
                            # ('walker2d-medium-v2', 1, 512)
                            for lr in [3e-5]:
                                for clip_param in [0.5]:
                                    for clone_weight_decay in [0]:
                                        for unet_hidden_size in [1024]:
                                            for hidden_size in [256]:
                                                for critic_lr in [5e-4]:
                                                    # ppo_enum = [1, 3] if episode_length <= 64 else [1, 3, 10]
                                                    ppo_enum = [5]
                                                    for ppo_epoch in ppo_enum:
                                                        for num_mini_batch in [4]:
                                                            for initial_logstd in [-2.3]:
                                                                # cnt = (clip_param != 0.5) + (ppo_epoch != 10) + (num_mini_batch != 10) + (initial_logstd != -2.7)
                                                                # print(cnt)
                                                                # if cnt > 1:
                                                                #     continue
                                                                for logit_scaling in [1]:
                                                                    for gamma in [0.999]:
                                                                        for gae_lambda in [0.99]:
                                                                            # model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/results/Robomimic-sweep-clone13/lift/diff-mappo/map_lift/wandb/run-20250315_204753-488pez0o/files"
                                                                            # model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/results/Robomimic-sweep-clone11/square/diff-mappo/map_square/wandb/run-20250216_155335-k0xxzgdu/files"
                                                                            # model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/model/run-20250215_214830-3ynb0rme_copy/files"
                                                                            model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/model/run-20250211_153635-dlz59ixc_copy/files"
                                                                            # model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/model/run-20241205_003829-1hszdvca/files"
                                                                            # model_dir = f"/workspace/Diffusion-PPO/on-policy/onpolicy/scripts/model/run-20241205_171815-3yl03raq/files"
                                                                            # model_dir = f"/home/username/Diffusion-PPO/on-policy/onpolicy/scripts/model/run-20241117_231723-7ctiz967/files"
                                                                            # --running_reward_scaling --use_valuenorm --recompute_adv  --running_reward_scaling --use_valuenorm 
                                                                            #  --model_dir {model_dir}
                                                                            normalization_path = f"/workspace/Diffusion-PPO/norm/robomimic/{scenario_name}/normalization.npz"
                                                                            for eta in [0]:
                                                                                cmd = f"python ../train/train_robomimic.py --normalize_advantage --normalization_path {normalization_path} --act_step {8} --model_dir {model_dir} --critic_epoch {0}  --env_name {project} --algorithm_name {algo} --scenario_name {scenario_name} --beta_schedule cosine --experiment_name map_{scenario_name} --num_agents {1} --initial_logstd {initial_logstd} --rnum_agents {rnum_agents} --t_dim {16} \
                                                                            --seed {seed} --n_timesteps {n_timesteps} --vault_uid {scenario_name}_mixed  --logit_scaling {logit_scaling} --clone_weight_decay {clone_weight_decay} --n_training_threads {8}  --data_chunk_length {1} --n_rollout_threads {200} --num_mini_batch  {num_mini_batch} --dataset_path {dataset_path} --robomimic_env_cfg_path {robomimic_env_cfg_path} --max_episode_length {max_episode_length} --episode_length {episode_length} --num_env_steps {15000000} --reward_shaping_horizon {100000000} --eta {eta} --use_proper_time_limits \
                                                                            --ppo_epoch {ppo_epoch} --bc_epoch {0} --lr {lr} --critic_lr {critic_lr}  --max_grad_norm {10} --layer_N 1  --gae_lambda {gae_lambda} --value_loss_coef {0.5} --bc_buffer_limit {0}  --hidden_size {hidden_size} --unet_hidden_size {unet_hidden_size} --repeat_num {1} --clone_episodes {clone_episodes} --use_latent_actions \
                                                                            --save_interval 50 --log_interval {1} --bc_loss_coef {5} --clip_param {clip_param}  --use_attention \
                                                                            --reward_shaping_factor 1 --gamma {gamma} \
                                                                            --save_interval 20 --negative_sample_scale {0} --use_eval --eval_interval 5  --n_eval_rollout_threads 100 --eval_episodes 1  --entropy_coef {entropy_coef}  --gain {0.01} \
                                                                            --wandb_name 'username' --user_name 'username' " + (' --joint_train ' if joint else '') + ('--use_recurrent_policy' if not use_recurrent_policy else '') + (' --no_rand_train ' if False else '')
                                                                                que.append(cmd)
for cmd in que:
    while 1:
        ok = False
        for i in range(num_max_process):
            if process[i] is None or process[i].poll() is not None:
                # cpu_affinity = envs[i]["CPU_AFFINITY"]
                # modified_cmd = f"taskset -c {cpu_affinity} {cmd} --id {i}"
                # cpu_usages = psutil.cpu_percent(percpu=True)
                # least_used_cpus = sorted(range(len(cpu_usages)), key=lambda i: cpu_usages[i])
                # p_tmp = psutil.Process(os.getpid())
                # available_cpus = []
                # for j in least_used_cpus:
                #     try:
                #         p_tmp.cpu_affinity([j])
                #         available_cpus.append(j)
                #         # print(f"Available CPU: {i}")
                #     except Exception as e:
                #         pass
                # print(available_cpus[0])
                # p_tmp.cpu_affinity([available_cpus[0]])

                # modified_cmd = f"taskset -c {available_cpus[0]} {cmd} --id {i}"
                # env = envs[i].copy()
                # env['CPU_AFFINITY'] = str(available_cpus[0])
                process[i] = subprocess.Popen(cmd, env=envs[i], shell=True)
                # print(f"Started process {i} on CPU {cpu_affinity}")
                ok = True
                time.sleep(40)
                break
        if ok:
            break
        else:
            time.sleep(20)

# kill $(ps -ef | grep 'diff-mappo' | awk '{print $2}')
# kill $(ps -ef | grep 'python' | awk '{print $2}')

# ps -eo pid,comm,psr | awk '$3 == 60'