import argparse
import json
import os
import numpy as np  # type: ignore
import random
import torch  # type: ignore
import time

from utils import load_jsonl, write_jsonl, batch_decode_vllm, init_seed, load_llm


def eval_subject(args, subject, llm, test_records, kk_proc, exist_result_test_records):

    keys_to_remove = [
        "knight_knave",
        "cot_foot",
        "cot_steps",
        "cot_head",
        "cot_repeat_steps",
        "statements",
    ]  # remove the dataset info to reduce the size of the output file

    cors = []
    start_index = 0
    if exist_result_test_records is not None:
        start_index = len(exist_result_test_records)
        print(f"find existing {start_index} records in {subject}")
        for i in range(start_index):
            test_records[i] = exist_result_test_records[i]
            cors.append(exist_result_test_records[i]["correct"])
            for key in keys_to_remove:
                test_records[i].pop(key, None)

    eval_start_time = time.time()
    # Prepare all prompts
    prompts = []
    labels = []
    for i in range(start_index, len(test_records)):
        prompt, label = kk_proc.gen_test_prompt(
            args.ntrain, test_records, i, args.model
        )
        prompts.append(prompt)
        if i == start_index:
            print(prompt)
        labels.append(label)

    if args.use_vllm and args.infer_mode == "generation":
        responses = batch_decode_vllm(llm, prompts, batch_size=args.batch_size)
    else:
        responses = []
        for index, prompt in enumerate(prompts):
            response, _ = llm.query(prompt, choices=None)
            responses.append(response)
            if index % 1 == 0:
                print("\nresponse\n", response)
                print("\nlabel\n", labels[index])

    # Process results
    for i, (prompt, label, response) in enumerate(
        zip(prompts, labels, responses), start=start_index
    ):

        cor, parsed_pred, reformat_gold_conditions = kk_proc._parse_cot_eval(
            response, label, args.model
        )

        if i % 1 == 0:
            print(prompt)
            print("\nresponse\n", response)
            print("\npredict\n", parsed_pred)
            print("\nlabel\n", reformat_gold_conditions)
            print("\ncorrect\n", cor)

        cors.append(cor)
        test_records[i]["predicts"] = parsed_pred
        test_records[i]["labels"] = reformat_gold_conditions
        test_records[i]["correct"] = cor
        test_records[i]["response"] = response
        test_records[i]["prompts"] = prompt

        for key in keys_to_remove:
            test_records[i].pop(key, None)

    eval_end_time = time.time()
    eval_time = eval_end_time - eval_start_time
    acc = np.mean(cors)
    cors = np.array(cors)

    print("Average accuracy {:.3f} - {}".format(acc, subject))
    print(f"Total evaluation time: {eval_time:.2f} seconds")

    return cors, acc, test_records


def main(args):

    last_two_layers = "/".join(args.model.split("/")[-2:])
    if args.no_linebreak:
        last_two_layers += "_no_linebreak"

    prefix = os.path.join(
        os.path.join(args.save_dir, "{}_{}shot".format(
            last_two_layers, args.ntrain))
    )

    if args.infer_mode == "generation":
        args.config += f"_token{args.max_token}"

    if args.cot:
        args.config += "_cot"

    args.config += f"_{args.split}"

    if args.problem_type != "clean":
        args.config += f"_{args.problem_type}"

    print("args.config", args.config, "prefix", prefix)
    output_folder = os.path.join(prefix, args.config)
    print(output_folder)

    results_file = os.path.join(prefix, f"result_{args.config}.json")

    os.makedirs(output_folder, exist_ok=True)

    from dataset.kk import KKProcessor

    kk_proc = KKProcessor(cot=args.cot, no_linebreak=args.no_linebreak)

    test_folder = f"{args.data_dir}/{args.split}/{args.problem_type}"
    if args.split == "test":
        if args.eval_train > 0:
            subjects = [f"people{args.eval_train}_num100"]
        else:
            subjects = [
                "people2_num100",
                "people3_num100",
                "people4_num100",
                "people5_num100",
                "people6_num100",
                "people7_num100",
                "people8_num100",
            ]
    else:  # train
        if args.eval_train > 0:
            if args.eval_train == 2:
                subjects = ["people2_num200"]
            else:
                subjects = [f"people{args.eval_train}_num1000"]
        else:
            subjects = []

    all_cors = []
    results = {"subject": {}}
    if os.path.isfile(results_file):
        with open(results_file, 'r', encoding='utf-8') as file:
            results = json.load(file)
        print("Previous Results loaded successfully.")
        print(results)
    else:
        print("Previous Results does not exist.")

    llm = None  # delay llm loading to save time

    for subject in subjects:
        test_output_file = os.path.join(
            output_folder, "{}.jsonl".format(subject))
        exist_result_test_records = None

        if os.path.exists(test_output_file):
            exist_result_test_records = load_jsonl(test_output_file)

        test_file_path = os.path.join(test_folder, f"{subject}.jsonl")
        test_records = load_jsonl(test_file_path)
        if args.limit is not None:
            test_records = test_records[: args.limit]
            if (exist_result_test_records is not None) and (args.limit <= len(exist_result_test_records)):
                print(
                    "skip",
                    subject,
                    " : because limit=",
                    args.limit,
                    "smaller than or equal to exisit=",
                    len(exist_result_test_records),
                )
                continue

        if llm is None:
            llm = load_llm(args)
        cors, acc, result_test_records = eval_subject(
            args, subject, llm, test_records, kk_proc, exist_result_test_records
        )

        all_cors.append(cors)

        write_jsonl(test_output_file, result_test_records)

        results["subject"][subject] = acc

    if len(all_cors) > 0:
        weighted_acc = np.mean(np.concatenate(all_cors))
        results["weighted_accuracy"] = weighted_acc
        print("Average accuracy: {:.3f}".format(weighted_acc))

        with open(results_file, "w") as f:
            json.dump(results, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=0)
    parser.add_argument("--data_dir", "-d", type=str, default="data")
    parser.add_argument("--save_dir", "-s", type=str, default="result")
    parser.add_argument("--model", "-m", type=str)
    parser.add_argument("--arch", type=str, default=None)
    parser.add_argument("--config", "-c", type=str, default="")

    parser.add_argument(
        "--infer_mode", type=str, default="generation", choices=["generation"]
    )
    parser.add_argument("--max_token", type=int, default=1024)
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--cot", action="store_true")
    parser.add_argument("--no_linebreak", action="store_true")
    parser.add_argument("--use_vllm", action="store_true")

    parser.add_argument("--batch_size", type=int, default=4)

    parser.add_argument("--split", type=str, default="test")
    parser.add_argument("--eval_train", type=int, default=0)
    parser.add_argument("--problem_type", type=str, default="clean")

    args = parser.parse_args()
    init_seed()
    main(args)
