import subprocess
import os
from itertools import product
from types import SimpleNamespace as SN


# Define the maximum number of jobs for each account
max_job_nums = [4]
account_combinations = [
    ["test", "test", "test"],
]

# Define the parameters to search over
parameters = {
    "seed": [112358, 1285842, 78590, 119527, 122529],
}

root_dir = f"/fs/nexus-projects/Guided_MARL/pegmarl_results"
smac_dir = "~/3rdparty/StarCraftII_2410"

param_names = list(parameters.keys())
param_values = [v for v in parameters.values()]
combinations = list(product(*param_values))

# algo_name = "qmix"
# python_command = "python src/main.py --env-config=sc2 --config=default --alg-config=qmix with env_args.map_name=5m_vs_6m t_max=10050000 name='qmix'"

algo_name = "ippo"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='env' update_gail=False t_max=10050000 name='ippo'"
python_command = "python src/main.py --env-config=sc2 --config=default_ippo_3sv4z --alg-config=ippo with env_args.map_name=3s_vs_4z rew_type='env' update_gail=False t_max=10050000 name='ippo'"
map_name = "3s_vs_4z"

# algo_name = "dm2"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' update_gail=True t_max=10050000 name='dm2'"

# algo_name = "dm2_sa"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' update_gail=True t_max=10050000 gail_state_discrim=False name='dm2_sa'"

# algo_name = "ours"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_nonco_add_pen'"

# algo_name = "ours_cotrain"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6_cotrain --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_add_pen'"

# map_name = "5m_vs_6m"

# algo_name = "ours"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_3sv4z --alg-config=ippo with env_args.map_name=3s_vs_4z rew_type='mixed' update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_nonco_add_pen'"

# algo_name = "ours_cotrain"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_3sv4z_cotrain --alg-config=ippo with env_args.map_name=3s_vs_4z rew_type='mixed' update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_add_pen'"

# map_name = "3s_vs_4z"

###### tune rew_weight
# algo_name = "ours_rew0.5"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' gail_rew_coef=0.5 update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_nonco_add_pen_rew0.5'"

# algo_name = "ours_cotrain_rew0.2"
# python_command = "python src/main.py --env-config=sc2 --config=default_ippo_5v6_cotrain --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' gail_rew_coef=0.2 update_gail=True t_max=10050000 gail_state_discrim=False learner='ippo_learner2' name='ours_add_pen_rew0.2'"

# map_name = "5m_vs_6m"


# Iterate over parameter combinations
jobs_num = 0
for combo in combinations:
    job_name = "__".join([f"{name}_{value}" for name, value in zip(param_names, combo)])
    job_name = f"{algo_name}__{map_name}__{job_name}"
    param_dict = {key: value for key, value in zip(param_names, combo)}
    param = SN(**param_dict)

    # get qos info
    remainder = jobs_num % sum(max_job_nums)
    for j, val in enumerate(max_job_nums):
        if remainder < sum(max_job_nums[:j+1]):
            account, partition, qos = account_combinations[j]
            print(f"job num: {jobs_num}, account: {account}, partition: {partition}, qos: {qos}")
            break
    jobs_num += 1

    # Create a unique job script for each combination 
    # note: change time accordingly
    job_script_content = f'''#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --output={root_dir}/slurm_logs/%x.%j.out
#SBATCH --time=48:00:00
#SBATCH --account={account}
#SBATCH --partition={partition}
#SBATCH --qos={qos}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=8
#SBATCH --mem=64gb
#SBATCH --gres=gpu:1

# Load any necessary modules
# For example, if you need Python, you might load a Python module
CONDA_BASE=$(conda info --base)
source $CONDA_BASE/etc/profile.d/conda.sh
conda activate dm2
export SC2PATH={smac_dir}

# Your Python script with parameters
srun bash -c "{python_command} --seed={param.seed}"
''' 

# python src/main.py --env-config=sc2 --config=default --alg-config=qmix with env_args.map_name=5m_vs_6m t_max=10050000 name='qmix' --seed=112358
# python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='env' update_gail=False t_max=10050000 name='ippo' --seed=112358
# python src/main.py --env-config=sc2 --config=default_ippo_5v6 --alg-config=ippo with env_args.map_name=5m_vs_6m rew_type='mixed' update_gail=True t_max=10050000 name='dm2' --seed=112358

    # Write the job script to a file
    job_script_path = f'{root_dir}/slurm_scripts/submit_job__{job_name}.sh'
    with open(job_script_path, 'w') as job_script_file:
        job_script_file.write(job_script_content)

    # Submit the job using sbatch
    subprocess.run(['sbatch', job_script_path])

    # Print the job submission info
    result = ", ".join([f"{name}: {value}" for name, value in zip(param_names, combo)])
    print(f'Job submitted for parameters: {result}')