import torch
import numpy as np
import tqdm
import datasets
import pickle
import os
import torch
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          BitsAndBytesConfig,
                          AutoConfig,
                          pipeline)

END_OF_GENERATION_TOKENS = [
    "Question:",
    " Question:",
    "Question: ",
    "\n",
    "Answer:",
    "\nQuestion:",
    " Answer:",
    "Q:",
]


device = 'cuda'
data_name='trivia_qa'
HF_TOKEN = ""

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)


model_name = "lmsys/vicuna-7b-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_name,
                                          token=HF_TOKEN)

eos_token_ids = [
    [tokenizer(eos_token)["input_ids"][1]] for eos_token in END_OF_GENERATION_TOKENS
]



config = AutoConfig.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map= device,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            config=config,
        )
model.generation_config.do_sample = True

if data_name=='trivia_qa':
    dataset = train_data = datasets.load_dataset("trivia_qa", "rc.nocontext", split=f"train[:12000]")  
else:
    dataset = datasets.load_from_disk('coqa_dataset/')
    id_to_question_mapping = dict(zip(dataset['id'], dataset['question'])) 

def encode(examples):
    return examples['story'] + ' Q: ' + examples['question'] + ' A:'
def encode_and_format_dataset(dataset):
    dataset = dataset.map(encode, batched=False, load_from_cache_file=False)
    dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'], output_all_columns=True)
    return dataset

train_dataset = dataset
if data_name=='coqa':
    questions = encode_and_format_dataset(train_dataset)
else:
    questions = train_dataset
    
    
data_for_few_shot_prompt = questions.select(range(0, 3))
few_shot_prompt = 'This is a bot that correctly answers questions. \n'
for sample in data_for_few_shot_prompt:
    few_shot_prompt += 'Question: ' + sample['question'] + ' Answer: ' + sample['answer']['value'] + ' '
    
dataloader = torch.utils.data.DataLoader(questions, batch_size=1)


number_of_generations = 30
sequences = []
period_token_id = tokenizer('. ')['input_ids'][1]
eos_tokens = ['Question:', ' Question:', '\n', 'Answer:', ' Answer:', 'Q:']
question_framing_ids = [[tokenizer(eos_token)['input_ids'][1]] for eos_token in eos_tokens]

for batch in tqdm.tqdm(dataloader):
    sequences_dict = {}
    generated_texts = []
    prompt = few_shot_prompt+'Question: '+batch['question'][0]+' Answer:'
    sentences = [prompt]*number_of_generations

    inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
    input_ids = inputs['input_ids'].to(model.device)
    input_length = input_ids.shape[1]
    output_sequences = model.generate(**inputs, max_new_tokens=35,eos_token_id=period_token_id, do_sample=True,temperature=0.6, top_p=0.9)
    responses = tokenizer.batch_decode(output_sequences[:,input_length:], skip_special_tokens=True)

    generated_texts.extend(responses)
        
    sequences_dict['question'] = batch['question'][0]
    sequences_dict['generated_texts'] = generated_texts
    sequences_dict['answer'] = batch['answer']
    sequences.append(sequences_dict)
    with open(f'generations_triviaqa.pkl', 'wb') as outfile:
        pickle.dump(sequences, outfile)


