# from sentence_transformers import SentenceTransformer
import sys
sys.path.insert(0, ROOT_PATH)
import os
import argparse

import torch
import numpy as np
from tqdm import tqdm

from text_prompts import text_prompts
import comnivore.const as const

from transformers import AutoModel, AutoTokenizer
# from transformers import BartTokenizer, BartModel
from sentence_transformers import SentenceTransformer


from api_key import API_KEY
import openai
openai.api_key = API_KEY

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_hf_embedding(texts):
    model = SentenceTransformer('bert-base-nli-mean-tokens', device='cuda')
    outputs = model.encode(texts, device='cuda')
    return outputs

def get_openai_embedding(text_list):
    EMBEDDING_MODEL = "text-similarity-ada-001"
    model = EMBEDDING_MODEL
    text_list = [text.replace("\n", " ") for text in text_list]
    return_list = openai.Embedding.create(input = text_list, model=model)['data']
    embeddings = [item['embedding'] for item in return_list]
    return embeddings

tokenizer_simcse = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
model_simcse = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased").to(device)
def get_simcse_embedding(text_list):
    inputs = tokenizer_simcse(text_list, padding=True, truncation=True, return_tensors="pt").to(device)
    # Get the embeddings
    with torch.no_grad():
        embedding = model_simcse(**inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().cpu().numpy()                
    return embedding

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-d', '--dataset', type=str, default='civilcomments')
    parser.add_argument('-m', '--model', type=str, default='hf_sim')


    args = parser.parse_args()
    dataset_name = args.dataset
    model = args.model
    assert dataset_name in const.TEXT_DATA
    
    root_dir = f'{dataset_name}_features/features_{model}'
    labels_text = text_prompts[dataset_name]['labels_text']
    metas_text = text_prompts[dataset_name]['metas_text']

    embedder_fn = {
        'hf_sim': get_hf_embedding,
        'openai': get_openai_embedding,
        'simcse': get_simcse_embedding,
    }
    if dataset_name in [const.AMAZON_NAME, const.GENDER_BIAS_NAME]:
        for texts in tqdm(metas_text):
            t_all = [t for t in os.listdir(root_dir) if 'text' in t]
            store_dir = f'{root_dir}/text_emb_{str(len(t_all))}'
            os.makedirs(store_dir, exist_ok=True)
            embedding = embedder_fn[model](texts)
            with open(os.path.join(store_dir,"texts.txt"), "w") as output:
                output.write(str(texts))
            output.close()
            np.save(os.path.join(store_dir, 'texts.npy'), embedding)
            print(os.path.join(store_dir, 'texts.npy'))
    else:
        embedding_all = []
        t_all = [t for t in os.listdir(root_dir) if 'text' in t]
        store_dir = f'{root_dir}/text_emb_{str(len(t_all))}'
        os.makedirs(store_dir, exist_ok=True)
        for i, label in tqdm(enumerate(labels_text)):
            texts = [f'support {labels_text[i]} people', f'against {labels_text[i]} people']
            embedding = embedder_fn[model](texts)
            embedding_all.append(embedding)
            with open(os.path.join(store_dir,"texts.txt"), "w") as output:
                output.write(str(texts))
            output.close()
        embedding_all = np.stack(embedding_all, axis=0)
        print(embedding_all.shape)
        np.save(os.path.join(store_dir, 'texts.npy'), embedding_all)
        print(os.path.join(store_dir, 'texts.npy'))
