import os
import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from nltk.tokenize import sent_tokenize

llm_gen_folder = 'llm_gen'

torch.set_grad_enabled(False)
device = 'cuda'

step_range =  [150, 600, 1000, 3000, 7000, 19000, 51000, 133000, 352000, 928646]
tokens_range =  [1, 3, 5, 13, 30, 80, 214, 558, 1477, 3896]

n_trials = 5
prompts = ['Why not', 'Are you', 'This is', 'Alice was', 'Bob went']

tokenizer_cola = AutoTokenizer.from_pretrained("yiiino/deberta-v3-large-cola")
model_cola = AutoModelForSequenceClassification.from_pretrained("yiiino/deberta-v3-large-cola")
model_cola.to(device)

def evaluate_grammaticality_cola(text):
    sentences = sent_tokenize(text)
    p_sents = []
    for sent in sentences:
        tokens = tokenizer_cola(sent, return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = model_cola(tokens['input_ids'])
        p_sents.append(torch.softmax(outputs.logits[0], dim=0).cpu().numpy()[1])
    return np.mean(p_sents)

df_results = []
for step, tokens in tqdm(zip(step_range, tokens_range)):
    model_name = 'allenai/OLMo-2-1124-7B'
    revision = 'stage1-step{}-tokens{}B'.format(step, tokens)
    texts = []

    for idx_trial in range(n_trials):
        for idx_prompt in range(len(prompts)):
            filename = os.path.join(llm_gen_folder, '{}_{}_gen_{}_{}.txt'.format(model_name.replace('/','_'), revision, idx_prompt, idx_trial))   
            with open(filename, 'r') as f:
                text = f.read()
            texts.append(text)

    cola_values = [evaluate_grammaticality_cola(text) for text in texts]
    df_results.append(pd.DataFrame({'model_name':model_name,
                                    'revision':revision,
                                    'step':step,
                                    'tokens':tokens,
                                    'cola_value':np.array(cola_values)}))

df_results = pd.concat(df_results)
df_results.to_csv(os.path.join(llm_gen_folder, '{}_gen_cola.csv'.format(model_name.replace('/','_'))))