import os
import numpy as np
import pandas as pd
import glob
import torch
import argparse
from llm_brain_asym import compute_logprobs, home_folder, make_dir
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='model name')
parser.add_argument('--revision', type=str, default='main', help='details of the revision')
parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--output_folder', type=str, default='blimp', help='output folder')

args = parser.parse_args()
model_name = args.model
revision =  args.revision
device =  args.device
batch_size = args.batch_size
output_folder = args.output_folder
    
torch.set_grad_enabled(False)

make_dir(output_folder)

linguistics_terms = ['anaphor_agreement', 'argument_structure', 'binding',
       'control_raising', 'determiner_noun_agreement', 'ellipsis',
       'filler_gap_dependency', 'irregular_forms', 'island_effects',
       'npi_licensing', 'quantifiers', 's-selection',
       'subject_verb_agreement']

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision, torch_dtype=torch.float16).to(device)
    
blimp_folder = os.path.join(home_folder, 'blimp', 'data')
files = sorted(glob.glob(os.path.join(blimp_folder, '*.jsonl')))

tokenizer.padding_side = 'right'
if tokenizer.pad_token_id == None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
        
df_results = []
for filename in files:
    df_task = pd.read_json(filename, lines=True)
    if linguistics_terms is None or (df_task['linguistics_term'][0] in linguistics_terms):    
        sentences_good = df_task['sentence_good'].to_list()
        sentences_bad = df_task['sentence_bad'].to_list()
        logprobs_sentences = compute_logprobs(model, tokenizer, sentences_good + sentences_bad,
                               batch_size=batch_size)
        n_sentences_good = len(sentences_good)        
        logprobs_good = logprobs_sentences[:n_sentences_good]
        logprobs_bad = logprobs_sentences[n_sentences_good:]
        accuracy = np.mean(logprobs_good > logprobs_bad)  
        df_results.append({'model':model_name,
                           'revision':revision,
                           'field':df_task['field'][0],
                           'linguistics_term':df_task['linguistics_term'][0],
                           'UID':df_task['UID'][0],
                           'accuracy':accuracy})

df_results = pd.DataFrame(df_results)
df_results.reset_index(inplace=True, drop=True)  

filename = os.path.join(output_folder, '{}_{}_blimp.csv'.format(model_name.replace('/','_'), revision))

df_results.to_csv(filename)