import os
import time
import subprocess
import psutil

algo = "diff-mappo"


# my_env = os.environ.copy()
# my_env["CUDA_VISIBLE_DEVICES"] = "0"
num_gpu = 8
num_max_process = 8
envs = [os.environ.copy() for _ in range(num_max_process)]
process = [None for _ in range(num_max_process)]

# Get available CPUs and sort them by usage

# Assign GPUs and CPUs to each process
for i in range(num_max_process):
    envs[i]["CUDA_VISIBLE_DEVICES"] = str((i+3) % num_gpu)
    # envs[i]["CPU_AFFINITY"] = str(available_cpus[i % len(available_cpus)])

project = 'Kitchen-rebuttal_mixed_final2'
que = []

# norm_reward
for joint in [True]:
    for seed in [1, 2, 3]:
        for use_recurrent_policy in [False]:
            for entropy_coef in [0]:
                for n_timesteps in [5]:
                    for scenario_name, rnum_agents, episode_length, clone_episodes in [('kitchen-mixed-v0', 1, 280, 0)]:
                        # ('halfcheetah-medium-v2', 1, 250),  ('halfcheetah-medium-v2', 1, 256), ('hopper-medium-v2', 1, 500), 
                        for lr in [3e-5, 1e-5, 1e-4]:
                            for clip_param in [0.2, 0.5]:
                                for weight_decay in [0]:
                                    for unet_hidden_size in [1024]:
                                        for hidden_size in [256]:
                                            for bc_loss_coef in [0]:
                                                ppo_enum = [10, 5]
                                                for ppo_epoch in ppo_enum:
                                                    for num_mini_batch in [1]:
                                                        for initial_logstd in [-2.3]:
                                                            for act_step in [4]:
                                                                # if act_step == 1 and n_timesteps != 5:
                                                                #     continue
                                                                num_env_steps = 2000000
                                                                if scenario_name == 'halfcheetah-medium-v2':
                                                                    num_env_steps = 20000000
                                                                for gamma in [0.995, 0.99]:
                                                                    # model_dir = f"~/Diffusion-PPO/on-policy/onpolicy/scripts/model/d4rl/{scenario_name}/files" 
                                                                    # model_dir = f"/home/yangningyuan/Diffusion-PPO/on-policy/onpolicy/scripts/results/Kitchen-rebuttal_pretrain3_act4/kitchen-partial-v0/diff-mappo/map_kitchen-partial-v0/run3/models" # partial
                                                                    # model_dir = f"/home/yangningyuan/Diffusion-PPO/on-policy/onpolicy/scripts/results/Kitchen-rebuttal_pretrain_complete_act4/kitchen-complete-v0/diff-mappo/map_kitchen-complete-v0/run2/models" # complete
                                                                    model_dir = f"/home/yangningyuan/Diffusion-PPO/on-policy/onpolicy/scripts/results/Kitchen-rebuttal_pretrain_mixed_act4/kitchen-mixed-v0/diff-mappo/map_kitchen-mixed-v0/run2/models" # mixed
                                                                    #--model_dir {model_dir} --normalization_path {normalization_path} 
                                                                    if scenario_name == 'halfcheetah-medium-v2':
                                                                        gamma = 0.99
                                                                    normalization_path = f"~/Diffusion-PPO/norm/gym/{scenario_name}/normalization.npz"
                                                                    for gae_lambda in [0.985, 0.95]:
                                                                        for eta in [0.0]:
                                                                            cmd = f"python ../train/train_d4rl_diff.py --unet_num_layer {7} --use_wandb --model_dir {model_dir} --tensorboard_project {project} --normalization_path {normalization_path} --beta_schedule cosine --env_name {project} --algorithm_name {algo} --scenario_name {scenario_name} --experiment_name map_{scenario_name} --num_agents {1} --initial_logstd {initial_logstd} --rnum_agents {rnum_agents} \
                                                                        --seed {seed} --n_timesteps {n_timesteps} --act_step {act_step} --vault_uid {scenario_name}_mixed  --logit_scaling {1} --weight_decay {weight_decay} --n_training_threads {8} --n_rollout_threads {40} --num_mini_batch {num_mini_batch} --episode_length {episode_length} --num_env_steps {num_env_steps} --reward_shaping_horizon {100000000} --eta {eta} --use_proper_time_limits \
                                                                        --ppo_epoch {ppo_epoch} --bc_epoch {0} --lr {lr}  --max_grad_norm {10.0}  --use_ReLU --layer_N 1 --gae_lambda {gae_lambda} --value_loss_coef {0.5} --bc_buffer_limit {0}  --hidden_size {hidden_size} --unet_hidden_size {unet_hidden_size} --repeat_num {1} --clone_episodes {clone_episodes} --use_latent_actions \
                                                                        --save_interval 50 --log_interval {1} --bc_loss_coef {bc_loss_coef} --clip_param {clip_param}  --use_attention \
                                                                        --reward_shaping_factor 1 --gamma {gamma} \
                                                                        --save_interval 20 --negative_sample_scale {1} --eval_interval 5  --n_eval_rollout_threads 20 --eval_episodes 1 --use_eval  --entropy_coef {entropy_coef}  --gain {0.01} \
                                                                        --wandb_name '' --user_name '' " + (' --joint_train ' if joint else '') + ('--use_recurrent_policy' if not use_recurrent_policy else '') + (' --no_rand_train ' if False else '')
                                                                            que.append(cmd)
for cmd in que:
    while 1:
        ok = False
        for i in range(num_max_process):
            if process[i] is None or process[i].poll() is not None:
                # cpu_affinity = envs[i]["CPU_AFFINITY"]
                # modified_cmd = f"taskset -c {cpu_affinity} {cmd} --id {i}"
                # cpu_usages = psutil.cpu_percent(percpu=True)
                # least_used_cpus = sorted(range(len(cpu_usages)), key=lambda i: cpu_usages[i])
                # p_tmp = psutil.Process(os.getpid())
                # available_cpus = []
                # for j in least_used_cpus:
                #     try:
                #         p_tmp.cpu_affinity([j])
                #         available_cpus.append(j)
                #         # print(f"Available CPU: {i}")
                #     except Exception as e:
                #         pass
                # print(available_cpus[0])
                # p_tmp.cpu_affinity([available_cpus[0]])

                # modified_cmd = f"taskset -c {available_cpus[0]} {cmd} --id {i}"
                # env = envs[i].copy()
                # env['CPU_AFFINITY'] = str(available_cpus[0])
                process[i] = subprocess.Popen(cmd, env=envs[i], shell=True)
                # print(f"Started process {i} on CPU {cpu_affinity}")
                ok = True
                time.sleep(40)
                break
        if ok:
            break
        else:
            time.sleep(20)

# kill $(ps -ef | grep 'diff' | awk '{print $2}')
# kill $(ps -ef | grep 'python' | awk '{print $2}')

# ps -eo pid,comm,psr | awk '$3 == 60'