# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

import json
import logging
import random
from pathlib import Path
from typing import Any, Dict, Literal, Optional

import tree_sitter_python as tspython
from tqdm import tqdm
from tree_sitter import Language, Parser

from camel.agents import ChatAgent
from camel.benchmarks.base import BaseBenchmark
from camel.utils import download_github_subdirectory

logger = logging.getLogger(__name__)


# Mapping of dataset names to file names
# 'Oracle' retriever used here which means all the full
# API documentation will be included in the prompt
dataset_mapping = {
    "huggingface": {
        "api": "huggingface_api.jsonl",
        "eval": "huggingface_eval.json",
        "train": "huggingface_train.json",
        "questions": "questions_huggingface_oracle.jsonl",
    },
    "tensorflowhub": {
        "api": "tensorflowhub_api.jsonl",
        "eval": "tensorflow_eval.json",
        "train": "tensorflow_train.json",
        "questions": "questions_tensorflowhub_oracle.jsonl",
    },
    "torchhub": {
        "api": "torchhub_api.jsonl",
        "eval": "torchhub_eval.json",
        "train": "torchhub_train.json",
        "questions": "questions_torchhub_oracle.jsonl",
    },
}


# This function is migrated from the original repo:
# https://github.com/ShishirPatil/gorilla
def encode_question(question: str, dataset_name: str) -> str:
    r"""Encode multiple prompt instructions into a single string."""

    if dataset_name == "torchhub":
        domains = "1. $DOMAIN is inferred from the task description and \
        should include one of {Classification, Semantic Segmentation, \
        Object Detection, Audio Separation, Video Classification, \
        Text-to-Speech}."
    elif dataset_name == "huggingface":
        domains = "1. $DOMAIN should include one of {Multimodal Feature \
            Extraction, Multimodal Text-to-Image, Multimodal \
            Image-to-Text, Multimodal Text-to-Video, \
            Multimodal Visual Question Answering, Multimodal Document \
            Question Answer, Multimodal Graph Machine Learning, \
            Computer Vision Depth Estimation, Computer Vision Image \
            Classification, Computer Vision Object Detection, \
            Computer Vision Image Segmentation, Computer Vision \
            Image-to-Image, Computer Vision Unconditional \
            Image Generation, Computer Vision Video Classification, \
            Computer Vision Zero-Shor Image Classification, \
            Natural Language Processing Text Classification, \
            Natural Language Processing Token Classification, \
            Natural Language Processing Table Question Answering, \
            Natural Language Processing Question Answering, \
            Natural Language Processing, Zero-Shot Classification \
            Natural Language Processing Translation, Natural Language \
            Processing Summarization, Natural Language Processing \
            Conversational, Natural Language Processing Text \
            Generation, Natural Language Processing Fill-Mask, \
            Natural Language Processing Text2Text Generation, \
            Natural Language Processing Sentence Similarity, \
            Audio Text-to-Speech, Audio Automatic Speech Recognition, \
            Audio Audio-to-Audio, Audio Audio Classification, \
            Audio Voice Activity Detection, Tabular Tabular \
            Classification, Tabular Tabular Regression, \
            Reinforcement Learning Reinforcement Learning, \
            Reinforcement Learning Robotics }"
    elif dataset_name == "tensorflowhub":
        domains = "1. $DOMAIN is inferred from the task description \
        and should include one of {text-sequence-alignment, \
        text-embedding, text-language-model, text-preprocessing, \
        text-classification, text-generation, text-question-answering, \
        text-retrieval-question-answering, text-segmentation, \
        text-to-mel, image-classification, image-feature-vector, \
        image-object-detection, image-segmentation, \
        image-generator, image-pose-detection, image-rnn-agent, \
        image-augmentation, image-classifier, image-style-transfer, \
        image-aesthetic-quality, image-depth-estimation, \
        image-super-resolution, image-deblurring, image-extrapolation, \
        image-text-recognition, image-dehazing, image-deraining, \
        image-enhancemenmt, image-classification-logits, \
        image-frame-interpolation, image-text-detection, image-denoising, \
        image-others, video-classification, video-feature-extraction, \
        video-generation, video-audio-text, video-text, \
        audio-embedding, audio-event-classification, audio-command-detection, \
        audio-paralinguists-classification, audio-speech-to-text, \
        audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
    else:
        logger.info("Error: API name is not supported.")

    prompt = (
        question
        + "\nWrite a python program in 1 to 2 lines to call API in "
        + dataset_name
        + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
        <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
        <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
        Here are the requirements:\n"
        + domains
        + "\n2. The $API_CALL should have only 1 line of code \
        that calls api.\n 3. The $API_PROVIDER should be the \
        programming framework used.\n4. $EXPLANATION should be \
        a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
        Do not repeat the format in your answer."
    )
    return prompt


