import os
import numpy as np
import pandas as pd
import torch
import argparse
from llm_brain_asym import compute_logprobs, make_dir
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

parser = argparse.ArgumentParser()
parser.add_argument('--model', type=str, help='model name')
parser.add_argument('--revision', type=str, help='details of the revision')
parser.add_argument('--device', type=str, default='cuda', help='device (cuda or cpu)')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--seed', type=int, default=12345, help='random seed for the creation of the dataset')
parser.add_argument('--output_folder', type=str, default='arithmetic', help='output folder')

args = parser.parse_args()

model_name = args.model
revision =  args.revision
device =  args.device
batch_size = args.batch_size
seed = args.seed
output_folder = args.output_folder

torch.set_grad_enabled(False)
np.random.seed(seed)
torch.manual_seed(seed)

make_dir(output_folder)

def create_equations_minimal_pairs(max_int, n_examples, kind='addition', seed=12345):
    np.random.seed(seed)
    equations_good = []
    equations_bad = []
    if kind=='addition':        
        for _ in range(n_examples):
            numbers = np.random.choice(max_int+1, size=2)
            error = np.random.choice([-10, -2, -1, 1, 2, 10])
            equations_good.append('{} + {} = {}'.format(*numbers, sum(numbers)))
            equations_bad.append('{} + {} = {}'.format(*numbers, sum(numbers)+error))
    elif kind == 'multiplication':
        equations = []
        for _ in range(n_examples):
            numbers = np.random.choice(max_int+1, size=2)
            error = np.random.choice([-10, -2, -1, 1, 2, 10])
            equations_good.append('{} * {} = {}'.format(*numbers, np.prod(numbers)))
            equations_bad.append('{} * {} = {}'.format(*numbers, np.prod(numbers)+error))
    else:
        raise Exception('This kind of operation is not implemented.')
    return equations_good, equations_bad

n_examples = 2048
equations_add_good, equations_add_bad = create_equations_minimal_pairs(1000, n_examples, kind='addition')
equations_mult_good, equations_mult_bad = create_equations_minimal_pairs(100, n_examples, kind='multiplication')

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, revision=revision, torch_dtype=torch.float16).to(device)

logprob_add = compute_logprobs(model, tokenizer, equations_add_good + equations_add_bad,
                               batch_size=batch_size)
logprob_mult = compute_logprobs(model, tokenizer, equations_mult_good + equations_mult_bad,
                               batch_size=batch_size)

accuracy_add = np.mean(logprob_add[:n_examples] > logprob_add[n_examples:])
accuracy_mult = np.mean(logprob_mult[:n_examples] > logprob_mult[n_examples:])

df_results = pd.DataFrame([{'model':model_name,
                            'revision':revision,
                            'accuracy':.5*(accuracy_add+accuracy_mult),
                            'accuracy_add':accuracy_add,
                            'accuracy_mult':accuracy_mult}])
filename = os.path.join(output_folder, '{}_{}_arithmetic.csv'.format(model_name.replace('/','_'), revision))
df_results.to_csv(filename)