import argparse
import os

import esm
import esm.inverse_folding
import torch
from tqdm import tqdm


def generate_esm_if1b_embedding(dataset):
    # Define paths
    pretrained_path = "models/esm_if1_gvp4_t16_142M_UR50.pt"
    output_dir = f"data/processed/{dataset}/esm_if1_embeddings"
    pdb_dir = f"data/raw/{dataset}/pdb"
    if dataset == "tim":
        pdb_dir = f"data/raw/{dataset}/pdb_dom"

    # Extract sequence names
    pdb_paths = os.listdir(pdb_dir)

    # Load model
    model, alphabet = esm.pretrained.load_model_and_alphabet_local(pretrained_path)
    print("Successfully loaded model.")
    model.eval()

    with torch.no_grad():
        for path in tqdm(pdb_paths):
            structure = esm.inverse_folding.util.load_structure(
                f"{pdb_dir}/{path}", chain="A"
            )
            coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(
                structure
            )
            rep = esm.inverse_folding.util.get_encoder_output(model, alphabet, coords)
            torch.save(rep.cpu().detach(), f"{output_dir}/{path[:-29]}.pt")


def main(dataset):
    generate_esm_if1b_embedding(dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    args = parser.parse_args()
    main(args.dataset)
