import os
import torch
from sentence_transformers import SentenceTransformer
from transformers import LlamaForCausalLM, LlamaTokenizer, AutoTokenizer, AutoModel
from tqdm.autonotebook import trange

ENCODER_DIM_DICT = {"ST": 768, "e5": 1024, "llama2_7b": 4096, "llama2_13b": 5120}


class SentenceEncoder:
    def __init__(self, name, batch_size, root, multi_gpu=False):
        self.name = name
        self.batch_size = batch_size
        self.multi_gpu = multi_gpu

        if name == "ST":
            self.model = SentenceTransformer(os.path.join(root, "multi-qa-distilbert-cos-v1"))
            self.encode = self.ST_encode

        elif name == "llama2_7b":
            self.model = LlamaForCausalLM.from_pretrained(os.path.join(root, "Llama-2-7b-hf"))
            tokenizer = LlamaTokenizer.from_pretrained(os.path.join(root, "Llama-2-7b-hf"))
            tokenizer.padding_side = "right"
            tokenizer.truncation_side = 'right'
            self.tokenizer = tokenizer
            self.encode = self.llama_encode

        elif name == "llama2_13b":
            model_name = "meta-llama/Llama-2-13b-hf"
            self.model = LlamaForCausalLM.from_pretrained(model_name, cache_dir=root)
            tokenizer = LlamaTokenizer.from_pretrained(model_name, cache_dir=root)
            tokenizer.padding_side = "right"
            tokenizer.truncation_side = 'right'
            self.tokenizer = tokenizer
            self.encode = self.llama_encode

        elif name == "e5":
            model_name = "intfloat/e5-large-v2"
            tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=root)
            self.model = AutoModel.from_pretrained(model_name, cache_dir=root)
            self.tokenizer = tokenizer
            self.encode = self.e5_encode

        elif name == "roberta":
            self.model = SentenceTransformer("sentence-transformers/roberta-base-nli-stsb-mean-tokens", cache_folder=root)
            self.encode = self.ST_encode
            
        else:
            raise ValueError(f"Unknown language model: {name}.")

    def ST_encode(self, texts, to_tensor=True):
        if self.multi_gpu:
            # Start the multi-process pool on all available CUDA devices
            pool = self.model.start_multi_process_pool()
            embeddings = self.model.encode_multi_process(texts, pool=pool, batch_size=self.batch_size)
            embeddings = torch.from_numpy(embeddings)
        else:
            embeddings = self.model.encode(texts, batch_size=self.batch_size, show_progress_bar=True,
                convert_to_tensor=to_tensor, convert_to_numpy=not to_tensor)
        return embeddings

    def llama_encode(self, texts, to_tensor=True):
        # Add EOS token for padding
        self.tokenizer.pad_token = self.tokenizer.eos_token
        tokens = self.tokenizer(texts, return_tensors="pt", padding="longest", truncation=True, max_length=500)
        input_ids = tokens.input_ids
        mask = tokens.attention_mask
        embeddings = self.model.get_input_embeddings()
        word_embeddings = embeddings(input_ids)

        if not to_tensor:
            word_embeddings = word_embeddings.numpy()

        return word_embeddings, mask

    def e5_encode(self, texts, to_tensor=True):
        def average_pool(last_hidden_states, attention_mask):
            last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
            return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

        all_embeddings = []
        with torch.no_grad():
            for start_index in trange(0, len(texts), self.batch_size, desc="Batches", disable=False, ):
                sentences_batch = texts[start_index: start_index + self.batch_size]
                batch_dict = self.tokenizer(sentences_batch, padding="longest", truncation=True, return_tensors='pt')
                for item, value in batch_dict.items():
                    batch_dict[item] = value
                outputs = self.model(**batch_dict)
                embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
                embeddings = embeddings.detach()
                all_embeddings.append(embeddings)

        all_embeddings = torch.cat(all_embeddings, dim=0)
        print(all_embeddings.size())
        if not to_tensor:
            all_embeddings = all_embeddings.numpy()

        return all_embeddings
