import torch
import os
from .moving_checkpoints import get_run_name


def load_and_filter_state_dict_keys(checkpoint_filename: str):
    """
    torch lightning module saves not only the model state dict but also other stuff.
    when loading the best model auto-saved by torch lightning, then I have to just use the 'state_dict' item from within the state dict for the model weights
    but when loading manually saved checkpoints, I can simply load the file as-is
    """
    state_dict = torch.load(checkpoint_filename, map_location="cpu")
    if "state_dict" in list(state_dict.keys()):
        state_dict = state_dict["state_dict"]
        state_dict_with_fixed_keys = {}
        for key in state_dict:
            state_dict_with_fixed_keys[key.replace("model.", "")] = state_dict[key]

        return state_dict_with_fixed_keys
    else:
        return state_dict

def get_checkpoint_path_gpt_neo_125m(
    checkpoints_dir,
    topo_scale: int, ## 0 = baseline
    global_step,
    filename = "pytorch_model.bin"
):
    run_name = get_run_name(
        topo_scale=topo_scale,
    )
    checkpoint_path = os.path.join(
        checkpoints_dir,
        run_name,
        f"checkpoint-{global_step}",
        filename
    )
    assert os.path.exists(checkpoint_path), f'Invalid checkpoint_path: {checkpoint_path}'
    return checkpoint_path