import os
import json
import argparse
import torch
import yaml
from tqdm import tqdm

from models.multimodal_encoder.t5_encoder import T5Embedder


def encode_lang(
    DATA_FILE_PATH,
    TARGET_DIR,
    GPU,
    desc_type="seen",
    tokenizer=None,
    text_encoder=None,
):
    current_dir = os.path.dirname(__file__)

    with open(os.path.join(current_dir, "../configs/base.yaml"), "r") as fp:
        config = yaml.safe_load(fp)

    device = torch.device(f"cuda:{GPU}")
    if tokenizer is None or text_encoder is None:
        text_embedder = T5Embedder(
            from_pretrained=os.path.join(current_dir, "../../weights/RDT/t5-v1_1-xxl"),
            model_max_length=config["dataset"]["tokenizer_max_length"],
            device=device,
            use_offload_folder=None,
        )
        tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model

    with open(DATA_FILE_PATH, "r") as f_instr:
        instruction_dict = json.load(f_instr)

    instructions = instruction_dict[desc_type]

    # Encode the instructions
    tokenized_res = tokenizer(instructions, return_tensors="pt", padding="longest", truncation=True)
    tokens = tokenized_res["input_ids"].to(device)
    attn_mask = tokenized_res["attention_mask"].to(device)

    with torch.no_grad():
        text_embeds = (text_encoder(input_ids=tokens, attention_mask=attn_mask)["last_hidden_state"].detach().cpu())

    attn_mask = attn_mask.cpu().bool()
    if not os.path.exists(f"{TARGET_DIR}/instructions"):
        os.makedirs(f"{TARGET_DIR}/instructions")
    # Save the embeddings for training use
    for i in range(len(instructions)):
        text_embed = text_embeds[i][attn_mask[i]]
        save_path = os.path.join(TARGET_DIR, f"instructions/lang_embed_{i}.pt")
        # print("encoded instructions save_path:",save_path)
        torch.save(text_embed, save_path)

    return tokenizer, text_encoder
