import argparse
import os

from src.prompt_functions import *
from src.utility import (
    default_batch_size,
    default_temperature,
    default_top_p,
    default_prompt_func,
    setup_logging,
    default_gemini_model_name,
)


def query_gemini(
    *,
    input_file: str,
    output_file: str,
    model_name: str,
    batch_size: int,
    temperature: float,
    top_p: float,
    gemini_client: genai.Client,
    prompt_function: Callable[[Record, bool], list[dict]],
):
    """
    Query Gemini API. If output file already exist, skip existing records in output file.
    :param input_file: Path to input file containing records to query.
    :param output_file: Output file path.
    :param model_name: OpenAI model name.
    :param batch_size: Batch size when querying OpenAI.
    param temperature: Temperature used when calling OpenAI.
    :param top_p: top_p used when calling OpenAI.
    :param gemini_client: Gemini client to use.
    :param prompt_function: Prompt function to use.
    """
    print(f"processing {input_file}")
    with open(input_file, "r") as fin:
        records = json.load(fin)

    records_with_key = {}
    for record in records:
        records_with_key[record["id"]] = Record.from_dict(record)

    exist_results = {}
    if os.path.exists(output_file):
        with open(output_file, "r") as fin:
            for line in fin:
                obj = json.loads(line)
                record_id = obj["id"]
                exist_results[record_id] = obj

    all_keys = records_with_key.keys()
    exist_keys = exist_results.keys()
    left_keys = sorted(set(all_keys) - set(exist_keys))
    print(
        f"total record:{len(all_keys)}, existing record:{len(exist_keys)}, record to query:{len(left_keys)}"
    )

    get_gemini_result(
        output_file=output_file,
        record_keys=left_keys,
        records_with_key=records_with_key,
        model_name=model_name,
        prompt_func=prompt_function,
        gemini_client=gemini_client,
        batch_size=batch_size,
        temperature=temperature,
        top_p=top_p,
    )


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_files", required=True, nargs="+", help="Input files")
    parser.add_argument(
        "--output_dir",
        default=".",
        type=str,
        help="Output dir",
    )
    parser.add_argument(
        "--output_file_suffix",
        default=default_gemini_model_name,
        type=str,
        help="Suffix to append to input file for generating output file name",
    )
    parser.add_argument(
        "--log_file",
        type=str,
        help="Log file path",
    )
    parser.add_argument(
        "--model_name",
        default=default_gemini_model_name,
        type=str,
        help="Gemini model name",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=default_batch_size,
        help="Batch size when sending request",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=default_temperature,
        help="Temperature when calling Gemini API",
    )
    parser.add_argument(
        "--top_p",
        type=float,
        default=default_top_p,
        help="top_p when calling Gemini API",
    )
    parser.add_argument(
        "--prompt_function",
        type=str,
        default=default_prompt_func,
        help="top_p when calling Gemini API",
    )

    parsed_args = parser.parse_args()
    return parsed_args


if __name__ == "__main__":
    """Call Gemini with one step."""
    setup_logging()
    logger = logging.getLogger()
    args = parse_args()
    print(f"Running {os.path.basename(__file__)} with args: {args}")

    gemini_client = genai.Client()

    local_symbols = locals()
    if args.prompt_function not in local_symbols:
        logger.info(f"function {args.prompt_function} not found in: {local_symbols}")
        exit(1)
    prompt_function = local_symbols[args.prompt_function]

    for input_file in args.input_files:
        output_file = os.path.join(
            args.output_dir,
            os.path.basename(input_file) + "." + args.output_file_suffix,
        )
        if args.log_file is None:
            log_file = output_file + ".log"
        else:
            log_file = args.log_file
        file_handler = logging.FileHandler(log_file)
        logger.addHandler(file_handler)

        query_gemini(
            input_file=input_file,
            output_file=str(output_file),
            model_name=args.model_name,
            batch_size=args.batch_size,
            temperature=args.temperature,
            top_p=args.top_p,
            gemini_client=gemini_client,
            prompt_function=prompt_function,
        )

        logger.removeHandler(file_handler)
