
import sys
sys.path.insert(0, ROOT_PATH)
import os
import argparse

from shared.utils import *
from api_key import API_KEY
from comnivore.dataloader import MultiEnvDataset

import comnivore.const as const

import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# torch.cuda.set_device(1)

import openai
openai.api_key = API_KEY

import numpy as np
# from sentence_transformers import SentenceTransformer
# from transformers import BartTokenizer, BartModel
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer


from text_prompts import text_prompts


def get_hf_embedding(texts):
    model = SentenceTransformer('bert-base-nli-mean-tokens', device='cuda')
    outputs = model.encode(texts, device='cuda')
    return outputs

tokenizer_simcse = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased")
model_simcse = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased").to(device)
def get_simcse_embedding(text_list):
    inputs = tokenizer_simcse(text_list, padding=True, truncation=True, return_tensors="pt").to(device)
    with torch.no_grad():
        embedding = model_simcse(**inputs, output_hidden_states=True, return_dict=True).pooler_output.detach().cpu().numpy()                
    return embedding

def get_openai_embedding(text_list):
    EMBEDDING_MODEL = "text-similarity-ada-001"
    model = EMBEDDING_MODEL
    text_list = [text.replace("\n", " ") for text in text_list]
    return_list = openai.Embedding.create(input = text_list, model=model)['data']
    embeddings = [item['embedding'] for item in return_list]
    return embeddings

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='run CLIP zero shot')
    parser.add_argument('-d', '--dataset', type=str, default='civilcomments')
    parser.add_argument('-bs', '--batch_size', type=int, default=32)
    parser.add_argument('-m', '--model', type=str, default='hf')

    args = parser.parse_args()
    
    dataset_name = args.dataset
    model_name = args.model
    assert dataset_name in const.TEXT_DATA

    batch_size = args.batch_size

    
    store_dir = f'{dataset_name}_features/features_{model_name}'
    if not os.path.isdir(store_dir):
        os.makedirs(store_dir)

    embedder_fn_dict = {
        'openai': get_openai_embedding,
        'hf_sim': get_hf_embedding,
        'simcse': get_simcse_embedding,
    }
    embedder = embedder_fn_dict[model_name]
    
    try:
        labels_text = text_prompts[dataset_name][f'labels'] 
    except:
        labels_text = text_prompts[dataset_name][f'labels_{model_name}'] 
    
    label_embeddings = embedder(labels_text)
    np.save(os.path.join(store_dir, 'labels.npy'), label_embeddings)
    # exit()
    dataloaders = MultiEnvDataset().get_dataloaders(dataset_name, batch_size)


    for i, dataloader in enumerate(dataloaders):
        embeddings_all = []
        y_all = []
        metadata_all = []
        for j, labeled_batch in tqdm(enumerate(dataloader)):
            if len(labeled_batch) == 3:
                x, y, metadata = labeled_batch
                metadata = metadata.detach().cpu().numpy()
            else:
                x, y = labeled_batch
            y = y.detach().cpu().numpy().tolist()
            try:
                x = list(x)
                text_embedding = embedder(x)
                embeddings_all.extend(text_embedding)
                y_all.extend(y)
                metadata_all.append(metadata)
            except Exception as e:
                raise e
                continue
    embeddings_all = np.vstack(embeddings_all)
    # print(embeddings_all.shape)
    if len(labeled_batch) == 3:
        try:
            metadata_all = np.vstack(metadata_all)
        except:
            metadata_all = np.hstack(metadata_all).reshape(-1,1)
    y_all = np.array(y_all)

    os.makedirs(os.path.join(store_dir, str(i)), exist_ok=True)

    np.save(os.path.join(store_dir, str(i), 'emb.npy'), embeddings_all)
    if len(labeled_batch) == 3:
        np.save(os.path.join(store_dir, str(i), 'metadata.npy'), metadata_all)
    np.save(os.path.join(store_dir, str(i), 'y.npy'), y_all)
    print(f"features etc saved to {os.path.join(store_dir, str(i))}")



