### Preamble ##########################################################################################################

"""
Helper functions for use in experiments.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import numpy as np
import torch
from diffusers import DiffusionPipeline

from typing import Union

#######################################################################################################################


@torch.no_grad()
def get_embedding(
    prompt: Union[str, list, torch.Tensor],
    diffusion_model: DiffusionPipeline,
    embedding_type: str = "eos",
    device: str = "cuda",
):
    """
    :param prompt: str, list, or torch.Tensor
        String or list-like of strings representing the prompts to return CLIP text embeddings for.
    :param diffusion_model: DiffusionPipeline
        Instance of a DiffusionPipeline from the `diffusers` library.
    :param embedding_type: str
        The type of embedding to return from the CLIP model, one of "eos", "average", "flat". "eos" returns the
        embedding at the "end of string" token, "average" returns the adverage embedding across all tokens, "flat"
        returns the flattened embedding across all tokens.
    :param device: str
        Device to compute the embedding on.

    Computes the embedding of a given string.
    """

    encoded = diffusion_model.encode_prompt(prompt, device, 1, True)[0]

    if embedding_type == "eos":
        # Get tokenized prompt to find first EOS token
        tokenizer = diffusion_model.tokenizer
        eos_id = tokenizer.eos_token_id
        tokenized_prompt = tokenizer(
            prompt, padding="max_length", max_length=diffusion_model.tokenizer.model_max_length, return_tensors="pt"
        )
        idxs = torch.argmax((tokenized_prompt["input_ids"] == eos_id).to(dtype=torch.int), dim=1)
        embedding = encoded[torch.arange(encoded.shape[0]), idxs, :]
    elif embedding_type == "average":
        embedding = torch.mean(encoded, dim=1)
    elif embedding_type == "flat":
        embedding = torch.flatten(encoded, start_dim=1)

    return embedding


@torch.no_grad()
def get_distance_mat(embedding: torch.Tensor, distance_type: str = "l2"):
    """
    :param embedding: torch.Tensor
        A matrix of embeddings, should be of shape (N, E), where N is the batch size and E is the embedding dimension.
    :param distance_type: str
        The type of distance metric to use, one of "l2" or "cosine". "l2" uses Euclidean distance and "cosine" uses
        cosine similarity.

    Computes the distance matrix between every pair of embedding vectors and returns the distance matrix.
    """

    embedding = embedding.to(dtype=torch.float32)

    if distance_type == "l2":
        dist = torch.cdist(embedding, embedding, p=2).clamp(min=0)
    elif distance_type == "cosine":
        X = torch.nn.functional.normalize(embedding, dim=1)
        cos_sim = X @ X.T
        dist = (1 - cos_sim).clamp(min=0, max=1)
    return dist


#######################################################################################################################