class APIBenchBenchmark(BaseBenchmark):
    r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
    Connected with Massive APIs`
    <https://huggingface.co/datasets/gorilla-llm/APIBench>.

    Args:
        data_dir (str): The directory to save the data.
        save_to (str): The file to save the results.
        processes (int, optional): The number of processes to use.
            (default: :obj:`1`)
    """

    # TODO: Integrate retriever (pending)

    def __init__(
        self,
        data_dir: str,
        save_to: str,
        processes: int = 1,
    ):
        r"""Initialize the APIBench benchmark.

        Args:
            data_dir (str): The directory to save the data.
            save_to (str): The file to save the results.
            processes (int, optional): The number of processes to use for
                parallel processing. (default: :obj:`1`)
        """
        super().__init__("apibench", data_dir, save_to, processes)

    def download(self):
        r"""Download the APIBench dataset."""
        from huggingface_hub import snapshot_download

        snapshot_download(
            repo_id="gorilla-llm/APIBench",
            repo_type="dataset",
            local_dir=self.data_dir,
            local_dir_use_symlinks=True,
        )

        repo = "ShishirPatil/gorilla"
        subdir = "/gorilla/eval/eval-data/questions"
        data_dir = self.data_dir

        download_github_subdirectory(repo, subdir, data_dir)

    def load(self, dataset_name: str, force_download: bool = False):  # type: ignore[override]
        r"""Load the APIBench Benchmark dataset.

        Args:
            dataset_name (str): Name of the specific dataset to be loaded.
            force_download (bool, optional): Whether to force
                download the data. (default: :obj:`False`)
        """

        if force_download:
            logger.info("Force downloading data.")
            self.download()

        def load_json_lines(file_path: Path):
            r"""Helper function to load JSON lines from a file."""
            try:
                with open(file_path, "r") as f:
                    return [json.loads(line) for line in f]
            except FileNotFoundError:
                raise FileNotFoundError(f"File not found: {file_path}")
            except json.JSONDecodeError as e:
                raise ValueError(
                    f"Error decoding JSON in file {file_path}: {e}"
                )

        dataset_path = self.data_dir / dataset_name
        if not dataset_path.exists():
            raise FileNotFoundError(
                f"Dataset directory does not exist: {dataset_path}"
            )

        for label in ['api', 'eval', 'questions']:
            file_name = dataset_mapping[dataset_name][label]
            file_path = (
                dataset_path / file_name
                if label == 'questions'
                else self.data_dir / file_name
            )

            # Load data based on label type
            if label in ['api', 'questions', 'eval']:
                data = load_json_lines(file_path)

                if label == 'eval':
                    # Extract 'api_data' specifically for eval label
                    data = [item['api_data'] for item in data]

                self._data[label] = data
            else:
                raise ValueError(f"Unknown label: {label}")

        ast_database = []
        for data in self._data['api']:
            ast_tree = ast_parse(data['api_call'])
            ast_database.append(ast_tree)
        self._data['ast'] = ast_database

    def run(  # type: ignore[override]
        self,
        agent: ChatAgent,
        dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
        randomize: bool = False,
        subset: Optional[int] = None,
    ) -> Dict[str, Any]:
        r"""Run the benchmark.

        Args:
            agent (ChatAgent): The agent to run the
                benchmark.
            dataset_name (Literal["huggingface",
                "tensorflowhub", "torchhub"]):
                The dataset to run the benchmark.
            randomize (bool, optional): Whether to randomize the data.
                (default: :obj:`False`)
            subset (Optional[int], optional): The subset of data to run.
                (default: :obj:`None`)
        """

        if dataset_name not in dataset_mapping:
            raise ValueError(f"Invalid value for dataset: {dataset_name}.")

        logger.info(f"Running APIBench benchmark on {dataset_name}.")
        self.load(dataset_name)
        datas = self._data['questions']

        # Shuffle and subset data if necessary
        if randomize:
            random.shuffle(datas)
        if subset:
            datas = datas[:subset]

        logger.info(f"Number of tasks: {len(datas)}")

        # Initialize results storage
        self._results = []

        with open(self.save_to, "w") as f:
            for question in tqdm(datas, desc="Running"):
                prompt = encode_question(question["text"], dataset_name)
                try:
                    # Generate response
                    responses = agent.step(prompt)
                    response = responses.msgs[0].content
                    api_database = self._data['api']
                    qa_pairs = self._data['eval']
                    ast_database = self._data['ast']
                    question_id = question['question_id']

                    # Evaluate response
                    error, correct, hallucination = evaluate_response(
                        response,
                        question_id,
                        dataset_name,
                        api_database,
                        qa_pairs,
                        ast_database,
                    )
                    self._results.append(
                        {
                            "question": question,
                            "agent_response": response,
                            "correct": correct,
                            "hallucination": hallucination,
                            "error": str(error) if error else None,
                        }
                    )
                except Exception as e:
                    logger.warning(
                        f"Error in processing task: {question}: {e}"
                    )
                    self._results.append(
                        {
                            "question": question,
                            "agent_response": None,
                            "correct": False,
                            "hallucination": False,
                            "error": str(e),
                        }
                    )

                agent.reset()

                json_str = json.dumps(
                    self._results[-1], indent=2, ensure_ascii=False
                )
                f.write(json_str + "\n")
                f.flush()

        total = len(self._results)
        correct = sum(r["correct"] for r in self.results)
        hallucination = sum(r["hallucination"] for r in self.results)

        return {
            "total": total,
            "correct": correct,
            "hallucination": hallucination,
            "accuracy": correct / total if total else "N/A",
            "hallucination rate": hallucination / total if total else "N/A",
        }


# This code is modified from the
# evaluators in the original repo
# https://github.com/ShishirPatil/gorilla
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
    node_stack = []
    sub_tree_sexp_list = []
    depth = 1
    # text = root_node.text
    node_stack.append([root_node, depth])
    while len(node_stack) != 0:
        cur_node, cur_depth = node_stack.pop()
        if cur_node.child_count > 0:
            sub_tree_sexp_list.append(
                [
                    str(cur_node),
                    cur_depth,
                    cur_node,
                    cur_node.children[0].text,
                ]
            )
        else:
            sub_tree_sexp_list.append(
                [str(cur_node), cur_depth, cur_node, None]
            )
        for child_node in cur_node.children:
            if len(child_node.children) != 0:
                depth = cur_depth + 1
                node_stack.append([child_node, depth])
    return sub_tree_sexp_list


# Parse the program into AST trees
def ast_parse(candidate):
    PY_LANGUAGE = Language(tspython.language())
    parser = Parser(PY_LANGUAGE)

    candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
    return candidate_tree


# Get all the arguments in the ast tree
def get_args(node, dataset_name):
    if node.child_count == 0:
        return []
    args_list = []
    if dataset_name == "huggingface":
        for child in node.children[0].children[0].children[1].children:
            if "=" in child.text.decode():
                args_list.append(child.children[2].text)
            elif (
                child.text.decode() != "("
                and child.text.decode() != ")"
                and child.text.decode() != ","
            ):
                args_list.append(child.text)
    elif dataset_name == "tensorflowhub":
        for child in node.children[0].children[0].children[1].children:
            if (
                'model=' in child.text.decode()
                or 'model =' in child.text.decode()
            ):
                args_list.append(child.children[2].text)
            elif (
                child.text.decode() != "("
                and child.text.decode() != ")"
                and child.text.decode() != ","
            ):
                args_list.append(child.text)
    elif dataset_name == "torchhub":
        for child in node.children[0].children[0].children[1].children:
            if (
                "repo_or_dir" in child.text.decode()
                or "model" in child.text.decode()
            ):
                args_list.append(child.children[2].text)
    return args_list


# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
    for idx, base_tree in enumerate(base_tree_list):
        if base_tree.children[0].children[0].child_count == 0:
            continue
        api_name = base_tree.children[0].children[0].children[0].text
        for candidate_tree in candidate_subtree_list:
            if candidate_tree[3] == api_name:
                break
        # Now we have a sub-tree
        candidate_tree = candidate_tree[2]
        args_list = get_args(base_tree, dataset_name)
        if len(args_list) == 0:
            continue
        ast_match = True
        for arg in args_list:
            if (
                arg.decode().lstrip("'").rstrip("'")
                not in candidate_tree.text.decode()
            ):
                ast_match = False
                break
        if ast_match:
            return idx
    return -1


def evaluate_response(
    response, question_id, dataset_name, api_database, qa_pairs, ast_database
):
    try:
        # Index the "api_call" domain
        output = response.split("api_call")
        if len(output) == 1:
            api_call = output[0]
        else:
            # Parse the output
            output = output[1].split("api_provider")[0]
            if ":" not in output:
                start = 0
            else:
                start = output.index(":")
            if ")" not in output:
                end = -2
            else:
                end = output.rindex(")")
            api_call = output[start + 2 : end + 1]

        try:
            ast_tree = ast_parse(api_call)
        except Exception as parse_error:
            print(f"Error parsing api_call: {api_call}, error: {parse_error}")
            return parse_error, False, False
        # Search for a subtree
        ast_subtree_list = get_all_sub_trees(ast_tree)
        # Check which ast tree is matching
        database_index = ast_check(
            ast_subtree_list, ast_database, dataset_name
        )
        # We cannot index this ast in our database
        if database_index == -1:
            halluncination = True
            correct = False
        # We index our reference api_call
        ref_api_call = api_database[database_index]
        # Check for functionality
        if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
            correct = True
            halluncination = False
        else:
            return None, False, False
    except Exception as e:
        print(f'Error parsing response: {response}, error: {e}')
        return e, False, False

    return None, correct, halluncination
