import datasets
import openai
import json
import os
import argparse
import numpy as np
from typing import List
from tqdm import tqdm, trange
from sentence_transformers import SentenceTransformer
from plugin_dataset import PlugInDataset


def batch_get_openai_embedding(text: List[str], model="text-embedding-ada-002"):
    resp = openai.Embedding.create(input=[s.replace("\n", " ") for s in text],
                                   model=model)
    res = [None] * len(text)
    for dic in resp['data']:
        res[dic['index']] = dic['embedding']
    return res


def batch_get_sbert_embedding(text: List[str], model):
    return model.encode(text)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', default='sst')
    parser.add_argument('--model', type=str, default="text-embedding-ada-002",
                        choices=['all-mpnet-base-v2', 'text-embedding-ada-002'])
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--output_dir', default='data/embeddings')
    parser.add_argument(
        "--src_key", default="question", help="source key in jsonl  (required for batch generation)"
    )
    parser.add_argument(
        "--tgt_key", default="answer", help="target key in jsonl  (required for batch generation)"
    )
    args = parser.parse_args()
    print(args)
    if args.model.startswith('text-embedding-'):
        model_type = 'openai'
    else:
        model_type = 'sbert'
        sbert_model = SentenceTransformer(args.model, device='cuda')

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    model_dim = {
        'all-mpnet-base-v2': 768,
        'text-embedding-ada-002': 1536
    }[args.model]

    dataset = datasets.load_dataset('sst2')
    for data_type in ["train", "validation"]:
        gen_dataset = PlugInDataset(data_dict=dataset, data_type=data_type, src_key=args.src_key, tgt_key=args.tgt_key,
                                    batch_size=args.batch_size)
        examples = list(gen_dataset.all_data.values())
        all_inputs = [ex.source_input for ex in examples]
        # all_targets = [ex.target_label for ex in examples]

        # for all_text, name in [(all_inputs, 'source'), (all_targets, 'target')]:
        X = np.zeros((len(all_inputs), model_dim), dtype=np.float32)
        for i in trange(0, len(all_inputs), args.batch_size):
            j = min(i + args.batch_size, len(all_inputs))
            if model_type == 'openai':
                embeddings = batch_get_openai_embedding(all_inputs[i:j], model=args.model)
            elif model_type == 'sbert':
                embeddings = sbert_model.encode(all_inputs[i:j])
            else:
                raise ValueError('Unknown model type')
            X[i:j] = np.array(embeddings, dtype=np.float32)

            # output_path = os.path.join(args.output_dir, f'{model_type}_{data_type}_{name}.npy')
        output_path = os.path.join(args.output_dir, f'{data_type}_source.npy')
        np.save(output_path, X)
        print(f'Saved to {output_path}')


if __name__ == '__main__':
    main()