import json
from pathlib import Path
from dataclasses import dataclass
from pprint import pprint

import tyro

from mathv_utils import load_jsonl
from scorer import score, get_stats
from geminieval import machine_eval


@dataclass
class Config:
    data_path: str = ""
    verbose: bool = False
    ipdb: bool = True
    do_machine_eval: bool = False


def math_level_subject_acc(lines):
    results_dict = {}
    for line in lines:
        correct = line["correct"]
        raw_exampe = line
        # raw_exampe = id_raw[line["id"]]
        subject = raw_exampe["subject"]
        level = raw_exampe["level"]
        for key in [
            "-all",
            f"-level{level}",
            f"{subject}",
            f"{subject}_level{level}",
            f"-level{level}_{subject}",
        ]:
            if key not in results_dict:
                results_dict[key] = [0, 0]
            results_dict[key][0] += 1 if correct else 0
            results_dict[key][1] += 1

    for key in results_dict.keys():
        if results_dict[key][1] == 0:
            results_dict[key] = f"{results_dict[key][0]}/{results_dict[key][1]}=0"
        else:
            results_dict[key] = (
                f"{results_dict[key][0]}/{results_dict[key][1]}={round(results_dict[key][0]/ max(results_dict[key][1], 1)*100, 2)}%"
            )

    results_dict = {key: results_dict[key] for key in sorted(results_dict.keys())}
    return results_dict


def get_gt_answer_value(raw_exampe, gt_answer):
    return raw_exampe["options"][ord(gt_answer) - ord("A")]


def get_metas(data_path):
    if "mathvision_full" in data_path:
        is_mini = False
        path = "../../data/temp/MathVision.tsv"
    elif "mathvision_mini" in data_path:
        is_mini = True
        path = "../../data/temp/MathVision_MINI.tsv"
    else:
        import ipdb; ipdb.set_trace()
        raise ValueError(f"Invalid data path: {data_path}")

    import pandas as pd
    import io
    import base64
    from PIL import Image

    def decode(x):
        x = base64.b64decode(x)
        x = Image.open(io.BytesIO(x))
        return x

    df = pd.read_csv(path, sep="\t")
    ds_data = {}
    sizes = []
    for i in range(len(df)):
        row = df.iloc[i]
        idx = str(row["index"])
        image = decode(row["image"])
        out_row = {
            "decoded_image": image,
            "image_size": image.size,
            "query": row["question"],
            "id": idx,
        }
        ds_data[idx] = out_row
        sizes.append(list(image.size))

    if is_mini:
        metas = load_jsonl("./MATH-V/data/testmini.jsonl")
    else:
        metas = load_jsonl("./MATH-V/data/test.jsonl")
    metas = {ex["id"]: ex for ex in metas}

    for v in metas.values():
        v["image_size"] = list(ds_data[v["id"]]["image_size"])
    return metas


def calc(data_path):
    # for comparison
    metas = get_metas(data_path)
    data, outs, maybe_correct = score(data_path, metas, get_gt_answer_value)
    return outs


def main(args):
    metas = get_metas(args.data_path)
    data, outs, maybe_correct = score(
        args.data_path, metas, get_gt_answer_value, verbose=args.verbose
    )
    all_outs = outs

    # we do not run machin-eval for mathvision since mathvision answers are already well-formed.
    if args.do_machine_eval:
        questions = {k: row["question"] for k, row in metas.items()}
        all_outs, machine_outs = machine_eval(questions, outs, task="mathvision")

    stats = get_stats(all_outs, math_level_subject_acc)
    pprint(stats)
    filename = Path(args.data_path).name
    Path("./scores").mkdir(exist_ok=True)
    to_save = {"stats": stats, "response": all_outs}
    print(stats['-all'])
    with open(f"./scores/score_{filename}", "w") as f:
        json.dump(to_save, f, indent=4)
    if args.ipdb:
        import ipdb; ipdb.set_trace() # noqa # fmt: skip


if __name__ == "__main__":
    args = tyro.cli(Config)
    main(args)
