import argparse
import copy as cp
import csv
import json
import os
import os.path as osp
import pickle
import re
import string
import time
from collections import defaultdict
from typing import Callable, Iterable

import numpy as np
import pandas as pd


def MathVista_auxeval(model, line):
    prompt = build_mathvista_gpt4_prompt(line)
    # import pdb; pdb.set_trace()
    log = ""
    retry = 5
    if post_check(line, prefetch=True):
        res = post_check(line, prefetch=True)
        return dict(log="Prefetch succeed", res=res)
    for i in range(retry):
        prediction = line["prediction"]
        res = model.generate(prompt, temperature=i * 0.5)

        if FAIL_MSG in res:
            log += f"Try {i}: output is {prediction}, failed to parse.\n"
        else:
            log += "Succeed"
            return dict(log=log, res=res)
    log += "All 5 retries failed.\n"
    return dict(log=log, res="")


def post_check(line, prefetch=False):
    res = None
    ans = line["answer"]
    response = line["prediction"] if prefetch else line["res"]
    try:
        if line["question_type"] == "multi_choice":
            ans = line["answer_option"]
            choices = list_to_dict(eval(line["choices"]))
            res = can_infer(response, choices)
            if prefetch:
                return res
        else:
            if line["answer_type"] == "integer":
                res = int(response)
                ans = int(line["answer"])
            elif line["answer_type"] == "float":
                res = float(response)
                ans = float(line["answer"])
            else:
                res = str(res)
                ans = str(ans)
    except ValueError:
        pass

    if res == ans:
        return res if prefetch else True
    else:
        return False


import logging

logging.basicConfig(
    format="[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

logger_initialized = {}


def get_logger(name, log_file=None, log_level=logging.INFO, file_mode="w"):
    logger = logging.getLogger(name)
    if name in logger_initialized:
        return logger

    for logger_name in logger_initialized:
        if name.startswith(logger_name):
            return logger

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    try:
        import torch.distributed as dist

        if dist.is_available() and dist.is_initialized():
            rank = dist.get_rank()
        else:
            rank = 0
    except ImportError:
        rank = 0

    if rank == 0 and log_file is not None:
        file_handler = logging.FileHandler(log_file, file_mode)
        handlers.append(file_handler)

    formatter = logging.Formatter(
        "[%(asctime)s] %(levelname)s - %(name)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s"
    )
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)

    logger_initialized[name] = True
    return logger


def list_to_dict(lst):
    return {chr(65 + i): val for i, val in enumerate(lst)}


def can_infer(answer, choices):
    answer = str(answer)
    copt = can_infer_option(answer, choices)
    return copt if copt else can_infer_text(answer, choices)


def can_infer_option(answer, choices):
    verbose = os.environ.get("VERBOSE", 0)
    # Choices is a dictionary
    if "Failed to obtain answer via API" in answer:
        return False

    reject_to_answer = [
        "Sorry, I can't help with images of people yet.",
        "I can't process this file.",
        "I'm sorry, but without the image provided",
        "Cannot determine the answer",
    ]
    for err in reject_to_answer:
        if err in answer:
            return "Z"

    def count_choice(splits, choices, prefix="", suffix=""):
        cnt = 0
        for c in choices:
            if prefix + c + suffix in splits:
                cnt += 1
        return cnt

    answer_mod = cp.copy(answer)
    chars = ".()[],:;!*#{}"
    for c in chars:
        answer_mod = answer_mod.replace(c, " ")

    splits = [x.strip() for x in answer_mod.split()]
    count = count_choice(splits, choices)

    if count == 1:
        for ch in choices:
            if "A" in splits and len(splits) > 3 and verbose:
                logger = get_logger("Evaluation")
                logger.info(f"A might be a quantifier in the string: {answer}.")
                return False
            if ch in splits:
                return ch
    elif count == 0 and count_choice(splits, {"Z", ""}) == 1:
        return "Z"
    return False


def can_infer_text(answer, choices):
    answer = answer.lower()
    assert isinstance(choices, dict)
    for k in choices:
        assert k in string.ascii_uppercase
        choices[k] = str(choices[k]).lower()
    cands = []
    for k in choices:
        if choices[k] in answer:
            cands.append(k)
    if len(cands) == 1:
        return cands[0]
    return False


FAIL_MSG = "Failed to obtain answer via API."


