import json
import csv
import pickle
from pathlib import Path
import re
from argparse import Namespace
from dataclasses import dataclass
from typing import Optional, Literal
import gdown
import pandas as pd
import numpy as np


@dataclass
class EvalMathVisionConfig:
    is_mini: bool = True

    model: str = (
        "../cache/v1/v1_Qwen2_5-VL-7B-Instruct_zloss/checkpoint-23000"
    )
    model_is_pgn: bool = True
    max_new_tokens: int = 8192
    use_bad_words: bool = False
    repetition_penalty: float = 1.05
    template_type: TEMPLATE_TYPE = "base"

    do_resize: bool = True
    resize_max_pixels: bool = False
    max_image_size: Optional[int] = 448
    min_image_size: Optional[int] = 448
    do_base64: bool = False

    rerun: bool = False
    seed: int = 0
    do_copy: bool = True
    batch_size: int = 1
    indices: Optional[str] = None
    shuffle: bool = True
    out_path: Optional[str] = None


def download(url: str, path: str):
    if not Path(path).is_file():
        gdown.download(url, str(path), quiet=False)
        print(f"Downloaded: {url} → {path}")


def get_name(args):
    path = Path(args.model)
    model_name = path.parent.name + "_" + path.name

    name = f"model_{model_name}_template_{args.template_type}_nowait_{args.use_bad_words}_resize1_{args.do_resize}_resize2_{args.resize_max_pixels}_penalty_{args.repetition_penalty}"
    if args.do_resize:
        name += f"_max_image_size_{args.max_image_size}_min_image_size_{args.min_image_size}"
    if args.model_is_pgn:
        name += f"_do_copy_{args.do_copy}"

    return name

def _parse_name_string(name_string: str) -> dict | None:
    """Helper function to parse the name string using regex."""
    # Regex pattern based on the structure defined in get_name
    pattern = re.compile(
        r"^model_(?P<model_name>.*?)_"
        r"template_(?P<template_type>.*?)_"
        r"nowait_(?P<use_bad_words>True|False)_"
        r"resize1_(?P<do_resize>True|False)_"
        r"resize2_(?P<resize_max_pixels>True|False)_"
        r"penalty_(?P<repetition_penalty>\d+\.?\d*)"
        r"(?:_do_copy_(?P<do_copy>True|False))?$" # Optional _do_copy_ part
    )

    match = pattern.match(name_string)

    if not match:
        print(f"Warning: Name string '{name_string}' did not match expected format.")
        return None

    parsed_data = match.groupdict()

    # Convert parsed strings to appropriate types
    args = {}
    # Note: Cannot reconstruct the full original 'model' path, only the derived name part.
    args['model_name'] = parsed_data['model_name']
    args['template_type'] = parsed_data['template_type']
    args['use_bad_words'] = parsed_data['use_bad_words'] == 'True'
    args['do_resize'] = parsed_data['do_resize'] == 'True'
    args['resize_max_pixels'] = parsed_data['resize_max_pixels'] == 'True'
    args['repetition_penalty'] = float(parsed_data['repetition_penalty'])

    if parsed_data['do_copy'] is not None:
        args['model_is_pgn'] = True
        args['do_copy'] = parsed_data['do_copy'] == 'True'
    else:
        args['model_is_pgn'] = False
        # do_copy is not set if it wasn't in the name string

    return args


def parse_args(out_path: str) -> dict | None:
    """
    Parses arguments encoded in a name string, typically the last part of out_path.
    This is the inverse operation of get_name, recovering a subset of arguments.

    Args:
        out_path: The path string containing the encoded arguments (e.g., output directory or file path).

    Returns:
        A dictionary containing the parsed arguments, or None if parsing fails.
    """
    # Assume the relevant name string is the final component of the path, without the suffix
    name_string = Path(out_path).stem # Use .stem to get filename without extension
    return _parse_name_string(name_string)


def set_qwen_args(args):
    args.model_is_pgn = False
    args.repetition_penalty = 1.0
    args.template_type = "qwen2"

    return args


def load(f, fmt=None):
    def load_pkl(pth):
        return pickle.load(open(pth, "rb"))

    def load_json(pth):
        return json.load(open(pth, "r", encoding="utf-8"))

    def load_jsonl(f):
        lines = open(f, encoding="utf-8").readlines()
        lines = [x.strip() for x in lines]
        if lines[-1] == "":
            lines = lines[:-1]
        data = [json.loads(x) for x in lines]
        return data

    def load_xlsx(f):
        return pd.read_excel(f)

    def load_csv(f):
        return pd.read_csv(f)

    def load_tsv(f):
        return pd.read_csv(f, sep="\t")

    # import validators
    # if validators.url(f):
    #     tgt = osp.join(LMUDataRoot(), 'files', osp.basename(f))
    #     if not osp.exists(tgt):
    #         download_file(f, tgt)
    #     f = tgt

    handlers = dict(
        pkl=load_pkl,
        json=load_json,
        jsonl=load_jsonl,
        xlsx=load_xlsx,
        csv=load_csv,
        tsv=load_tsv,
    )
    if fmt is not None:
        assert fmt in handlers
        return handlers[fmt](f)

    suffix = f.split(".")[-1]
    assert suffix in handlers
    return handlers[suffix](f)


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(
            obj,
            (
                np.int_,
                np.intc,
                np.intp,
                np.int8,
                np.int16,
                np.int32,
                np.int64,
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ),
        ):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
            return {"real": obj.real, "imag": obj.imag}
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        elif isinstance(obj, (np.bool_)):
            return bool(obj)
        elif isinstance(obj, (np.void)):
            return None
        return json.JSONEncoder.default(self, obj)


def dump(data, f, **kwargs):
    def dump_pkl(data, pth, **kwargs):
        pickle.dump(data, open(pth, "wb"))

    def dump_json(data, pth, **kwargs):
        json.dump(data, open(pth, "w"), indent=4, ensure_ascii=False, cls=NumpyEncoder)

    def dump_jsonl(data, f, **kwargs):
        lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
        with open(f, "w", encoding="utf8") as fout:
            fout.write("\n".join(lines))

    def dump_xlsx(data, f, **kwargs):
        data.to_excel(f, index=False, engine="xlsxwriter")

    def dump_csv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, index=False, encoding="utf-8", quoting=quoting)

    def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, sep="\t", index=False, encoding="utf-8", quoting=quoting)

    handlers = dict(
        pkl=dump_pkl,
        json=dump_json,
        jsonl=dump_jsonl,
        xlsx=dump_xlsx,
        csv=dump_csv,
        tsv=dump_tsv,
    )
    suffix = f.split(".")[-1]
    return handlers[suffix](data, f, **kwargs)
