import base64
import io
import json
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Literal, Optional

import gdown
import ray
import tyro
from PIL import Image

ray.init()

from inferer import infer
from score_mathvision import main as score
from utils import get_name, set_qwen_args, EvalMathVisionConfig, download

TEMPLATE_TYPE = Literal["base", "v1", "v2", "v3", "none"]

def main():
    args = tyro.cli(EvalMathVisionConfig)
    random.seed(args.seed)

    if "Qwen/Qwen2.5-VL" in args.model:
        print(f"Qwen/Qwen2.5-VL in {args.model}, setting arguments for Qwen models")
        args = set_qwen_args(args)

    output_name = get_name(args)

    if args.out_path is None:
        if args.is_mini:
            out_path = Path(f"outputs/mathvision_mini_{output_name}.json")
            rerun_out_path = Path(f"outputs/mathvision_mini_{output_name}_rerun_force.json")
        else:
            out_path = Path(f"outputs/mathvision_full_{output_name}.json")
            rerun_out_path = Path(f"outputs/mathvision_full_{output_name}_rerun_force.json")
    else:
        out_path = Path(args.out_path)
        rerun_out_path = out_path.parent / f"{out_path.stem}_rerun_force{out_path.suffix}"
    
    if args.rerun and not out_path.exists():
        print(f"Output path {out_path} does not exist!")
        import ipdb; ipdb.set_trace()
    
    if args.rerun and rerun_out_path.is_file():
        assert out_path.is_file()
        out_path = rerun_out_path
    
    print(f"Output path: {out_path}")
    if args.rerun:
        print(f"Rerun output path: {rerun_out_path}")

    results = {}
    if out_path.is_file():
        with open(out_path, "r") as f:
            results = json.load(f)

    indices = None
    if args.indices is not None:
        indices = [v for v in args.indices.split(",")]

    # ds = load_dataset("MathLLMs/MathVision", split="testmini")
    # data = []
    # for row in ds:
    #     prompt = get_prompt(row)
    #     row["query"] = prompt
    #     data.append(row)

    def decode(x):
        x = base64.b64decode(x)
        x = Image.open(io.BytesIO(x))
        return x

    # https://github.com/open-compass/VLMEvalKit/blob/main/vlmeval/dataset/image_vqa.py
    if args.is_mini:
        url = "https://opencompass.openxlab.space/utils/VLMEval/MathVision_MINI.tsv"
        path = "../../data/temp/MathVision_MINI.tsv"
    else:
        url = "https://opencompass.openxlab.space/utils/VLMEval/MathVision.tsv"
        path = "../../data/temp/MathVision.tsv"
    download(url, path)
    import pandas as pd

    df = pd.read_csv(path, sep="\t")
    data = []
    for i in range(len(df)):
        row = df.iloc[i]
        idx = str(row["index"])
        out_row = {
            "decoded_image": decode(row["image"]),
            "query": row["question"],
            "id": idx,
        }
        if (indices is None) or idx in indices:
            data.append(out_row)

    if args.shuffle:
        random.shuffle(data)
    infer(args, data, results, out_path, rerun_out_path)

    score_args = lambda: None
    score_args.data_path = str(out_path)
    score_args.ipdb = False
    score_args.do_machine_eval = False
    score_args.verbose = True
    score(score_args)


if __name__ == "__main__":
    main()