def get_gpt4_ICE():
    example_1 = """
Hint: Please answer the question requiring an integer answer and provide the final value,
e.g., 1, 2, 3, at the end.\n
Question: Which number is missing?\n
Model response: The number missing in the sequence is 14.\n
Extracted answer: 14
"""

    example_2 = """
Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value,
e.g., 1.2, 1.3, 1.4, at the end.\n
Question: What is the fraction of females facing the camera?\n
Model response: The fraction of females facing the camera is 0.6,
which means that six out of ten females in the group are facing the camera.\n
Extracted answer: 0.6
"""

    example_3 = """
Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value,
e.g., 1.23, 1.34, 1.45, at the end.\n
Question: How much money does Luca need to buy a sour apple candy and a butter-scotch candy? (Unit: $)\n
Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.\n
Extracted answer: 1.45
"""

    example_4 = """
Hint: Please answer the question requiring a Python list as an answer and provide the final list,
e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.\n
Question: Between which two years does the line graph saw its maximum peak?\n
Model response: The line graph saw its maximum peak between 2007 and 2008.\n
Extracted answer: [2007, 2008]
"""

    example_5 = """
Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
Question: What fraction of the shape is blue?\n
Choices: (A) 3/11 (B) 8/11 (C) 6/11 (D) 3/5\n
Model response: The correct answer is (B) 8/11.\n
Extracted answer: B
"""

    return [example_1, example_2, example_3, example_4, example_5]


def build_mathvista_gpt4_prompt(line):
    task_description = """
Please read the following example.
Then extract the answer from the model response and type it at the end of the prompt.\n
"""
    question = line["question"]
    prediction = str(line["prediction"])
    prompt = task_description
    examples = get_gpt4_ICE()
    for example in examples:
        prompt += example + "\n"
    prompt += question + "\n"
    prompt += "Model respone: " + prediction
    prompt += "Extracted answer:"
    return prompt


def MathVista_acc(result_file):
    data = load(result_file)
    tot = defaultdict(lambda: 0)
    fetch = defaultdict(lambda: 0)
    hit = defaultdict(lambda: 0)
    lt = len(data)
    skill_list = []
    for i in range(lt):
        item = data.iloc[i]
        cate = item["task"]
        tot["Overall"] += 1
        try:
            skills = eval(item["skills"])
        except SyntaxError:
            skills = [item["skills"]]
        for skill in skills:
            if skill not in skill_list:
                skill_list.append(skill)
            tot[skill] += 1
        tot[cate] += 1
        if item["log"] == "Prefetch succeed":
            fetch["Overall"] += 1
            fetch[cate] += 1
            for skill in skills:
                fetch[skill] += 1
        if post_check(item, prefetch=False):
            hit["Overall"] += 1
            hit[cate] += 1
            for skill in skills:
                hit[skill] += 1

    res = defaultdict(list)
    for k in tot.keys():
        res["Task&Skill"].append(k)
        res["tot"].append(tot[k])
        res["prefetch"].append(fetch[k])
        res["hit"].append(hit[k])
        res["prefetch_rate"].append(fetch[k] / tot[k] * 100)
        res["acc"].append(hit[k] / tot[k] * 100)
    res = pd.DataFrame(res)
    return res


def download_file(url, filename=None):
    import urllib.request

    from tqdm import tqdm

    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)

    if filename is None:
        filename = url.split("/")[-1]

    try:
        with DownloadProgressBar(
            unit="B", unit_scale=True, miniters=1, desc=url.split("/")[-1]
        ) as t:
            urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)
    except Exception as e:
        import logging

        logging.warning(f"{type(e)}: {e}")
        # Handle Failed Downloads from huggingface.co
        if "huggingface.co" in url:
            url_new = url.replace("huggingface.co", "hf-mirror.com")
            try:
                download_file(url_new, filename)
                return filename
            except Exception as e:
                logging.warning(f"{type(e)}: {e}")
                raise Exception(f"Failed to download {url}")
        else:
            raise Exception(f"Failed to download {url}")

    return filename


def load_env():
    import logging

    logging.basicConfig(
        format="[%(asctime)s] %(levelname)s - %(filename)s: %(funcName)s - %(lineno)d: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # try:
    #     import vlmeval
    # except ImportError:
    #     logging.error('VLMEval is not installed. Failed to import environment variables from .env file. ')
    #     return
    # pth = osp.realpath(vlmeval.__path__[0])
    pth = "./.env"
    pth = osp.realpath(pth)
    if not osp.exists(pth):
        logging.error(f"Did not detect the .env file at {pth}, failed to load. ")
        return

    from dotenv import dotenv_values

    values = dotenv_values(pth)
    for k, v in values.items():
        if v is not None and len(v):
            os.environ[k] = v
    logging.info(f"API Keys successfully loaded from {pth}")


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--result-file",
        type=str,
        default="results/llava_v1.5_7b_MathVista_MINI-new.xlsx",
    )
    # parser.add_argument('--output-dir', type=str, default='./LLaVA/results/CoIN_slim_new_0.8/OCRVQA/Finetune')
    return parser.parse_args()


