import os
from pathlib import Path


def get_checkpoint_iteration(loading_config, state_dict_name: str):
    checkpoint_path = Path(loading_config.checkpoint_path)
    # load the optimizer parameters
    if loading_config.checkpoint_iteration == "final" or loading_config.checkpoint_iteration is None:
        epoch = "final"
    else:
        epoch = loading_config.checkpoint_iteration
    # test if the checkpoint exists
    state_dict_checkpoint_path = checkpoint_path / f"{state_dict_name}_{epoch}.pt"
    if not state_dict_checkpoint_path.exists():
        print(f"No checkpoint '{loading_config.checkpoint_iteration}' found for state_dict {state_dict_name}. Try using the last checkpoint instead.")
        files = os.listdir(checkpoint_path)
        files = [x for x in files if x.startswith("mpn_simulator_state_dict_")]
        files = [x for x in files if x.endswith(".pt")]
        files = [x[len("mpn_simulator_state_dict_"):-len(".pt")] for x in files]
        if "final" in files:
            epoch = "final"
        else:
            int_files = [x for x in files if x.isdigit()]
            int_files = [int(x) for x in int_files]
            if len(int_files) == 0:
                raise FileNotFoundError(f"No checkpoint at all found for state_dict {state_dict_name}.")
            epoch = max(int_files)
        print(f"Using iteration {epoch} instead as checkpoint.")

    return epoch

