import ast
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path

import numpy as np
import requests
import yaml
from loguru import logger as eval_logger
from openai import OpenAI
from PIL import Image
from tqdm import tqdm

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
from lmms_eval.tasks.capability.prompt import Prompts

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
    raw_data = f.readlines()
    safe_data = []
    for i, line in enumerate(raw_data):
        # remove function definition since yaml load cannot handle it
        if "!function" not in line:
            safe_data.append(line)
config = yaml.safe_load("".join(safe_data))

API_TYPE = os.getenv("API_TYPE", "openai")

if API_TYPE == "openai":
    API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
    API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
    }
elif API_TYPE == "azure":
    API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
    API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
    headers = {
        "api-key": API_KEY,
        "Content-Type": "application/json",
    }
else:
    API_URL = "YOUR_API_URL"
    API_KEY = "YOUR_API_KEY"
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
    }

HF_HOME = os.getenv("HF_HOME", "~/.cache/huggingface")
HF_HOME = os.path.expanduser(HF_HOME)
cache_dir = os.path.join(HF_HOME, config["dataset_kwargs"]["cache_dir"])


def capability_doc_to_visual(doc, lmms_eval_specific_kwargs=None):
    data_type = doc["data_type"]
    file_path = doc["file_path"][5:]
    file_path = os.path.join(cache_dir, file_path)
    if not os.path.exists(file_path):
        eval_logger.error(f"File path: {file_path} does not exist, please check.")

    if data_type == "image":
        return [Image.open(file_path).convert("RGB")]
    else:  # video
        return [file_path]


def capability_doc_to_text(doc, lmms_eval_specific_kwargs=None):
    data_type = doc["data_type"]
    return lmms_eval_specific_kwargs[f"{data_type}_prompt"]


def capability_process_results(doc, results):
    """
    Args:
        doc: a instance of the eval dataset
        results: [pred]
    Returns:
        a dictionary with key: metric name (in this case capability_perception_score), value: metric value
    """
    if isinstance(doc["annotation"], dict):
        annotation = {k: v for k, v in doc["annotation"].items() if v is not None}
    else:
        annotation = doc["annotation"]

    response = {
        "file_id": doc["file_id"],
        "caption": results[0].strip(),
        "annotation": annotation,
        "task": doc["task"],
    }
    return {
        "capability_inference_result": response,
        "capability_precision": response,
        "capability_recall": response,
        "capability_f1_score": response,
    }


def capability_aggregate_inference_result(results, args):
    task = results[0]["task"]
    if "eval_save_root" in config["metadata"] and config["metadata"]["eval_save_root"] is not None:
        save_path = os.path.join(config["metadata"]["eval_save_root"], f"inference/{task}.jsonl")
    else:
        suffix = args.model if args.log_samples_suffix == "model_outputs" else args.log_samples_suffix
        save_path = generate_submission_file(file_name=f"{task}.jsonl", args=args, subpath=f"capability_results/{suffix}/inference")

    # delete the invalid evaluation results as lmms-eval do not support auto-resume inference
    # to ensure re-run evaluation if re-run inference
    eval_save_path = os.path.join(os.path.dirname(save_path), f"../evaluation/{task}.jsonl")
    if os.path.exists(eval_save_path):
        eval_logger.warning(f"Found EXISTING evaluation records: {eval_save_path}, REMOVING it!")
        os.remove(eval_save_path)

    with open(save_path, "w") as f:
        for result in results:
            f.write(json.dumps(result) + "\n")
    return None


def capability_aggregate_results(results, args):
    """
    Args:
        results: a list of values returned by process_results
    Returns:
        A score
    """
    # results: [{"file_id": doc["file_id"], "caption": results[0].strip(), "annotation": doc["annotation"], "task": doc["task"]},]
    task = results[0]["task"]
    if "eval_save_root" in config["metadata"] and config["metadata"]["eval_save_root"] is not None:
        save_path = os.path.join(config["metadata"]["eval_save_root"], f"evaluation/{task}.jsonl")
    else:
        suffix = args.model if args.log_samples_suffix == "model_outputs" else args.log_samples_suffix
        save_path = generate_submission_file(file_name=f"{task}.jsonl", args=args, subpath=f"capability_results/{suffix}/evaluation")
    eval_model = config["metadata"]["eval_model_name"]
    num_process = config["metadata"]["eval_num_process"]
    max_allow_missing = config["metadata"]["eval_max_allow_missing"]
    max_retry_times = config["metadata"]["eval_max_retry_times"]
    auto_resume = config["metadata"]["eval_auto_resume"]
    strict_match = config["metadata"]["eval_strict_match"]
    evaluator = Evaluator(task, results, save_path, eval_model, headers, num_process, max_allow_missing, max_retry_times, auto_resume, strict_match)
    score_dict = evaluator.evaluate_scores()
    metrics = evaluator.calculate_metric(score_dict)
    return metrics


