"""
Evaluate (LLM-as-a-judge)

"""

from argparse import ArgumentParser
from collections import defaultdict
from copy import deepcopy
import logging
from pathlib import Path
import json
from tqdm import tqdm
import yaml
import sys

from utils_api import (
    call_openai_single_response_api,
    # call_anthropic_api_single,
    # call_google_api_single,
    estimate_cost,
)

from utils import (
    get_date,
)


def format_prev_steps(example, annotation):
    end = float(example["video"]["end"])

    output = ""
    for step in annotation["mistake"]["steps"]:
        if float(step["end"]) <= end:
            output += f"* {step['action']}\n"
        else:
            pass

    return output.strip()


def format_answers(example):
    output = ""
    for answer in example["answers"]:
        output += f"- {answer}\n"

    return output.strip()


def format_eval_openai_response_api(
    args, example, prediction, template_components, annotation
):
    content, text_prompt = [], ""

    content.append({"type": "input_text", "text": template_components["prefix"]})
    text_prompt += template_components["prefix"] + "\n"

    step = template_components["step"]
    step = step.replace("{previous_steps}", format_prev_steps(example, annotation))
    content.append({"type": "input_text", "text": step})
    text_prompt += step + "\n"

    content.append({"type": "input_text", "text": template_components["option"]})
    text_prompt += template_components["option"] + "\n"

    content.append({"type": "input_text", "text": template_components["note"]})
    text_prompt += template_components["note"] + "\n"

    # todo: unify the output file format
    # currently, do ad-hoc process
    if "prediction" in prediction:
        if "answer" in prediction["prediction"]:
            predicted_answer = prediction["prediction"]["answer"].strip()
        else:
            predicted_answer = prediction["prediction"]["response"].strip()
    elif "answer" in prediction:
        predicted_answer = prediction["answer"].strip()
    else:
        sys.exit(f"Undefined output file {prediction=}")

    task = template_components["task"]
    task = (
        task.replace("{question}", example["question"])
        .replace("{gold_answer}", format_answers(example))
        .replace("{predicted_answer}", predicted_answer)
    )
    content.append({"type": "input_text", "text": task})
    text_prompt += task + "\n"

    return content, text_prompt


def parse_feedback(feedback: str) -> tuple[str, str]:
    """
    parse feedback

    """

    splits = feedback.split("[Judge]")
    rationale, judge = splits

    return judge.strip(), rationale.strip()


def main(args):
    with open(args.filepath_qa, "r") as f:
        examples = json.load(f)

    with open(args.filepath_prediction, "r") as f:
        predictions = json.load(f)

    assert len(examples) == len(predictions["examples"])

    with open(args.filepath_annotation, "r") as f:
        annotations = json.load(f)

    with open(args.filepath_template, "r") as f:
        template_components = yaml.safe_load(f)

    logging.info(f"#target examples: {len(examples)} ({args.template_type=})")

    logging.info("Call API")
    filepath_output = (
        args.dirpath_output
        / f"{Path(args.filepath_prediction.parent.name)}_{Path(args.model_id).name}"
        f"_{args.template_type}_{args.filepath_prediction.name}"
    )
    if not args.dirpath_output.exists():
        args.dirpath_output.mkdir(parents=True)

    new_examples = []
    total_usage = defaultdict(int)
    for idx, (example, prediction) in tqdm(
        enumerate(zip(examples, predictions["examples"])), total=len(examples)
    ):
        annotation = annotations["examples"][example["sequence_id"]]

        if prediction != "Error" and prediction["prediction"] != "Error":
            # sys.exit('stop')

            match args.model_id:
                case "gpt-4o-2024-11-20":
                    content, text_prompt = format_eval_openai_response_api(
                        args, example, prediction, template_components, annotation
                    )
                    if idx == 0:
                        logging.info(f"Sanity check: prompt (text part only) for {idx=}")
                        logging.info(text_prompt)

                    system_developer_message = None
                    messages = [{"role": "user", "content": content}]
                    response, usage = call_openai_single_response_api(
                        args,
                        system_developer_message,
                        messages=messages,
                    )
                case _:
                    logging.error(f"Undefined: {args.model_id=}")

            prediction = ""
            for _response in response.output:
                match _response.type:
                    case "message":
                        if len(_response.content) > 1:
                            logging.warning(f"Multiple {_response.content=}")
                        prediction = "\n".join([x.text for x in _response.content])
                    case _:
                        logging.error(f"Undefined {_response.type=}")

            if len(response.output) > 1:
                logging.warning(f"Multiple {response.output=}")

            total_usage["input"] += usage["input"]
            total_usage["cached"] += usage["cached"]
            total_usage["output"] += usage["output"]

            new_example = deepcopy(example)
            new_example["evaluation"] = {
                "prompt": text_prompt,
                "model_id": args.model_id,
                "template_type": args.template_type,
                "response": prediction,
            }
            try:
                judge, rationale = parse_feedback(prediction)
                new_example["evaluation"]["judge"] = judge
                new_example["evaluation"]["rationale"] = rationale
            except Exception as e:
                logging.warning(f"Error happened during postprocess: {e}")
        else:
            new_example = deepcopy(example)
            new_example["evaluation"] = {
                "prompt": text_prompt,
                "model_id": args.model_id,
                "template_type": args.template_type,
                "response": "error",
                "judge": "0",
                "rationale": "error",
            }

        new_examples.append(new_example)

        with open(filepath_output, "w") as f:
            json.dump(new_examples, f, indent=4)
            f.write("\n")

    # assert len(examples) == len(new_examples)

    cost = estimate_cost(args.model_id, total_usage)
    logging.info(f"Estimated total cost: ${cost:.4f}.")


if __name__ == "__main__":
    parser = ArgumentParser(description="Evaluate")
    parser.add_argument("--filepath_qa", type=Path, help="filepath to qa data")
    parser.add_argument(
        "--filepath_prediction", type=Path, help="filepath to predictions"
    )
    parser.add_argument(
        "--filepath_annotation", type=Path, help="filepath to original example"
    )
    parser.add_argument(
        "--filepath_instruction", type=Path, help="filepath for instruction"
    )
    parser.add_argument("--filepath_template", type=Path, help="filepath to template")
    parser.add_argument("--dirpath_output", type=Path, help="dirpath to output")
    parser.add_argument("--template_type", type=str, help="template_type")
    parser.add_argument("--model_id", type=str, help="model id")
    parser.add_argument("--temperature", type=float, help="temperature", default=0.0)
    parser.add_argument(
        "--max_tokens", type=int, help="max tokens to generate", default=512
    )
    parser.add_argument("--reasoning", action="store_true", help="Enable reasoning")
    parser.add_argument("--wait_time", type=int, help="API call wait time", default=0.5)
    parser.add_argument("--dirpath_log", type=Path, help="dirpath to log")

    args = parser.parse_args()

    logging.basicConfig(
        format="%(asctime)s:%(levelname)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO,
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler(args.dirpath_log / f"evaluate_{get_date()}.log"),
        ],
    )

    logging.info(f"Arguments: {vars(args)}")

    main(args)
