"""
Step-1: Perform Text2SVG task
"""

import json
import logging
import os
from argparse import ArgumentParser
from pathlib import Path
from time import time

from models import GenModelOpenAI, GenModelVllm
from render import render_entry
from utils import is_container_env, read_jsonl, setup_logger, write_jsonl

logger = logging.getLogger("text2svg")


os.environ["TOKENIZERS_PARALLELISM"] = "false"

SVG_DATA_FILE = "text2svg_eval_n164.jsonl"


def build_prompt(sample, question_mode, mode):
    if question_mode == "description":
        raw_question_str = f"""\
Please write an SVG code that will render an image matching the following description:

<description>
{sample["svg_description"]}
</description>"""
    elif question_mode == "question":
        raw_question_str = sample["svg_question"]
    else:
        raise ValueError(f"{question_mode = }")

    question_str = f"""\
{raw_question_str}

Please write the SVG code in this format:
```xml
<!-- YOUR CODE HERE -->
```"""

    if mode == "base":
        question_str = f"""\
Question:
{question_str}

Answer:
```xml
"""
    elif mode == "chat":
        question_str = question_str
    else:
        raise ValueError(f"{question_mode = }")

    return question_str, raw_question_str


def postprocess_generation(generation, mode):
    if mode == "base":
        return f"```xml\n" + generation
    return generation


def perform_inference(args, model, save_file):
    data_file = Path(__file__).parent / "data" / SVG_DATA_FILE
    logger.info(f"Eval file: {data_file}")

    original_samples = read_jsonl(data_file)
    if args.limit is not None:
        original_samples = original_samples[: args.limit]
    logger.info(f"Loaded original_samples = {len(original_samples)}")

    samples = []
    for sample in original_samples:
        origin_idx = sample["origin_idx"]
        for sampling_idx in range(1, args.num_samples + 1):
            gen_idx = f"{origin_idx}<>{sampling_idx:03d}<>{args.num_samples:03d}"
            samples.append({"gen_idx": gen_idx, **sample})

    logger.info(f"After num_samples={args.num_samples} expansion, {len(samples)=}")

    prompts, raw_prompts = [], []
    for sample in samples:
        prompt, raw_prompt = build_prompt(
            sample,
            question_mode=args.question_mode,
            mode=args.mode,
        )
        prompt = model.preprocess_prompt(prompt, force=args.force)

        prompts.append(prompt)
        raw_prompts.append(raw_prompt)

    logger.info(f"prompt[0]:")
    logger.info("-" * 20)
    logger.info(prompts[0])
    logger.info("-" * 20)

    generations_file = Path(save_file)
    generations_file = generations_file.resolve()

    logger.info(f"Generating...")
    tic = time()
    generations = model.generate(prompts, enable_tqdm=not disable_tqdm)
    generations = [postprocess_generation(generation, mode=args.mode) for generation in generations]
    toc = time()
    logger.info(f"Generation DONE: in {toc - tic:.2f} s")

    merged_generations = [
        {
            **s,
            "model_input": p,
            "prompt": rp,
            "completion": g,
        }
        for s, p, rp, g in zip(samples, prompts, raw_prompts, generations)
    ]
    write_jsonl(generations_file, merged_generations)
    logger.info(f"Generations => {generations_file}")

    assert len(samples) == len(generations), f"Mismatch between samples and generations."

    return samples, prompts, generations


if __name__ == "__main__":
    disable_tqdm = is_container_env()

    parser = ArgumentParser()
    parser.add_argument("--model", type=str, default="")
    parser.add_argument("--backend", type=str, choices=["vllm", "openai"])
    parser.add_argument("--tp", default=1, type=int)

    parser.add_argument("--max_new_tokens", default=4096, type=int)
    parser.add_argument("--num_samples", default=1, type=int)

    parser.add_argument("--temperature", default=None, type=float)
    parser.add_argument("--top_p", default=None, type=float)
    parser.add_argument("--top_k", default=None, type=int)
    parser.add_argument("--presence_penalty", default=None, type=float)
    parser.add_argument("--repetition_penalty", default=None, type=float)

    parser.add_argument("--save_folder", type=str)

    parser.add_argument("--mode", choices=["base", "chat"], default="base")
    parser.add_argument("--question_mode", choices=["description", "question"], default="description")
    parser.add_argument("--force", default=None, choices=["normal", "thinking", "no_thinking"])
    parser.add_argument("--limit", default=None, type=int)
    parser.add_argument("--no_render", default=False, action="store_true")

    args = parser.parse_args()

    save_dir = Path(args.save_folder).joinpath(args.question_mode)
    save_dir.mkdir(exist_ok=True, parents=True)

    setup_logger(save_dir, console_output=True)
    logger.info(f"Save => {save_dir}")

    configs_dumped = vars(args)
    with save_dir.joinpath("config.json").open("w") as f:
        json.dump(configs_dumped, f, indent=2, ensure_ascii=False)

    if args.limit is not None:
        logger.warning(f"args.limit is set ({args.limit = }). This should only be used for debugging.")

    inference_args = {
        "top_p": args.top_p,
        "top_k": args.top_k,
        "presence_penalty": args.presence_penalty,
        "repetition_penalty": args.repetition_penalty,
    }
    kwargs = dict(
        model=args.model,
        temperature=args.temperature,
        max_new_tokens=args.max_new_tokens,
        stop=["</s>", "<|endoftext|>", "\nQuestion:", "<|im_end|>"],
        is_chat_model=args.mode == "chat",
        inference_args={k: v for k, v in inference_args.items() if v is not None},
    )
    if args.mode == "base":
        kwargs["stop"] += ["```"]

    logger.info(f"{kwargs = }")
    if args.backend == "vllm":
        model = GenModelVllm(tp=args.tp, **kwargs)
    elif args.backend == "openai":
        model = GenModelOpenAI(**kwargs)
    else:
        raise ValueError(f"Unrecognized {args.backend = }")

    save_file = save_dir.joinpath("generations.jsonl")
    samples, prompts, generations = perform_inference(args, model, save_file)
    model.close()

    if not args.no_render:
        rendered_file = save_dir.joinpath("rendered_svg.jsonl")
        render_entry(save_file, rendered_file)

    logger.info("===============[Success]===============")
