import argparse
import os

from src.prompt_functions import *
from src.utility import (
    default_batch_size,
    default_openai_model_name,
    default_temperature,
    default_top_p,
    default_prompt_func,
    default_model_max_length,
    prepare_model,
    setup_logging,
    chat_format_general,
)


def query_hf_llm(
    *,
    input_file: str,
    output_file: str,
    tokenizer: PreTrainedTokenizer,
    model: PreTrainedModel,
    batch_size: int,
    temperature: float,
    top_p: float,
    prompt_function: Callable[[Record, bool], list[dict]],
    chat_format: str,
    do_sample: bool,
):
    """
    Query HF LLM. If output file already exist, skip existing records in output file.
    :param input_file: Path to input file containing records to query.
    :param output_file: Output file path.
    :param tokenizer: HF tokenizer.
    :param model: HF model.
    :param batch_size: Batch size when querying OpenAI.
    param temperature: Temperature used when calling OpenAI.
    :param top_p: top_p used when calling OpenAI.
    :param prompt_function: Prompt function to use.
    :param chat_format: Type of the model, use chat template or not.
    :param do_sample: Do sample or not when doing generation.
    :return:
    """
    print(f"processing {input_file}")
    with open(input_file, "r") as fin:
        records = json.load(fin)

    records_with_key = {}
    for record in records:
        records_with_key[record["id"]] = Record.from_dict(record)

    exist_results = {}
    if os.path.exists(output_file):
        with open(output_file, "r") as fin:
            for line in fin:
                obj = json.loads(line)
                record_id = obj["id"]
                exist_results[record_id] = obj

    all_keys = records_with_key.keys()
    exist_keys = exist_results.keys()
    left_keys = sorted(set(all_keys) - set(exist_keys))
    print(
        f"total record:{len(all_keys)}, existing record:{len(exist_keys)}, record to query:{len(left_keys)}"
    )

    get_hf_llm_result(
        output_file=output_file,
        record_keys=left_keys,
        records_with_key=records_with_key,
        model=model,
        tokenizer=tokenizer,
        prompt_func=prompt_function,
        batch_size=batch_size,
        temperature=temperature,
        top_p=top_p,
        chat_format=chat_format,
        do_sample=do_sample,
    )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_files", required=True, nargs="+", help="Input files")
    parser.add_argument(
        "--output_file_suffix",
        default=default_openai_model_name,
        type=str,
        help="Suffix to append to input file for generating output file name",
    )
    parser.add_argument(
        "--output_dir",
        default=".",
        type=str,
        help="Output dir",
    )
    parser.add_argument(
        "--log_file",
        type=str,
        help="Log file path",
    )
    parser.add_argument(
        "--model_name", type=str, help="HuggingFace compatible LLM model name"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=default_batch_size,
        help="Batch size when sending request",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=default_temperature,
        help="Temperature when calling LLM",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=default_top_p,
        help="top_p when calling LLM",
    )
    parser.add_argument(
        "--prompt_function",
        type=str,
        default=default_prompt_func,
        help="Function to get prompt when calling LLM",
    )
    parser.add_argument(
        "--model_max_length",
        type=int,
        default=default_model_max_length,
        help="Model max length",
    )
    parser.add_argument(
        "--device",
        type=str,
        help="Device to use",
    )
    parser.add_argument(
        "--chat_format",
        type=str,
        default=chat_format_general,
        help="Type of the model",
    )
    parser.add_argument(
        "--no_sample",
        action="store_true",
        help="Do not do sample in generation",
    )
    parser.add_argument(
        "--bf16",
        action="store_true",
        help="Use BF16",
    )
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Use FP16",
    )

    parsed_args = parser.parse_args()
    return parsed_args


if __name__ == "__main__":
    """Query HF model with one step."""
    setup_logging()
    logger = logging.getLogger()
    args = parse_args()
    print(f"Running {os.path.basename(__file__)} with args: {args}")
    if args.no_sample:
        do_sample = False
    else:
        do_sample = True
    local_symbols = locals()
    if args.prompt_function not in local_symbols:
        logger.info(f"function {args.prompt_function} not found in: {local_symbols}")
        exit(1)
    prompt_function = local_symbols[args.prompt_function]

    if args.bf16 and args.fp16:
        logger.info(f"Both bf16 and fp16 are set, use bf16")
        torch_dtype = torch.bfloat16
    elif args.bf16:
        torch_dtype = torch.bfloat16
    elif args.fp16:
        torch_dtype = torch.float16
    else:
        torch_dtype = None

    model, tokenizer = prepare_model(
        args.model_name, args.model_max_length, args.device, torch_dtype
    )

    for input_file in args.input_files:
        output_file = os.path.join(
            args.output_dir,
            os.path.basename(input_file) + "." + args.output_file_suffix,
        )
        if args.log_file is None:
            log_file = output_file + ".log"
        else:
            log_file = args.log_file
        file_handler = logging.FileHandler(log_file)
        logger.addHandler(file_handler)

        query_hf_llm(
            input_file=input_file,
            output_file=str(output_file),
            tokenizer=tokenizer,
            model=model,
            batch_size=args.batch_size,
            temperature=args.temperature,
            top_p=args.top_p,
            prompt_function=prompt_function,
            chat_format=args.chat_format,
            do_sample=do_sample,
        )

        logger.removeHandler(file_handler)
