import os
import numpy as np
import pandas as pd
import glob
import torch
import argparse
from llm_brain_asym import compute_logprobs, home_folder, make_dir
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='model name')
parser.add_argument('--revision', type=str, help='details of the revision')
parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--output_folder', type=str, default='zorro', help='output folder')

args = parser.parse_args()

model_name = args.model
revision =  args.revision
device =  args.device
batch_size = args.batch_size
output_folder = args.output_folder

torch.set_grad_enabled(False)

make_dir(output_folder)

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision, torch_dtype=torch.float16).to(device)

zorro_folder = os.path.join(home_folder, 'Zorro', 'sentences', 'babyberta')
filenames = sorted(glob.glob(os.path.join(zorro_folder, '*.txt')))

tokenizer.padding_side = 'right'
if tokenizer.pad_token_id == None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
        
df_results = []
for filename in filenames:
    with open(filename, 'r') as file:
        lines = []
        for line in file:
            line = line.rstrip()
            line = line.replace(' .', '.')
            line = line[0].upper() + line[1:]
            lines.append(line)

    logprobs = compute_logprobs(model, tokenizer, lines,
                                batch_size=batch_size)
    accuracy = np.mean(logprobs[1::2] > logprobs[::2])

    df_results.append({'model':model_name,
                        'revision':revision,
                        'task':os.path.basename(filename).split('.txt')[0],
                        'accuracy':accuracy})

df_results = pd.DataFrame(df_results)
df_results.reset_index(inplace=True, drop=True)  

filename = os.path.join(output_folder, '{}_{}_zorro.csv'.format(model_name.replace('/','_'), revision))
df_results.to_csv(filename)