import os

from easy_runner import EasyRunner
import json
from pathlib import Path
# import tools
# task = "boxing"
# Initialize the EasyRunner
import argparse
# paths = []
exp_names = []
instructions = []
# 创建参数解析器
parser = argparse.ArgumentParser(description="Indicating train mode")
# 添加参数
parser.add_argument('--train_mode', type=str, default="both", choices=["both","stage1","stage2","eval","abl_wo_stage2","test"])

# 解析参数
args = parser.parse_args()
print(f"train mode: {args.train_mode}")
#[todo] variable start
batch_size = 4
debug = False
seed = 0
eval_episode_num = 2
prefix = "dmc_proprio"
log_name = f"{prefix}_seperate_expert_task_grouping_sameseed{seed}"


steps = 1e5 #1e5
add_token_embed = False
moe_start_steps_ratio = 0.8
stage1_only = None  #None or list
specific_num_states = 24
specific_num_actions = 6
group_num = 9
#[todo] variable end
if not add_token_embed:
    log_name.replace("token_","")
if args.train_mode != "abl_wo_stage2":
    expert_steps = int(steps * moe_start_steps_ratio)
else:
    expert_steps = steps

def config2list(input):
    if isinstance(input, list):
        train_env_name_list = input
    elif isinstance(input, tuple):
        train_env_name_list = list(input)
    else:
        train_env_name_list = [x.strip() for x in input.strip("[]").split(",")]
    return train_env_name_list
def list_to_str(lst):
    return "[" + ",".join(str(elem) for elem in lst) + "]"
# instructions.append("export HF_ENDPOINT=https://hf-mirror.com")
# exp_names.append("mirror")
group_save_dir = Path(f"./logdir/{log_name}").expanduser()
group_save_dir.mkdir(exist_ok=True, parents=True)
if not os.path.exists(group_save_dir / f"{prefix}_final_groups.json"):
    with open(f"{prefix}_final_groups.json", "r", encoding="utf-8") as f:
        prefix_dict = json.load(f)
        with open( group_save_dir / f"{prefix}_final_groups.json", "w", encoding="utf-8") as fout:
            json.dump(prefix_dict, fout, indent=2, ensure_ascii=False)
else:
    with open(group_save_dir / f"{prefix}_final_groups.json", "r", encoding="utf-8") as f:
        prefix_dict = json.load(f)
# 初始化空字典：键:前缀（如a/b/c），值:对应元素列表
# prefix_dict = {}

# # 1. 遍历原始列表，按前缀动态分组
# for elem in train_env_name_list:
#     # 拆分前缀：按下划线分割，处理无下划线的异常元素（避免索引报错）
#     parts = elem.split("_")
#     if len(parts) < 2:
#         print(f"警告：元素「{elem}」无下划线，前缀按整体处理")
#         prefix = elem  # 无下划线时，前缀为元素本身
#     else:
#         prefix = parts[0]  # 有下划线时，取第一个部分为前缀
#
#     # 动态添加：前缀不存在则初始化空列表，再追加元素
#     prefix_dict.setdefault(prefix, []).append(elem)
sorted_keys = sorted(
    prefix_dict.keys(),  # 取字典所有的 key 作为排序对象
    key=lambda k: len(prefix_dict[k]),  # 排序依据：key 对应的 value 列表的长度
    reverse=True  # 降序排列（默认是升序，需显式指定）
)
print(sorted_keys)
for key in prefix_dict.keys():
    prefix_dict[key] = list_to_str(prefix_dict[key])
    print(f"{key}:{prefix_dict[key]}")

if args.train_mode == "test":
    sorted_keys = ["group_1"]

if debug:
    raise NotImplementedError("Not debug mode yet")
else:
    runner = EasyRunner(log_name=log_name)
    if args.train_mode == "both" or args.train_mode == "stage1" or args.train_mode == "abl_wo_stage2":
        ############专家阶段的训练####################
        for name in sorted_keys:
            if stage1_only is not None:
                if name not in stage1_only:
                    continue
            instructions.append(f"nohup python -u taskloom.py --configs dmc_proprio --task dmc_multitask_proprio --logdir ./logdir/{log_name}/{name} --batch_size {batch_size} --wm_with_moe False --envs 1 --train_env_name_list '{prefix_dict[name]}' --eval_episode_num {eval_episode_num} --multi_actor_train None --actor_train_mode seperate --multi_actor_sample None --seed {seed} --steps {expert_steps} --add_token_embed {add_token_embed} --moe_start_steps_ratio {moe_start_steps_ratio}  --specific_num_states {specific_num_states} --specific_num_actions {specific_num_actions} --group_num {group_num}")
            exp_names.append(f"{name}_expert")
    if args.train_mode == "both" or args.train_mode == "stage2" or args.train_mode == "eval" or args.train_mode == "test":
        ############最后阶段的训练####################
        for name in sorted_keys:
            instructions.append(f"nohup python -u taskloom.py --configs dmc_proprio --task dmc_multitask_proprio --logdir ./logdir/{log_name}/final/{name} --batch_size {batch_size} --wm_with_moe True --envs 1 --train_env_name_list '{prefix_dict[name]}' --eval_episode_num {eval_episode_num} --multi_actor_train None --actor_train_mode seperate --multi_actor_sample None --seed {seed} --steps {steps} --add_token_embed {add_token_embed} --moe_start_steps_ratio {moe_start_steps_ratio}  --specific_num_states {specific_num_states} --specific_num_actions {specific_num_actions} --group_num {group_num}")
            exp_names.append(f"{name}_final")

if args.train_mode == "stage1":
    runner.start(instructions, exp_names = exp_names, max_parallel=10)
elif args.train_mode == "abl_wo_stage2":
    runner.start(instructions, exp_names = exp_names, max_parallel=7)
elif args.train_mode == "eval":
    runner.start(instructions, exp_names = exp_names, max_parallel=1)
else:
    runner.start(instructions, exp_names = exp_names, max_parallel=3)
