import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from ast import literal_eval
from tqdm import tqdm
import random
import numpy as np
tqdm.pandas()


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def generate(tokens, model, tokenizer, device):
    if not tokens:
        return []
    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = torch.tensor(ids, dtype=torch.int, device=device)
    ids = torch.unsqueeze(ids, dim=0)
    output = model.generate(ids, max_length=128).squeeze()
    output_tokens = tokenizer.convert_ids_to_tokens(output.tolist())
    # print('input: ', tokens)
    # print('pred: ', output_tokens[len(tokens):])
    return output_tokens

if __name__ == '__main__':
    set_seed(36)
    device = 'cuda:0'
    df =  pd.read_csv('/mnt/home/dongkeun/L2U/outputs/doc_id_reset_count.csv')
    df['prefix'] = df['prefix'].apply(literal_eval)
    df['suffix'] = df['suffix'].apply(literal_eval)
    df = df[df['eoe_init'] != 127]
    tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
    model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', pad_token_id=tokenizer.eos_token_id).to(device)

    df['wrong_generation'] = df['prefix'].progress_apply(lambda x: generate(x, model, tokenizer, device))

    df.to_csv('/mnt/home/dongkeun/L2U/outputs/doc_id_reset_count_generation.csv', index=False)

