import os
import torch
import argparse
import numpy as np
from tqdm import tqdm

from prj_rag import common, constants
from poisonedrag.src.utils import load_beir_datasets, load_models


subset_frac = 1.0


def main(args: dict):
    common.setup_seeds(args["seed"])

    corpus, _, _ = load_beir_datasets(args["dataset"], "test")
    len_data = len(corpus)
    print("Total data len: ", len_data)

    if subset_frac < 1:
        subset_size = int(subset_frac * len_data)
        corpus_subset_keys = sorted(list(corpus.keys()))[:subset_size]
        corpus_subset = {k: corpus[k] for k in corpus_subset_keys}
    else:
        corpus_subset = corpus
        corpus_subset_keys = sorted(list(corpus.keys()))

    _, ret_model_context, ret_tokenizer, ret_get_enc = load_models(
        args["ret_model"], cache_dir=constants.hf_dir
    )
    ret_model_context = ret_model_context.eval().to(args["device"])

    corpus_subset_encoded = {}

    # Encode the dataset batching over the keys
    for i in tqdm(range(0, len(corpus_subset_keys), args["batch_size"])):
        batch_keys = corpus_subset_keys[i : i + args["batch_size"]]

        corpus_subset_tokenized = {}
        input_ids = []
        attention_masks = []
        token_type_ids = []
        for k in batch_keys:
            v = corpus_subset[k]
            tok_out = ret_tokenizer(
                v["text"], return_tensors="pt", padding="max_length", truncation=True
            )
            input_ids.append(tok_out["input_ids"].to(args["device"]))
            attention_masks.append(tok_out["attention_mask"].to(args["device"]))
            token_type_ids.append(tok_out["token_type_ids"].to(args["device"]))
        input_ids = torch.cat(input_ids, dim=0)
        attention_masks = torch.cat(attention_masks, dim=0)
        token_type_ids = torch.cat(token_type_ids, dim=0)

        enc = ret_get_enc(
            ret_model_context,
            {
                "input_ids": input_ids,
                "attention_mask": attention_masks,
                "token_type_ids": token_type_ids,
            },
        )
        enc = enc.detach().cpu().numpy()
        for j, k in enumerate(batch_keys):
            corpus_subset_encoded[k] = enc[j]

    # Encode the dataset without batching
    # corpus_subset_tokenized = {
    #     k: ret_tokenizer(
    #         v["text"], return_tensors="pt", padding="max_length", truncation=True
    #     )
    #     for k, v in corpus_subset.items()
    # }
    # for k, v in tqdm(corpus_subset_tokenized.items()):
    #     v = {k2: v2.to(args["device"]) for k2, v2 in v.items()}
    #     enc = ret_get_enc(ret_model_context, v)
    #     enc = enc.detach().cpu().numpy()
    #     v = {k2: v2.detach().cpu() for k2, v2 in v.items()}
    #     corpus_subset_encoded[k] = enc

    # Save the encoded dataset
    out_pth = os.path.join(constants.prj_dir, f"{args['dataset']}_{args['ret_model']}")
    os.makedirs(out_pth, exist_ok=True)
    np.save(os.path.join(out_pth, "corpus_encoded.npy"), corpus_subset_encoded)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Encode dataset")
    parser.add_argument(
        "--dataset",
        type=str,
        help="Dataset name",
        choices=["nq", "msmarco", "hotpotqa"],
        default="nq",
    )
    parser.add_argument(
        "--ret_model",
        type=str,
        help="Retriever model",
        default="contriever",
    )
    parser.add_argument(
        "--seed",
        type=int,
        help="Random seed",
        default=42,
    )
    parser.add_argument(
        "--device",
        type=str,
        help="Device name",
        default="cuda:0",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        help="Batch size",
        default=32,
    )

    args = parser.parse_args()
    args = vars(args)
    main(args)
