import logging
import sys
import functools

logging.basicConfig(
    stream=sys.stdout,
    level=logging.INFO,
    format="%(asctime)s, %(name)s:%(levelname)s: %(message)s",
    datefmt="%d.%m.%y %H:%M:%S",
)

from bof4.evaluation.evaluate_quant import (
    evaluate_quantizers,
    evaluate_quantized_model,
    Benchmarks,
)
from bof4.quantization import *
from bof4.evaluation.harness import DEFAULT_HARNESS_TASKS
from bof4.evaluation.evalplus import DEFAULT_EVALPLUS_TASKS

BENCHMARKS_DICT = {benchmark.value: benchmark for benchmark in Benchmarks}
DEFAULT_BENCHMARKS = [Benchmarks.ERRORS, Benchmarks.HARNESS]


def main():
    _logger = logging.getLogger(__name__)

    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-q",
        "--quantizer",
        nargs="*",
        help=f"Paths to the evaluated quantizer(s)",
    )

    parser.add_argument(
        "-u",
        "--unquantized",
        action="store_true",
        help="Additionally evaluate the model in full precision"
    )

    parser.add_argument(
        "-o", "--out", default="quant_results", help="The output folder"
    )

    parser.add_argument(
        "-m",
        "--model",
        default="meta-llama/Llama-3.2-3B",
        help="The model to evaluate (Hugging Face ID or path to directory)",
    )
    parser.add_argument(
        "--benchmarks",
        nargs="*",
        help=f"The benchmark suites to use for the evaluation. Available: {list(BENCHMARKS_DICT.keys())}",
    )
    parser.add_argument(
        "-t",
        "--harness-tasks",
        nargs="*",
        help=f"The tasks to evaluate with the LLM evaluation harness. Default: {DEFAULT_HARNESS_TASKS}",
    )
    parser.add_argument(
        "-p",
        "--evalplus-tasks",
        nargs="*",
        help=f"The tasks to evaluate with the evalplus library. Default: {DEFAULT_EVALPLUS_TASKS}",
    )
    parser.add_argument("--chat-template", default=None)
    parser.add_argument("--merge-results", action="store_true", help="Merge new results into existing file")
    parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention", default=False)

    args = parser.parse_args()


    if args.quantizer is None:
        args.quantizer = []
    if args.benchmarks is None:
        selected_benchmarks = DEFAULT_BENCHMARKS
    else:
        try:
            selected_benchmarks = [BENCHMARKS_DICT[benchmark] for benchmark in args.benchmarks]
        except KeyError as e:
            raise ValueError(f"One of the provided benchmarks does not exist: {e.args[0]}")

    if args.chat_template is not None:
        try:
            with open(args.chat_template, "r") as f:
                chat_template = (
                    f.read().replace("\n", "").replace("\r", "").lstrip("\t ")
                )
        except FileNotFoundError:
            _logger.error(f"Could not find template file at {args.chat_template}")
        _logger.info(f"Using chat template: {args.chat_template}")
    else:
        chat_template = None
    print(selected_benchmarks)
    quantizers = [
        load_from_file(path) for path in args.quantizer
    ]
    if args.unquantized or not quantizers:
        quantizers = quantizers + [None]


    quanitzer_names = [getattr(q, "name", "unnamed") if q is not None else "unquantized" for q in quantizers]
    _logger.info(
        f"Evaluating the following quantizers:\n{quanitzer_names}",
    )
    _logger.info(f"Evaluating benchmarks {selected_benchmarks}")
    if Benchmarks.HARNESS in selected_benchmarks:
        _logger.info(
            f"Evaluating harness tasks {args.harness_tasks or DEFAULT_HARNESS_TASKS}"
        )
    if Benchmarks.EVALPLUS in selected_benchmarks:
        _logger.info(
            f"Evaluating evalplus tasks {args.evalplus_tasks or DEFAULT_EVALPLUS_TASKS}"
        )


    evaluate_quantizers(
        args.model,
        functools.partial(
            evaluate_quantized_model,
            benchmarks=selected_benchmarks,
            harness_benchmarks=args.harness_tasks or DEFAULT_HARNESS_TASKS,
            evalplus_benchmarks=args.evalplus_tasks or DEFAULT_EVALPLUS_TASKS,
            chat_template=chat_template,
            repo_id=args.model,
        ),
        quantizers,
        args.out,
        model_dtype=torch.bfloat16,
        disable_flash_attn=args.disable_flash_attn,
        skip_existing=not args.merge_results,
        merge_results=args.merge_results,
    )


if __name__ == "__main__":
    main()
