import torch
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

class TokenizedInput:
    def __init__(self, tokenized_title):
        self.tokenized_title = tokenized_title

def predict_embeddings(model, data_loader, device):
    """
    Perform inference with the model and extract embeddings
    """
    print("Predicting embeddings...")
    output_cry_list = []
    output_text_list = []
    with torch.no_grad():
        for batch in tqdm(data_loader):
            batch = batch.to(device)
            output_cry_run, output_text_run = model(batch)
            output_cry_list.append(output_cry_run.cpu())
            output_text_list.append(output_text_run.cpu())

    output_cry = torch.cat(output_cry_list, dim=0)
    output_text = torch.cat(output_text_list, dim=0)
    print("Embedding prediction completed.")
    return output_cry, output_text


def normalize_embedding(embedding, norm_type):
    """
    Normalize the embeddings using the specified normalization method
    """
    if norm_type == 'l2':
        embedding /= embedding.norm(p=2, dim=1, keepdim=True)
    elif norm_type == 'minmax':
        embedding = (embedding - embedding.min(dim=1, keepdim=True)[0]) / (embedding.max(dim=1, keepdim=True)[0] - embedding.min(dim=1, keepdim=True)[0])
    else:
        raise ValueError(f"Unsupported normalization type: {norm_type}")
    return embedding


def encode_texts(text_list, tokenizer, text_encoder, cfg, device):
    """
    Receive a list of texts, encode them, and return the embeddings
    """
    encoded_input = tokenizer(text_list, padding=True, truncation=True, return_tensors='pt')

    data = TokenizedInput({
        "input_ids": encoded_input['input_ids'].to(device),
        "attention_mask": encoded_input['attention_mask'].to(device)
    })
    
    with torch.no_grad():
        embedding = text_encoder(data).cpu()
    if cfg.embedding_normalize is not None:
        embedding = normalize_embedding(embedding, cfg.embedding_normalize)
        
    return embedding


def calculate_material_category_similarities(target_embeddings, categories, tokenizer, text_encoder, cfg, device):
    """
    Calculate the similarity between the embeddings of multiple materials and the text embedding representing a category
    """
    category_embeddings = [encode_texts([category], tokenizer, text_encoder, cfg, device) for category in categories]
    category_embeddings = torch.stack(category_embeddings).squeeze(1)  

    all_similarities = cosine_similarity(target_embeddings[:, None, :], category_embeddings[None, :, :], dim=2)
    return all_similarities