from transformers import LlamaModel, LlamaTokenizer
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils.seed import seed_everything
import fire


def get_text_embeddings(
    parquet_path,
    taxonomy_level="species",
    llm_type="Llama-2-7b-hf",
    device="cpu",
    seed=42,
):
    # taxonomy_level: one of 'class', 'order', 'family', 'genus', 'species'
    # llm_type: one of 'Llama-2-7b-hf', 'Llama-2-13b-hf', 'Llama-2-70b-hf'

    seed_everything(seed)
    obs = pd.read_parquet(parquet_path)
    obs = obs.drop_duplicates(subset=[taxonomy_level])
    tokenizer = LlamaTokenizer.from_pretrained(
        f"meta-llama/{llm_type}", legacy=True, device_map=device
    )
    model = LlamaModel.from_pretrained(
        f"meta-llama/{llm_type}", device_map=device
    ).eval()
    di = {}
    for i in tqdm(obs[taxonomy_level].unique()):
        row = obs[obs[taxonomy_level] == i].iloc[0]
        if taxonomy_level == "class":
            tokens = tokenizer.encode("{class: Aves}")
        elif taxonomy_level == "order":
            tokens = tokenizer.encode(
                "{class: Aves, " + f"{taxonomy_level}: " + f"{i}" + "}"
            )
        elif taxonomy_level == "family":
            tokens = tokenizer.encode(
                "{class: Aves, "
                + f"order: {row['order']}, {taxonomy_level}: "
                + f"{i}"
                + "}"
            )
        elif taxonomy_level == "genus":
            tokens = tokenizer.encode(
                "{class: Aves, "
                + f"order: {row['order']}, family: {row['family']}, {taxonomy_level}: "
                + f"{i}"
                + "}"
            )
        elif taxonomy_level == "species":
            tokens = tokenizer.encode(
                "{class: Aves, "
                + f"order: {row['order']}, family: {row['family']}, genus: {row['genus']}, {taxonomy_level}: "
                + f"{i}"
                + "}"
            )

        out = model(torch.LongTensor(tokens).unsqueeze(0).to(device))
        di[f"{i}"] = out.last_hidden_state.detach().cpu().numpy()

    np.save(
        f"data/text_embeddings/{taxonomy_level}_{llm_type}.npy", di, allow_pickle=True
    )


if __name__ == "__main__":
    fire.Fire(get_text_embeddings)
