import os
import time
import subprocess
import psutil

# Run this:
    # echo "seed is ${seed}:"
    # CUDA_VISIBLE_DEVICES=0 python ../train/train_bridge.py  --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --map_type ${map_type} --num_agents ${num_agents} \
    #  --seed ${seed} --n_timesteps 10 --n_training_threads 1 --n_rollout_threads 50 --num_mini_batch 1 --episode_length 30 --num_env_steps 1000000 --reward_shaping_horizon 100000000 \
    #  --ppo_epoch 15 --use_latent_actions \
    #  --save_interval 25 --log_inerval 10 --use_recurrent_policy --bc_loss_coef 0.5 \
    #  --wandb_name "username" --user_name "username" 

# env = "Bridge_dif_entropy3"
algo = "diff-mappo"


# my_env = os.environ.copy()
# my_env["CUDA_VISIBLE_DEVICES"] = "0"
num_gpu = 2
num_max_process = 2
envs = [os.environ.copy() for _ in range(num_max_process)]
process = [None for _ in range(num_max_process)]

# Get available CPUs and sort them by usage

# Assign GPUs and CPUs to each process
for i in range(num_max_process):
    envs[i]["CUDA_VISIBLE_DEVICES"] = str((i) % num_gpu)
    # envs[i]["CPU_AFFINITY"] = str(available_cpus[i % len(available_cpus)])

project = 'D4RL-sweep9_clone'
que = []

# norm_reward
for joint in [True]:
    for seed in [1]:
        for use_recurrent_policy in [False]:
            for entropy_coef in [0]:
                for clone_episodes in [60]:
                    for n_timesteps in [5]:
                        for scenario_name, rnum_agents, episode_length in [('halfcheetah-medium-v2', 1, 1000), ('hopper-medium-v2', 1, 1000), ('walker2d-medium-v2', 1, 1000)]:
                            # ('halfcheetah-medium-v2', 1, 250), 
                            for lr in [1e-3]:
                                for clip_param in [0.2]:
                                    for weight_decay in [0]:
                                        for unet_hidden_size in [1024]:
                                            for hidden_size in [256]:
                                                for bc_loss_coef in [1]:
                                                    ppo_enum = [10]
                                                    for ppo_epoch in ppo_enum:
                                                        for num_mini_batch in [4]:
                                                            for initial_logstd in [-2]:
                                                                for act_step in [4]:
                                                                    for gamma in [0.995]:
                                                                        if scenario_name == 'halfcheetah-medium-v2':
                                                                            gamma = 0.99
                                                                        for gae_lambda in [0.985]:
                                                                            for eta in [0.0]:
                                                                                cmd = f"python ../train/train_d4rl_diff.py --beta_schedule cosine --normalize_advantage --env_name {project} --algorithm_name {algo} --scenario_name {scenario_name} --experiment_name map_{scenario_name} --num_agents {1} --initial_logstd {initial_logstd} --rnum_agents {rnum_agents} \
                                                                            --seed {seed} --n_timesteps {n_timesteps} --act_step {act_step} --vault_uid {scenario_name}_mixed  --logit_scaling {10} --weight_decay {weight_decay} --n_training_threads {8} --n_rollout_threads {32} --num_mini_batch {num_mini_batch} --episode_length {episode_length} --num_env_steps {120000} --reward_shaping_horizon {100000000} --eta {eta} --use_proper_time_limits \
                                                                            --ppo_epoch {ppo_epoch} --bc_epoch {0} --lr {lr}  --max_grad_norm {0.5}  --use_ReLU --layer_N 1 --gae_lambda {gae_lambda} --value_loss_coef {0.5} --bc_buffer_limit {0}  --hidden_size {hidden_size} --unet_hidden_size {unet_hidden_size} --repeat_num {1} --clone_episodes {clone_episodes} --use_latent_actions \
                                                                            --save_interval 50 --log_interval {1} --bc_loss_coef {bc_loss_coef} --clip_param {clip_param}  --use_attention \
                                                                            --reward_shaping_factor 1 --gamma {gamma} \
                                                                            --save_interval 20 --negative_sample_scale {1} --eval_interval 5  --n_eval_rollout_threads 20 --eval_episodes 1 --use_eval  --entropy_coef {entropy_coef}  --gain {0.01} \
                                                                            --wandb_name 'username' --user_name 'username' " + (' --joint_train ' if joint else '') + ('--use_recurrent_policy' if not use_recurrent_policy else '') + (' --no_rand_train ' if False else '')
                                                                                que.append(cmd)
for cmd in que:
    while 1:
        ok = False
        for i in range(num_max_process):
            if process[i] is None or process[i].poll() is not None:
                # cpu_affinity = envs[i]["CPU_AFFINITY"]
                # modified_cmd = f"taskset -c {cpu_affinity} {cmd} --id {i}"
                # cpu_usages = psutil.cpu_percent(percpu=True)
                # least_used_cpus = sorted(range(len(cpu_usages)), key=lambda i: cpu_usages[i])
                # p_tmp = psutil.Process(os.getpid())
                # available_cpus = []
                # for j in least_used_cpus:
                #     try:
                #         p_tmp.cpu_affinity([j])
                #         available_cpus.append(j)
                #         # print(f"Available CPU: {i}")
                #     except Exception as e:
                #         pass
                # print(available_cpus[0])
                # p_tmp.cpu_affinity([available_cpus[0]])

                # modified_cmd = f"taskset -c {available_cpus[0]} {cmd} --id {i}"
                # env = envs[i].copy()
                # env['CPU_AFFINITY'] = str(available_cpus[0])
                process[i] = subprocess.Popen(cmd, env=envs[i], shell=True)
                # print(f"Started process {i} on CPU {cpu_affinity}")
                ok = True
                time.sleep(40)
                break
        if ok:
            break
        else:
            time.sleep(20)

# kill $(ps -ef | grep 'diff-mappo' | awk '{print $2}')
# kill $(ps -ef | grep 'python' | awk '{print $2}')

# ps -eo pid,comm,psr | awk '$3 == 60'