def capability_aggregate_precision(results, args):
    metrics = capability_aggregate_results(results, args)
    task = results[0]["task"]
    precision = metrics["precision"]
    eval_logger.info(f"[{task}] precision: {precision:.1f}")
    return precision


def capability_aggregate_recall(results, args):
    metrics = capability_aggregate_results(results, args)
    task = results[0]["task"]
    recall = metrics["recall"]
    eval_logger.info(f"[{task}] recall: {recall:.1f}")
    return recall


def capability_aggregate_f1score(results, args):
    metrics = capability_aggregate_results(results, args)
    task = results[0]["task"]
    f1_score = metrics["f1_score"]
    eval_logger.info(f"[{task}] f1_score: {f1_score:.1f}")
    return f1_score


class Evaluator:
    def __init__(
        self,
        task,
        results,
        save_path,
        eval_model,
        headers,
        num_process=0,
        max_allow_missing=5,
        max_retry_times=10,
        auto_resume=True,
        strict_match=True,
    ):
        self.task = task
        self.results = results
        self.save_path = save_path
        self.eval_model = eval_model
        self.headers = headers
        self.num_process = num_process
        self.max_allow_missing = max_allow_missing
        self.max_retry_times = max_retry_times
        self.auto_resume = auto_resume
        self.strict_match = strict_match
        self.prompts = Prompts()

        self.post_validate_format_func = eval(f"self.post_validate_format_{task}")
        self.post_process_func = eval(f"self.post_process_{task}")

        self.file2anno = {r["file_id"]: r["annotation"] for r in self.results}

    def post_validate_format_event(self, response, anno):
        # "{\"action\": \"copy provided action here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["event"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_event(self, response, anno):
        return response["score"]

    def post_validate_format_action(self, response, anno):
        # "{\"action\": \"copy provided action here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["action"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_action(self, response, anno):
        return response["score"]

    def post_validate_format_object_category(self, response, anno):
        # "{\"object_category\": \"copy provided object here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["object_category"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_object_category(self, response, anno):
        return response["score"]

    def post_validate_format_object_number(self, response, anno):
        # "{\"object_number\": \"copy the provided {object: number} here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if isinstance(response["object_number"], str):
            # assert response['object_number'].startswith("{") and response['object_number'].endswith("}")
            assert ":" in response["object_number"]
            object_category, object_number = response["object_number"].lstrip("{").rstrip("}").split(":")
            object_number = int(object_number.strip())
        elif isinstance(response["object_number"], dict):
            object_category, object_number = list(response["object_number"].items())[0]
            object_number = int(object_number.strip())
        else:
            raise ValueError("Invalid object_number format")
        if self.strict_match:
            assert object_number == list(anno.values())[0]
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_object_number(self, response, anno):
        return response["score"]

    def post_validate_format_dynamic_object_number(self, response, anno):
        # "{\"object_number\": \"copy the provided {object: number} here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        assert "response" in response
        for i, r in enumerate(response["response"]):
            if isinstance(r["object_number"], str):
                # assert response['object_number'].startswith("{") and response['object_number'].endswith("}")
                assert ":" in r["object_number"]
                object_category, object_number = r["object_number"].lstrip("{").rstrip("}").split(":")
                object_number = int(object_number.strip())
            elif isinstance(response["object_number"], dict):
                object_category, object_number = list(r["object_number"].items())[0]
                object_number = int(object_number.strip())
            else:
                raise ValueError("Invalid object_number format")
            if self.strict_match:
                assert object_number == list(anno.values())[i]
            if r["score"] in ["-1", "0", "1"]:
                r["score"] = int(r["score"])
            assert r["score"] in [1, 0, -1]

    def post_process_dynamic_object_number(self, response, anno):
        scores = []
        for r in response["response"]:
            scores.append(r["score"])
        return scores

    def post_validate_format_object_color(self, response, anno):
        # "{\"object_color\": \"copy the provided {object: color} here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if isinstance(response["object_color"], str):
            # assert response['object_color'].startswith("{") and response['object_color'].endswith("}")
            assert ":" in response["object_color"]
            unpacked = response["object_color"].lstrip("{").rstrip("}").split(":")
            if len(unpacked) > 2:
                object_category, object_color = ":".join(unpacked[:-1]), unpacked[-1]
            else:
                object_category, object_color = unpacked
            object_color = object_color.strip()
        elif isinstance(response["object_color"], dict):
            object_category, object_color = list(response["object_color"].items())[0]
            object_color = object_color.strip()
        else:
            raise ValueError("Invalid object_color format")
        if self.strict_match:
            assert object_color == list(anno.values())[0]
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_object_color(self, response, anno):
        return response["score"]

    def post_validate_format_spatial_relation(self, response, anno):
        # "{\"spatial_relation\": \"copy the provided spatial relationship here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["spatial_relation"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_spatial_relation(self, response, anno):
        return response["score"]

    def post_validate_format_scene(self, response, anno):
        # "{\"scene\": \"copy the provided scene here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["scene"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_scene(self, response, anno):
        return response["score"]

    def post_validate_format_camera_angle(self, response, anno):
        # "{\"pred\": \"put your predicted category here\", \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        assert "pred" in response
        if response["pred"] == "N/A" or "N/A" in response["pred"]:
            response["pred"] = ["N/A"]
        if isinstance(response["pred"], str):
            response["pred"] = ast.literal_eval(response["pred"])
        assert isinstance(response["pred"], list)
        for i in range(len(response["pred"])):
            if response["pred"][i] in self.prompts.camera_angle_category_explains:
                response["pred"][i] = response["pred"].split(":")[0].lower()
            assert response["pred"][i] == "N/A" or response["pred"][i] in self.prompts.camera_angle_categories

    def post_process_camera_angle(self, response, anno):
        if len(response["pred"]) == 1 and response["pred"][0] == "N/A":
            return 0
        elif anno in response["pred"]:
            return 1
        else:
            return -1

    def post_validate_format_camera_movement(self, response, anno):
        # "{\"pred\": \"put your predicted category here\", \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        assert "pred" in response
        if response["pred"] == "N/A" or "N/A" in response["pred"]:
            response["pred"] = ["N/A"]
        if isinstance(response["pred"], str):
            response["pred"] = ast.literal_eval(response["pred"])
        assert isinstance(response["pred"], list)
        for i in range(len(response["pred"])):
            if response["pred"][i] in self.prompts.camera_movement_category_explains:
                response["pred"][i] = response["pred"].split(":")[0].lower()
            assert response["pred"][i] == "N/A" or response["pred"][i] in self.prompts.camera_movement_categories

    def post_process_camera_movement(self, response, anno):
        if len(response["pred"]) == 1 and response["pred"][0] == "N/A":
            return 0
        elif anno in response["pred"]:
            return 1
        else:
            return -1

    def post_validate_format_OCR(self, response, anno):
        # "{\"OCR\": \"copy the provided real OCR text here\", \"score\": put your score here, \"reason\": \"give your reason here\"},\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["OCR"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_OCR(self, response, anno):
        return response["score"]

    def post_validate_format_style(self, response, anno):
        # "{\"pred\": \"put your predicted category here\", \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        assert "pred" in response
        if response["pred"] == "N/A" or "N/A" in response["pred"]:
            response["pred"] = ["N/A"]
        if isinstance(response["pred"], str):
            response["pred"] = ast.literal_eval(response["pred"])
        assert isinstance(response["pred"], list)
        for i in range(len(response["pred"])):
            if response["pred"][i] in self.prompts.style_category_explains:
                response["pred"][i] = response["pred"][i].split(":")[0].lower()
            assert response["pred"][i] == "N/A" or response["pred"][i] in self.prompts.style_categories

    def post_process_style(self, response, anno):
        if len(response["pred"]) == 1 and response["pred"][0] == "N/A":
            return 0
        elif anno in response["pred"]:
            return 1
        else:
            return -1

    def post_validate_format_character_identification(self, response, anno):
        # "{\"name\": \"copy the provided name here\", \"score\": \"put your score here\",  \"reason\": \"give your reason here\"}\n"\
        assert isinstance(response, dict)
        if self.strict_match:
            assert response["character_identification"].strip() == anno.strip()
        if response["score"] in ["-1", "0", "1"]:
            response["score"] = int(response["score"])
        assert response["score"] in [1, 0, -1]

    def post_process_character_identification(self, response, anno):
        return response["score"]

    def load_saved_records(self):
        if os.path.exists(self.save_path):
            with open(self.save_path, "r") as f:
                saved_responses = [json.loads(l.strip("\n")) for l in f.readlines()]
        else:
            saved_responses = []
        return saved_responses

    def call_gpt(self, system_prompt, user_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ]
        try:
            payload = {
                "model": self.eval_model,
                "messages": messages,
            }
            response = requests.post(API_URL, headers=self.headers, json=payload, timeout=60)
            response.raise_for_status()
            response = response.json()
        except Exception as e:
            eval_logger.info(f"Error calling {self.eval_model}: {e}")
            return None

        try:
            response_message = response["choices"][0]["message"]["content"].strip()
            return response_message
        except Exception as e:
            eval_logger.info(f"Error parsing {self.eval_model} response: {e}\nResponse: {response}")
            return None

    def call_and_parse_single_meaasge(self, file, system_prompt, user_prompt):
        response_message = self.call_gpt(system_prompt, user_prompt)
        if response_message is None:
            return None

        try:
            if "```json" in response_message:
                response_message = response_message.split("```json")[-1].split("```")[0].strip()
            if "```python" in response_message:
                response_message = response_message.split("```python")[-1].split("```")[0].strip()
            elif "```" in response_message:
                response_message = response_message.split("```")[1].strip()
            response = ast.literal_eval(response_message)
            return response
        except (SyntaxError, ValueError) as e:
            eval_logger.info(f"Invalid response format for {file}: {response_message}")
            return None

    def evaluate_sample_worker(self, args):
        file, anno, system_prompt, user_prompt = args
        if isinstance(user_prompt, list):
            response = {"response": []}
            for prompt in user_prompt:
                single_response = self.call_and_parse_single_meaasge(file, system_prompt, prompt)
                if single_response is None:
                    return None
                response["response"].append(single_response)

        else:
            response = self.call_and_parse_single_meaasge(file, system_prompt, user_prompt)
            if response is None:
                return None

        try:
            self.post_validate_format_func(response, anno)
        except Exception as e:
            eval_logger.info(f"Format validation failed for {file}: {e}, anno: {anno}, response: {response}")
            return None

        response["file_id"] = file
        return response

    def evaluate_scores(self):
        score_dict = {}
        # Load saved records for resuming evaluation
        if self.auto_resume:
            saved_responses = self.load_saved_records()
            eval_logger.info(f"[{self.task}] Loaded {len(saved_responses)} records")
        else:
            saved_responses = []

        buffer = []
        buffer_size = 100
        try:
            # Evaluate remaining
            for retry_count in range(self.max_retry_times + 1):
                saved_files = [r["file_id"] for r in saved_responses]
                if len(saved_files) == len(self.results):
                    break
                if len(self.results) - len(saved_files) <= self.max_allow_missing:
                    break

                remaining_results = [r for r in self.results if r["file_id"] not in saved_files]
                if retry_count != 0:
                    print(f"\nRetrying {retry_count} times")

                process_args = []
                for res in remaining_results:
                    file = res["file_id"]
                    caption = res["caption"]
                    anno = res["annotation"]
                    system_prompt, user_prompt = self.prompts.get_prompts_by_task(self.task, caption, anno)
                    args = (file, anno, system_prompt, user_prompt)
                    process_args.append(args)

                if self.num_process == 0:
                    for args in tqdm(process_args, desc=f"Evaluating {self.task}"):
                        response = self.evaluate_sample_worker(args)
                        if response is not None:
                            with open(self.save_path, "a") as f:
                                f.write(json.dumps(response) + "\n")
                            saved_responses.append(response)
                else:
                    with ThreadPoolExecutor(max_workers=self.num_process) as executor:
                        futures = {executor.submit(self.evaluate_sample_worker, arg): arg for arg in process_args}
                        buffer_counter = 0
                        for future in tqdm(as_completed(futures), total=len(remaining_results), desc=f"Evaluating {self.task}"):
                            result = future.result()
                            if result is not None:
                                buffer.append(json.dumps(result) + "\n")
                                buffer_counter += 1
                                if buffer_counter >= buffer_size:
                                    with open(self.save_path, "a") as f:
                                        f.writelines(buffer)
                                    buffer.clear()
                                    buffer_counter = 0

                                saved_responses.append(result)

                        if len(buffer) > 0:
                            with open(self.save_path, "a") as f:
                                f.writelines(buffer)
                            buffer.clear()

        finally:
            if len(buffer) > 0:
                with open(self.save_path, "a") as f:
                    f.writelines(buffer)
                buffer.clear()

        for response in tqdm(saved_responses, desc=f"Calculating {self.task} scores"):
            file = response["file_id"]
            score_dict[file] = self.post_process_func(response, self.file2anno[file])

        return score_dict

    def calculate_metric(self, score_dict):
        all_scores = []
        for file_id, scores in score_dict.items():
            if isinstance(scores, list):
                all_scores += scores
            else:
                all_scores.append(scores)
        all_scores = np.array(all_scores)
        sum_count = len(all_scores)
        hit_count = np.count_nonzero(all_scores != 0)
        correct_count = np.count_nonzero(all_scores == 1)
        precision = 0 if hit_count == 0 else 100 * correct_count / hit_count
        recall = 100 * correct_count / sum_count
        hit_rate = 100 * hit_count / sum_count
        f1_score = 0 if precision == 0 else 2 * precision * recall / (precision + recall)
        eval_logger.info(f"[{self.task}] all: {sum_count}, hit: {hit_count}, correct: {correct_count}")
        return {"precision": precision, "recall": recall, "hit_rate": hit_rate, "f1_score": f1_score}


