import argparse

import esm
import torch
from esm import FastaBatchedDataset
from tqdm import tqdm


def generate_esm_2_embedding(dataset):
    # Define paths
    pretrained_path = "models/esm2_t36_3B_UR50D.pt"
    # Load model
    model, alphabet = esm.pretrained.load_model_and_alphabet_local(pretrained_path)
    print("Successfully loaded model.")
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
        print("Transferred model to GPU")

    # Final representation is of interest
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in [-1]]

    # Dataset specific paths
    fasta_path = f"data/raw/{dataset}/{dataset}.fasta"
    if dataset == "tim":
        fasta_path = f"data/raw/{dataset}/{dataset}_dom.fasta"
    output_dir = f"data/processed/{dataset}/esm_2_embeddings"

    # Load dataset
    dataset = FastaBatchedDataset.from_file(fasta_path)
    batches = dataset.get_batch_indices(1024, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
    )
    print(f"Read {fasta_path} with {len(dataset)} sequences")

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in tqdm(enumerate(data_loader)):
            # Move to GPU
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            # Generate and extract representations over batch
            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }
            # Iterate through sequence and save results
            for i, label in enumerate(labels):
                output_path = f"{output_dir}/{label}.pt"
                # Create dictionary of representations
                result = {"label": label}
                result["representations"] = {
                    layer: t[i, 1 : len(strs[i]) + 1].clone()
                    for layer, t in representations.items()
                }
                result["mean_representations"] = {
                    layer: t[i, 1 : len(strs[i]) + 1].mean(0).clone()
                    for layer, t in representations.items()
                }
                # Save file
                torch.save(result, output_path)


def main(dataset):
    generate_esm_2_embedding(dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset", type=str)
    args = parser.parse_args()
    main(args.dataset)
