import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from itertools import product
from typing import Union, List


def __seq_id_similarity_tok(seq_a: torch.Tensor, seq_b: torch.Tensor) -> float:
    # assume padding is done consistently
    return torch.all(seq_a==seq_b).to(dtype=torch.float)

def __seq_id_similarity_txt(seq_a: str, seq_b: str) -> float:
    return float(seq_a==seq_b)

def compute_seq_id_similarity(sequences: Union[torch.Tensor, List[str]]) -> torch.FloatTensor:
    # sequence identity similarity - if two seqs are identical, they will have similarity of 1, else 0
    n_seqs = len(sequences)
    res = torch.zeros(n_seqs,n_seqs, dtype=torch.float)
    for i in range(n_seqs):
        for j in range(n_seqs):
            res[i, j] = \
                __seq_id_similarity_tok(sequences[i], sequences[j]) if isinstance(sequences, torch.Tensor) else \
                __seq_id_similarity_txt(sequences[i], sequences[j])
    return res


@torch.no_grad()
def compute_similarity_matrix(
    outputs, 
    model="all-MiniLM-L6-v2", 
    return_daisy_chain = False,
    **kwargs
):
    if isinstance(model, str):
        model = SentenceTransformer(model, **kwargs)
    assert isinstance(model, SentenceTransformer), f"Provided model is not a SentenceTransformer: {type(model)}"
    
    # Two lists of sentences
    sentences1 = outputs
    sentences2 = sentences1
    
    # Compute embeddings for both lists
    embeddings1 = model.encode(sentences1)
    embeddings2 = model.encode(sentences2)
    
    # Compute cosine similarities
    similarities = model.similarity(embeddings1, embeddings2, convert_to_tensor=True).cpu()
    # symmetrize just in case
    similarities = (similarities+similarities.T)/2

    if return_daisy_chain:
        return similarities, {'model': model, **kwargs}
    else:
        return similarities

# dc = {'model':"all-MiniLM-L6-v2", 'device':'cuda'}
# sim, dc = compute_similarity_matrix(
#     list(records[0]['txt_y']),
#     return_daisy_chain=True,
#     **dc
# )
# sim



@torch.no_grad()
def compute_similarity_matrix_nli(
    outputs,
    model,
    context = '',
):
    # Two lists of sentences
    outputs = [context+' '+o for o in outputs]
    nli_input = list(product(outputs, outputs))
    
    # Compute embeddings for both lists
    scores = model.predict(nli_input, convert_to_tensor=True)
    # select only entailment
    scores = scores.argmax(axis=1).cpu()
    scores = (scores==1).to(torch.float)
    # reshape to square matrix
    similarities = scores.reshape(len(outputs), len(outputs))
    # symmetrize consistent with Kuhn code (logical and both ways)
    similarities = (similarities*similarities.T)

    return similarities


@torch.no_grad()
def compute_answer_embeddings(
    outputs, 
    model="all-MiniLM-L6-v2", 
    return_daisy_chain = False,
    **kwargs
):
    if isinstance(model, str):
        model = SentenceTransformer(model, **kwargs)
    assert isinstance(model, SentenceTransformer), f"Provided model is not a SentenceTransformer: {type(model)}"
 
    # Compute embeddings for both lists
    embeddings = model.encode(outputs, convert_to_tensor=True).cpu()
    
    if return_daisy_chain:
        return embeddings, {'model': model, **kwargs}
    else:
        return embeddings


@torch.no_grad()
def compute_question_answer_embeddings(
    outputs, 
    prompt,
    prompt_prefix='Represent answers to the following question: ',
    model="hkunlp/instructor-large", 
    return_daisy_chain = False,
    **kwargs
):
    if isinstance(model, str):
        model = SentenceTransformer(model, **kwargs)
    assert isinstance(model, SentenceTransformer), f"Provided model is not a SentenceTransformer: {type(model)}"
     
    # Compute embeddings for both lists
    embeddings = model.encode(outputs, prompt=prompt_prefix+prompt, convert_to_tensor=True).cpu()
    
    if return_daisy_chain:
        return embeddings, {'model': model, **kwargs}
    else:
        return embeddings
