import torch
from sentence_transformers import SentenceTransformer

import utils.const as C
import utils.datasets as UD
import utils.helpers as UH
import utils.ops as UO


def get_encoder():
    encoder = SentenceTransformer(C.ENCODER_NAME, C.DEVICE)
    encoder.eval()
    return encoder

def get_embeddings(qrels, query_texts, doc_texts, name):
    qrels_by_len = UD.get_qrels_by_len(qrels)
    doc_lens = qrels_by_len.keys()
    encoder_name = C.ENCODER_NAME.split("/")[1]
    encoder = get_encoder()
    dim = encoder.get_sentence_embedding_dimension()

    points = UH.generate_random_points(dim)
    points = UH.rescale_random_points(points, 0.2)

    for l in doc_lens:
        relevant_qids = qrels_by_len[l]

        relevant_query_texts = [query_texts[qid] for qid in relevant_qids]
        relevant_doc_texts = [[doc_texts[did] for did in qrels[qid]] for qid in relevant_qids]
        flat_doc_texts = [doc for docs in relevant_doc_texts for doc in docs]

        UH.load_or_compute(
            cache_dir=f"cache/{encoder_name}/q_embs/{l}",
            filename=f"{name}.pt",
            compute_fn=lambda: get_query_embeddings(
                encoder, relevant_query_texts
            ),
        )

        document_embeddings = UH.load_or_compute(
            cache_dir=f"cache/{encoder_name}/doc_embs/{l}",
            filename=f"{name}.pt",
            compute_fn=lambda: get_document_embeddings(
                encoder, flat_doc_texts, l
            ),
        )

        UH.load_or_compute(
            cache_dir=f"cache/{encoder_name}/tl_embs/{l}",
            filename=f"{name}.pt",
            compute_fn=lambda: get_token_level_embeddings(
                encoder, relevant_query_texts
            ),
        )

        UH.load_or_compute(
            cache_dir=f"cache/{encoder_name}/centroids/{l}",
            filename=f"{name}.pt",
            compute_fn=lambda: UO.get_centroids(document_embeddings),
        )

        UH.load_or_compute(
            cache_dir=f"cache/{encoder_name}/score_fields/{l}",
            filename=f"{name}.pt",
            compute_fn=lambda: UO.generate_score_fields(
                points, document_embeddings
            ),
        )


def get_query_embeddings(encoder, query_texts):
    return encoder.encode(query_texts, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False, device=C.DEVICE)

def get_document_embeddings(encoder, doc_texts, num_docs):
    doc_embeddings = encoder.encode(doc_texts, convert_to_tensor=True, normalize_embeddings=True, show_progress_bar=False, device=C.DEVICE)
    doc_embeddings = doc_embeddings.reshape(-1, num_docs, doc_embeddings.shape[-1])
    return doc_embeddings

def get_token_level_embeddings(encoder, relevant_query_texts):
    transformer = encoder[0].auto_model
    tokenizer = encoder.tokenizer
    encoded_queries = tokenizer(
        relevant_query_texts, padding=True, truncation=True, return_tensors="pt"
    ).to(C.DEVICE)
    with torch.no_grad():
        outputs = transformer(**encoded_queries, output_hidden_states=False)
        embedded_queries = outputs.last_hidden_state
    attention_mask = encoded_queries["attention_mask"].unsqueeze(-1)
    embedded_queries = torch.nn.functional.normalize(embedded_queries, p=2, dim=-1)
    embedded_queries = embedded_queries * attention_mask
    return embedded_queries
