import os

from evolution import evolution
from utils import set_seed
from args import parse_args
from llm_client import llm_init, LLMClient
from run_bbh import eval_task
import json
from utils import read_lines, setup_log
import random

if __name__ == "__main__":
    args = parse_args()
    task = args.task
    set_seed(args.seed)
    if args.client:
        client = LLMClient(args.client_host)
        llm_config = {}
    else:
        client = None
        llm_config = llm_init(f"./auth.yaml", args.llm_type, args.setting)

    out_path = args.output
    logger = setup_log(os.path.join(out_path, f"eval.log"))
    logger.info("\n\t" + "\n\t".join(f"{k} = {v}" for k, v in vars(args).items()))
    logger.info("=" * 50)

    task_data = json.load(open("data/%s.json" % task))["examples"]
    sample_num = args.sample_num
    dev_data = random.sample(task_data, sample_num)
    test_data = [i for i in task_data if i not in dev_data]
    model = "turbo" if "turbo" in args.llm_type else "davinci"
    task_prompt = open("optimize_prompt/%s.txt" % task, "r").read()
    prompts = [
        "Let's be realistic and examine this analytically, analyzing the data carefully and precisely, step by step until the evidence leads to the solution. "
    ]
    # prompts = read_lines("prompts.txt")[:2]
    with open(os.path.join(args.output, "acc.txt"), "a+") as f:
        for p in prompts:
            test_acc = eval_task(
                task=task,
                task_prompt=task_prompt,
                cot_prompt=p,
                eval_data=test_data,
                client=client,
                model_index=model,
                logger=logger,
                consis=args.consis,
                demon=args.demon,
                **llm_config,
            )
            dev_acc = eval_task(
                task=task,
                task_prompt=task_prompt,
                cot_prompt=p,
                eval_data=dev_data,
                client=client,
                model_index=model,
                logger=logger,
                consis=args.consis,
                demon=args.demon,
                **llm_config,
            )

            test_correct = test_acc * len(test_data)
            dev_correct = dev_acc * len(dev_data)
            all_acc = (test_correct + dev_correct) / (len(task_data))
            f.write(f"{p}\t{dev_acc}\t{test_acc}\t{all_acc}\n")
            f.flush()