def build_judge(**kwargs):
    from .api import OpenAIWrapper

    model = kwargs.pop("model", None)
    kwargs.pop("nproc", None)
    load_env()
    LOCAL_LLM = os.environ.get("LOCAL_LLM", None)
    if LOCAL_LLM is None:
        model_map = {
            "gpt-4-turbo": "gpt-4-1106-preview",
            "gpt-4-0613": "gpt-4-0613",
            "gpt-4-0125": "gpt-4-0125-preview",
            "gpt-4-0409": "gpt-4-turbo-2024-04-09",
            "chatgpt-1106": "gpt-3.5-turbo-1106",
            "chatgpt-0125": "gpt-3.5-turbo-0125",
            "gpt-4o": "gpt-4o-2024-05-13",
            "gpt-4o-0806": "gpt-4o-2024-08-06",
            "gpt-4o-mini": "gpt-4o-mini-2024-07-18",
            "qwen-7b": "Qwen/Qwen2.5-7B-Instruct",
            "qwen-72b": "Qwen/Qwen2.5-72B-Instruct",
            "deepseek": "deepseek-ai/DeepSeek-V2.5",
        }
        model_version = model_map[model]
    else:
        model_version = LOCAL_LLM

    model = OpenAIWrapper(model_version, **kwargs)
    return model


def track_progress_rich(
    func: Callable,
    tasks: Iterable = tuple(),
    nproc: int = 1,
    save=None,
    keys=None,
    **kwargs,
) -> list:
    from concurrent.futures import ThreadPoolExecutor

    from tqdm import tqdm

    if save is not None:
        assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ""
        if not osp.exists(save):
            dump({}, save)
    if keys is not None:
        assert len(keys) == len(tasks)
    if not callable(func):
        raise TypeError("func must be a callable object")
    if not isinstance(tasks, Iterable):
        raise TypeError(f"tasks must be an iterable object, but got {type(tasks)}")
    assert nproc > 0, "nproc must be a positive number"
    res = load(save) if save is not None else {}
    results = [None for _ in range(len(tasks))]

    with ThreadPoolExecutor(max_workers=nproc) as executor:
        futures = []

        for inputs in tasks:
            if not isinstance(inputs, (tuple, list, dict)):
                inputs = (inputs,)
            if isinstance(inputs, dict):
                future = executor.submit(func, **inputs)
            else:
                future = executor.submit(func, *inputs)
            futures.append(future)

        unfinished = set(range(len(tasks)))
        pbar = tqdm(total=len(unfinished))
        while len(unfinished):
            new_finished = set()
            for idx in unfinished:
                if futures[idx].done():
                    results[idx] = futures[idx].result()
                    new_finished.add(idx)
                    if keys is not None:
                        res[keys[idx]] = results[idx]
            if len(new_finished):
                if save is not None:
                    dump(res, save)
                pbar.update(len(new_finished))
                for k in new_finished:
                    unfinished.remove(k)
            time.sleep(0.1)
        pbar.close()

    if save is not None:
        dump(res, save)
    return results


def evaluate(eval_file, **judge_kwargs):
    model = judge_kwargs["model"]
    suffix = eval_file.split(".")[-1]
    storage = eval_file.replace(f".{suffix}", f"_{model}.xlsx")
    tmp_file = eval_file.replace(f".{suffix}", f"_{model}.pkl")
    nproc = judge_kwargs.pop("nproc", 50)
    # nproc = 1

    if not osp.exists(storage):
        data = load(eval_file)
        # import pdb; pdb.set_trace()
        model = build_judge(max_tokens=128, **judge_kwargs)
        # assert model.working(), ('MathVista evaluation requires a working OPENAI API\n' + DEBUG_MESSAGE)
        lt = len(data)
        lines = [data.iloc[i] for i in range(lt)]
        tups = [(model, line) for line in lines]
        indices = [line["index"] for line in lines]

        ans = {}
        if osp.exists(tmp_file):
            ans = load(tmp_file)
        tups = [x for x, i in zip(tups, indices) if i not in ans]
        indices = [i for i in indices if i not in ans]

        if len(indices):
            new_results = track_progress_rich(
                MathVista_auxeval,
                tups,
                nproc=nproc,
                chunksize=nproc,
                keys=indices,
                save=tmp_file,
            )
            ans = load(tmp_file)
            for k, v in zip(indices, new_results):
                assert k in ans
                assert ans[k]["log"] == v["log"] and ans[k]["res"] == v["res"]

        data["res"] = [ans[idx]["res"] for idx in data["index"]]
        data["log"] = [ans[idx]["log"] for idx in data["index"]]
        dump(data, storage)

    score = MathVista_acc(storage)
    score_pth = storage.replace(".xlsx", "_score.csv")
    dump(score, score_pth)
    return score


