from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
import os
from nesim.utils.folder import make_folder_if_does_not_exist

checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
global_step = 10500
checkpoints_map = {
    "untrained": None,
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, topo_scale=0, global_step=global_step
    ),
    "topo_1": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, topo_scale=1, global_step=global_step
    ),
    "topo_5": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, topo_scale=5, global_step=global_step
    ),
    "topo_10": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, topo_scale=10, global_step=global_step
    ),
    "topo_50": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, topo_scale=50, global_step=global_step
    ),
}

commands = []
hook_output_folders = {}

for checkpoint_name in checkpoints_map:
    checkpoint_filename = checkpoints_map[checkpoint_name]
    hook_output_folder = os.path.join(
        "/research/XXXX-1/toponets_hook_outputs_gpt_neo_125m",
        checkpoint_name
    )

    make_folder_if_does_not_exist(folder=hook_output_folder)

    command = f"python3 obtain_hook_outputs.py --checkpoint-filename {checkpoint_filename} --hook-output-folder {hook_output_folder} --layer-names-json layer_names.json"
    commands.append(command)
    hook_output_folders[checkpoint_name] = hook_output_folder

# for command in commands:
#     os.system(command)


"""
Now compute the effective dimensionalities
"""

commands = []

for checkpoint_name in checkpoints_map:
    result_filename = os.path.join(
        "results",
        f"{checkpoint_name}.json"
    )
    hook_output_folder = hook_output_folders[checkpoint_name]
    command = f"python3 compute_effective_dimensionality.py --hook-output-folder {hook_output_folder} --layer-names-json layer_names.json --result-filename {result_filename}"
    commands.append(command)

for command in commands:
    os.system(command)