# Directly run this file to evaluate existing inference record
if __name__ == "__main__":
    results_dir = "logs/capability_results/llava_onevision_7b/inference"
    save_dir = "logs/capability_results/llava_onevision_7b/evaluation"
    os.makedirs(save_dir, exist_ok=True)

    tasks = ["object_category", "object_number", "object_color", "spatial_relation", "scene", "camera_angle", "OCR", "style", "character_identification", "dynamic_object_number", "action", "camera_movement", "event"]

    metrics = []
    for task in tasks:
        with open(os.path.join(results_dir, f"{task}.jsonl"), "r") as f:
            result = [json.loads(l.strip()) for l in f.readlines()]
        save_path = os.path.join(save_dir, f"{task}.jsonl")
        eval_model = config["metadata"]["eval_model_name"]
        num_process = config["metadata"]["eval_num_process"]
        max_allow_missing = config["metadata"]["eval_max_allow_missing"]
        max_retry_times = config["metadata"]["eval_max_retry_times"]
        auto_resume = config["metadata"]["eval_auto_resume"]
        strict_match = config["metadata"]["eval_strict_match"]
        evaluator = Evaluator(task, result, save_path, eval_model, headers, num_process, max_allow_missing, max_retry_times, auto_resume, strict_match)
        score_dict = evaluator.evaluate_scores()
        metric = evaluator.calculate_metric(score_dict)
        metrics.append(metric)
        eval_logger.info(f"[{task}] " + ", ".join([f"{k}: {v:.1f}" for k, v in metric.items()]))

    # summarize metrics
    eval_logger.info("Summarized Results:")
    avg_precision = np.mean([m["precision"] for m in metrics])
    avg_recall = np.mean([m["recall"] for m in metrics])
    avg_hit_rate = np.mean([m["hit_rate"] for m in metrics])
    avg_f1_score = np.mean([m["f1_score"] for m in metrics])
    eval_logger.info(f"Average precision: {avg_precision:.3f}, recall: {avg_recall:.3f}, f1_score: {avg_f1_score:.3f}, hit_rate: {avg_hit_rate:.3f}")