def LMUDataRoot():
    if "LMUData" in os.environ and osp.exists(os.environ["LMUData"]):
        return os.environ["LMUData"]
    home = osp.expanduser("~")
    root = osp.join(home, "LMUData")
    os.makedirs(root, exist_ok=True)
    return root


def load(f, fmt=None):
    def load_pkl(pth):
        return pickle.load(open(pth, "rb"))

    def load_json(pth):
        return json.load(open(pth, "r", encoding="utf-8"))

    def load_jsonl(f):
        # special fix for math
        excel_ori = pd.read_excel("llava_v1.5_7b_MathVista_MINI.xlsx")

        lines = open(f, encoding="utf-8").readlines()
        lines = [x.strip() for x in lines]
        if lines[-1] == "":
            lines = lines[:-1]
        data = [json.loads(x) for x in lines]

        for d in data:
            id_excel = int(d["question_id"].replace("testmini_", ""))
            outputs = d["text"]
            excel_ori.at[id_excel, "prediction"] = outputs

        dump(excel_ori, f"{os.path.dirname(f)}/result.xlsx")

        return excel_ori

    def load_xlsx(f):
        return pd.read_excel(f)

    def load_csv(f):
        return pd.read_csv(f)

    def load_tsv(f):
        return pd.read_csv(f, sep="\t")

    import validators

    if validators.url(f):
        tgt = osp.join(LMUDataRoot(), "files", osp.basename(f))
        if not osp.exists(tgt):
            download_file(f, tgt)
        f = tgt

    handlers = dict(
        pkl=load_pkl,
        json=load_json,
        jsonl=load_jsonl,
        xlsx=load_xlsx,
        csv=load_csv,
        tsv=load_tsv,
    )
    if fmt is not None:
        return handlers[fmt](f)

    suffix = f.split(".")[-1]
    return handlers[suffix](f)


class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(
            obj,
            (
                np.int_,
                np.intc,
                np.intp,
                np.int8,
                np.int16,
                np.int32,
                np.int64,
                np.uint8,
                np.uint16,
                np.uint32,
                np.uint64,
            ),
        ):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)):
            return float(obj)
        elif isinstance(obj, (np.complex_, np.complex64, np.complex128)):
            return {"real": obj.real, "imag": obj.imag}
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        elif isinstance(obj, (np.bool_)):
            return bool(obj)
        elif isinstance(obj, (np.void)):
            return None
        return json.JSONEncoder.default(self, obj)


# LOAD & DUMP
def dump(data, f, **kwargs):
    def dump_pkl(data, pth, **kwargs):
        pickle.dump(data, open(pth, "wb"))

    def dump_json(data, pth, **kwargs):
        json.dump(data, open(pth, "w"), indent=4, ensure_ascii=False, cls=NumpyEncoder)

    def dump_jsonl(data, f, **kwargs):
        lines = [json.dumps(x, ensure_ascii=False, cls=NumpyEncoder) for x in data]
        with open(f, "w", encoding="utf8") as fout:
            fout.write("\n".join(lines))

    def dump_xlsx(data, f, **kwargs):
        data.to_excel(f, index=False, engine="xlsxwriter")

    def dump_csv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, index=False, encoding="utf-8", quoting=quoting)

    def dump_tsv(data, f, quoting=csv.QUOTE_ALL):
        data.to_csv(f, sep="\t", index=False, encoding="utf-8", quoting=quoting)

    handlers = dict(
        pkl=dump_pkl,
        json=dump_json,
        jsonl=dump_jsonl,
        xlsx=dump_xlsx,
        csv=dump_csv,
        tsv=dump_tsv,
    )
    suffix = f.split(".")[-1]
    return handlers[suffix](data, f, **kwargs)


if __name__ == "__main__":
    args = get_args()

    if args.result_file is not None:
        kwargs = {"verbose": True, "retry": 3, "model": "gpt-4o-mini"}
        evaluate(args.result_file, **kwargs)
