import os
import pandas as pd
import numpy as np
import torch

def huggingface_avg_subwords(model_name, context, language_embeds, tokenizer):
    if model_name != 'gpt-2':
        idx = 1
    else:
        idx = 0
    get_last = lambda x: np.array([x[-1]])
    embeddings = []
    for context_idx, context_word in enumerate(context):
        if context_idx > 0:
            context_word = f' {context_word}'
        subwords = tokenizer.tokenize(context_word) #TODO: Check that words are being sanitized of punctuation and casing
        embedding = torch.mean(language_embeds[:, idx:idx+len(subwords), :], dim = 1).unsqueeze(0)
        if torch.isnan(embedding).any():
            raise ValueError(f'Embedding for context \"{context}\" has nan')
        idx += len(subwords)
    return embedding[:, -1, :]

def clip_avg_subwords(context, language_embeds):
    SOS_token = 49406
    EOS_token = 49407
    idx = 1
    for context_idx, context_word in enumerate(context):
        if context_idx > 0:
            context_word = f' {context_word}'
        subwords = clip.tokenize(context_word)
        eos_index = (subwords == EOS_token).nonzero(as_tuple=True)[1].item()
        embedding = torch.mean(language_embeds[:, idx:idx+subwords[:, 1:eos_index].shape[1], :], dim = 1).unsqueeze(0)
        if torch.isnan(embedding).any():
            raise ValueError(f'Embedding for context \"{context}\" has nan')
        idx+=subwords[:, 1:eos_index].shape[1]
    return embedding[:, -1, :]