import torch
from transformers import BertTokenizer, BertModel
from tqdm import tqdm
from datasets import load_dataset
import fire

class Embedder:
    def __init__(self, device, seq_len, embd_type='bert'):
        device = 'cuda:0'
        self.device = device
        self.seq_len = seq_len
        if self.seq_len > 512:
            self.seq_len = 512
            print('Reducing the embeddings seq_len to 512, as that is the maximum bert supports.')
        
        self.embd_type = embd_type
        if embd_type == 'bert':
            self.model = BertModel.from_pretrained("bert-base-uncased", device_map=device)
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        elif embd_type == 'gtr':
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer('sentence-transformers/gtr-t5-base')
            self.model.to(device)
        else:
            raise ValueError(f'Unknown embedding type: {embd_type}')

    @torch.no_grad()
    def batch_embed(self, batch, out_device=None):
        if self.embd_type == 'bert':
            model_inputs = self.tokenizer(batch, return_tensors='pt', max_length=self.seq_len, padding=True, truncation=True)
            model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
            embd = self.model(**model_inputs).pooler_output
            if out_device is not None:
                embd = embd.to(out_device)
        elif self.embd_type == 'gtr':
            embd = self.model.encode(batch)
            embd = torch.from_numpy(embd)
            if out_device is not None:
                embd = embd.to(out_device)
        return embd

    def embed_all(self, texts, batch_size=64):
        embds = []
        for i in tqdm(range(0, len(texts), batch_size)):
            embd = self.batch_embed(texts[i:i+batch_size], out_device='cpu')
            embds.append(embd)
        return torch.cat(embds, dim=0)

def get_dataset(dataset_name):
    if dataset_name == 'sql':
        dataset = load_dataset('json',
                               data_files="./data/sql/train.jsonl",
                               split="train")
        dataset = dataset.map(
            lambda example: {
                'inp': example['messages'][0]['content'],
                'label': example['messages'][1]['content'],
            }, remove_columns=['messages'])
    elif dataset_name == 'viggo':
        dataset = load_dataset('GEM/viggo', split='train')
        dataset = dataset.map(
            lambda example: {
                'inp': example['target'],
                'label': example['meaning_representation']
            })
    elif dataset_name == 'gsm8k':
        dataset = load_dataset('gsm8k', 'main', split='train')
        dataset = dataset.map(
            lambda example: {
                'inp': example['question'],
                'label': example['answer']
            })
    else:
        raise ValueError(f'Unknown dataset name: {dataset_name}')
    return dataset

def main(dset='gsm8k', embd_type='bert'):
    dataset = get_dataset(dset)

    # print args
    print(f'dset: {dset}')
    print(f'embd_type: {embd_type}')

    texts = [dataset[i]['inp'] + '\n\n' + dataset[i]['label'] for i in range(len(dataset))]

    print(f'generated {len(texts)} texts')

    embedder = Embedder(device='cuda:0', seq_len=512, embd_type=embd_type)
    embds = embedder.embed_all(texts, batch_size=64)
    print(f'embedded {len(embds)} texts')
    out_path = f'{dset}_{embd_type}_embds.pt'
    print(f'saving embeddings to {out_path}')
    torch.save(embds, out_path)

if __name__ == '__main__':
    fire.Fire(main)