# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

import argparse
import csv
import logging
import pickle

import numpy as np
import torch

import transformers

import source.controller.retriever.slurm
import source.controller.retriever.contriever
import source.controller.retriever.utils
import source.controller.retriever.data
import source.controller.retriever.normalize_text

from source.controller.retriever.contriever import Contriever
from transformers import AutoTokenizer

def embed_passages(passages, model, tokenizer, passage_maxlength=300, per_gpu_batch_size=1236/10):
    total = 0
    allids, allembeddings = [], []
    batch_ids, batch_text = [], []
    with torch.no_grad():
        for k, p in enumerate(passages):
            batch_ids.append(p["id"])
            text = str(p["text"])
            text = text.lower()
            text = source.controller.retriever.normalize_text.normalize(text)
            batch_text.append(text)

            if len(batch_text) == per_gpu_batch_size or k == len(passages) - 1:

                encoded_batch = tokenizer.batch_encode_plus(
                    batch_text,
                    return_tensors="pt",
                    max_length=passage_maxlength,
                    padding=True,
                    truncation=True,
                )
                
                if torch.cuda.is_available():
                    encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
                else:
                    encoded_batch = {k: v for k, v in encoded_batch.items()}
                embeddings = model(**encoded_batch)

                embeddings = embeddings.cpu()
                total += len(batch_ids)
                allids.extend(batch_ids)
                allembeddings.append(embeddings)

                batch_text = []
                batch_ids = []
                if k % 1 == 0 and k > 0:
                    print(f"Encoded passages {total}")

    allembeddings = torch.cat(allembeddings, dim=0).numpy()
    return allids, allembeddings


def main(args):
    model, tokenizer, _ = source.controller.retriever.contriever.load_retriever(args.model_name_or_path)
    print(f"Model loaded from {args.model_name_or_path}.", flush=True)
    model.eval()
    model = model.cuda()
    if not args.no_fp16:
        model = model.half()

    passages = source.controller.retriever.data.load_passages(args.passages)

    shard_size = len(passages) // args.num_shards
    start_idx = args.shard_id * shard_size
    end_idx = start_idx + shard_size
    if args.shard_id == args.num_shards - 1:
        end_idx = len(passages)

    passages = passages[start_idx:end_idx]
    print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")

    allids, allembeddings = embed_passages(args, passages, model, tokenizer)

    save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}")
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Saving {len(allids)} passage embeddings to {save_file}.")
    with open(save_file, mode="wb") as f:
        pickle.dump((allids, allembeddings), f)

    print(f"Total passages processed {len(allids)}. Written to {save_file}.")
    
    

def main_modified(model_name_or_path, passages, num_shards=1, shard_id=0):
    model = Contriever.from_pretrained("facebook/contriever") 
    tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
    
    print(f"Model loaded from {model_name_or_path}.", flush=True)
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()

    shard_size = len(passages) // num_shards
    start_idx = shard_id * shard_size
    end_idx = start_idx + shard_size
    if shard_id == num_shards - 1:
        end_idx = len(passages)

    passages = passages[start_idx:end_idx]
    print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")

    allids, allembeddings = embed_passages(passages, model, tokenizer)

    # print(f"Total passages processed {len(allids)}.")
    return allids, allembeddings

# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()

#     parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)")
#     parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings")
#     parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings")
#     parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard")
#     parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards")
#     parser.add_argument(
#         "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
#     )
#     parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage")
#     parser.add_argument(
#         "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
#     )
#     parser.add_argument("--no_fp16", action="store_true", help="inference in fp32")
#     parser.add_argument("--no_title", action="store_true", help="title not added to the passage body")
#     parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding")
#     parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding")

#     args = parser.parse_args()

#     source.controller.retriever.slurm.init_distributed_mode(args)

#     main(args)
