import os

train_step = 1700

run_name = 'supreme_topo_scale_50'
layer_name = "transformer.h.10.mlp.c_fc"

target_categories = [
    "politics.left_wing",
    "politics.right_wing",
    "science.physics",
    "science.chemistry",
    "science.biology",
    "science.math",
    "sports.football",
    "sports.olympics",
    "sports.tennis",
    "technology.software",
    "technology.artificial_intelligence",
    "technology.blockchain",
    "technology.space_exploration",
    "musicians.rock",
    "musicians.pop",
    "musicians.hip_hop",
    "actors.drama",
    "actors.action",
    "actors.comedy",
    "history.ancient",
    "history.medieval",
    "history.modern",
]

checkpoints_root = "../../../training/gpt_neo_125m/checkpoints/"
output_folder = f"maps/hierarchial/{train_step}_{layer_name}_{run_name}"
output_folder_numpy_arrays = (
    f"maps/hierarchial_numpy_arrays/{train_step}_{layer_name}_{run_name}"
)

os.system(f"mkdir -p {output_folder}")
os.system(f"mkdir -p {output_folder_numpy_arrays}")

dataset_filename = "dataset_hierarchial.json"
checkpoint_folder_name = f"checkpoint-{train_step}"

for target_category in target_categories:

    image_filename = f"{target_category}_{train_step}_{layer_name}_{run_name}.png"
    numpy_array_filename = f"{target_category}_{train_step}_{layer_name}_{run_name}.npy"

    checkpoint_folder = os.path.join(
        checkpoints_root, run_name, checkpoint_folder_name
    )
    assert os.path.exists(
        checkpoint_folder
    ), f"Invalid path:\n{checkpoint_folder}"
    output_filename = os.path.join(output_folder, image_filename)
    output_filename_numpy = os.path.join(
        output_folder_numpy_arrays, numpy_array_filename
    )
    command = f"""python3 generate_map_hierarchial.py \
    --dataset-filename {dataset_filename} \
    --checkpoint-folder {checkpoint_folder} \
    --layer-name {layer_name} \
    --output-filename {output_filename} \
    --target-category {target_category} \
    --output-filename-numpy {output_filename_numpy}"""
    os.system(command)