import re
import click
import pandas as pd
from eval.util import (
    # load_model_and_tokenizer,
    # batched_generate,
    parse_number,
    format_example,
    prep_incontext_examples,
    write_results,
)
from text_conditioning import *
from eval.eval_hacks import load_sampler, batched_generate
from utils import read_json, seed_all
from tqdm.auto import tqdm as tqdm
import datasets
from pathlib import Path
import json
seed_all(42)


def evaluate_alpacaeval(sampler_factory, eval_set, batch_size):
    # print(eval_set)
    prompts = [d['instruction'] for d in eval_set]
    outputs = batched_generate(
        prompts=prompts,
        bsf=sampler_factory,
        do_sample=False,
        max_new_bytes=4000,
        batch_size=batch_size,
        stop_strings=("\nQuestion: ","\n\nQuestion: ")
    )
    return outputs


@click.command()
@click.option("--model_name_or_path", type=str, default="pile-npt25k")
@click.option("--output_dir", type=str)
@click.option("--eval_batch_size", type=int, default=1)
@click.option("--start", type=int)
@click.option("--end", type=int)
def main(
    model_name_or_path: str,
    output_dir: str,
    eval_batch_size: int,
    start:int,
    end:int
):
    out_dir = Path(output_dir)
    out_dir.mkdir(exist_ok=True)
    out_file = out_dir/ f'out_{start}_{end}.json'
    if out_file.exists():
        return 
        
    is_base = False
    if model_name_or_path.endswith("Base"):
        is_base = True
    if model_name_or_path.startswith("meta-llama/Llama-3") and model_name_or_path.endswith("B"):
        is_base = True
        
    if is_base:
        sampler_factory = BytewiseQAFactory(TextConditionedSampler(model_name_or_path))
    else:
        sampler_factory = load_sampler(model_name_or_path)
    eval_set = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
    print("running evals")
    results = evaluate_alpacaeval(
       sampler_factory,
        list(eval_set)[start:end],
        batch_size=eval_batch_size,
    )
    with open(out_file, 'w') as f:
        json.dump(results, f)
    # write_results(results, output_dir, print_metrics=True)


if __name__ == "__main__":
    main()