"""
Step-4: Summarize scores
"""

from collections import defaultdict
import json
import logging
from pathlib import Path

import fire
import numpy as np

from utils import CompactJSONEncoder, setup_logger, load_jsonl_iter


logger = logging.getLogger("text2svg")


def score_to_level(score):
    assert 1 <= score <= 10
    if score in [1, 2, 3]:
        return "easy"
    elif score in [4, 5, 6]:
        return "medium"
    elif score in [7, 8, 9, 10]:
        return "hard"


def convert_to_normal_dict(d):
    if isinstance(d, defaultdict):
        d = {k: convert_to_normal_dict(v) for k, v in d.items()}
    return d


def summarize_entry(
    input_file: str,
    output_file: str,
    verbose: bool = False,
):
    input_file = Path(input_file)
    save_file = Path(output_file)

    if not logging.getLogger("text2svg").hasHandlers():
        setup_logger(save_file.parent, console_output=True)

    save_folder = save_file.parent
    save_folder.mkdir(parents=True, exist_ok=True)
    logger.info(f"{save_folder = }")

    level2id2scores = {k: defaultdict(list) for k in ["easy", "medium", "hard", "overall"]}
    num_failed, num_good = 0, 0
    for ex in load_jsonl_iter(input_file):
        ex_id = ex["origin_idx"]

        if "text2svg_difficulty" in ex:
            level = ex["text2svg_difficulty"]
        else:
            level = score_to_level(ex["svg_difficulty"])

        judged_score = ex["score"]

        if not (ex["render_status"] and isinstance(judged_score, (float, int))):
            num_failed += 1
            judged_score = 0.0
        else:
            num_good += 1

        level2id2scores[level][ex_id].append(judged_score)
        level2id2scores["overall"][ex_id].append(judged_score)

    num_samples = []
    for level, id2scores in level2id2scores.items():
        for ex_id, scores in id2scores.items():
            num_samples.append(len(scores))

    num_sample = max(num_samples)
    logger.info(f"{num_sample = }")

    result = {}
    logger.info("-" * 50)
    logger.info(f"{num_failed = }")
    logger.info("-" * 50)
    for level, id2scores in level2id2scores.items():
        # assert len(id2scores), id2scores
        if not len(id2scores):
            logger.warning(f"[{level=}] There's no instance.")
            continue

        all_scores, max_scores, min_scores = [], [], []
        for ex_id, scores in id2scores.items():

            all_scores.extend(scores)
            max_scores.append(max(scores))
            min_scores.append(min(scores))

        avg_all = round(np.mean(all_scores), 4)
        avg_max = round(np.mean(max_scores), 4)
        avg_min = round(np.mean(min_scores), 4)

        pack = {"avg_of_n": avg_all, "best_of_n": avg_max, "worst_of_n": avg_min, "n": num_sample, "count": len(all_scores)}

        logger.info(f"[{level}] {pack}")
        result[level] = pack

    logger.info("-" * 50)

    with Path(output_file).open("w") as f:
        data = result.copy()
        data["num_failed"] = num_failed
        data["num_good"] = num_good
        data["num_all"] = num_failed + num_good
        data["num_sample"] = num_sample
        data["input_file"] = str(input_file)

        if verbose:
            data["detail"] = convert_to_normal_dict(level2id2scores)
        json.dump(data, f, indent=4, ensure_ascii=False, cls=CompactJSONEncoder)
    logger.info(f"Result => {output_file}")


if __name__ == "__main__":
    fire.Fire(summarize_entry)
