import json
import os

import numpy as np

from lita.core.protos import Role
from lita.core.tools import Tool
from lita.utils import read_jsonl


def record_quantile(data: list, qt: list = None) -> dict[str, float]:
    if qt is None:
        qt = [0.25, 0.5, 0.75]

    m = min(data)
    M = max(data)

    data = np.array(data)
    quantiles = np.quantile(data, qt)
    ret = {f"quantile_{int(q * 100)}": float(value) for q, value in zip(qt, quantiles, strict=False)}

    avg = float(np.mean(data))
    ret["average"] = avg
    ret["min"] = m
    ret["max"] = M

    return ret


# TODO: may need to deal with different tool using for different agents
class Analytics:
    class TextFileEditorMessage:
        apply_diff_failed = "Apply diff failed"
        wrong_diff_format = "No valid diff blocks"
        mismatch_diff_content = "The string you want to replace"
        edit_success = "Edited the file."

    class TerminalBashMessage:
        terminal_success = "Command return code: 0"

    def __init__(self, traj_dir: str, time_stamp: str, tools: list[Tool], **kwargs):
        # Common data for all cases
        # Recommend: model, temperature, max_num_turns
        self.commons = {}
        for key, value in kwargs.items():
            self.commons[key] = value

        self.unit_test_strings = kwargs.get(
            "unit_test_strings",
            [
                "pytest",
                "python",
                "npm-test.sh",
                "node",
                "gradlew",
                "cpp-test.sh",
                "cargo test",
                "go test",
            ],
        )
        self.tools = [tool.model_fields['name'].default for tool in tools]
        self.map_tool_names()

        self.target_edit = kwargs.get("target_edit", 6)
        self.target_test = kwargs.get("target_test", 2)

        # For each file / test case
        self.data = {}
        self.traj_dir = traj_dir
        self.time_stamp = time_stamp

    def map_tool_names(self):
        tool_name_mapping = {
            "search": "search",
            "finish": "finish",
            "think": "think",
            "plan": "planning"
        }
        if "str_replace_editor" in self.tools:
            tool_name_mapping["text_file_editor"] = "str_replace_editor"
        else:
            tool_name_mapping["text_file_editor"] = "text_file_editor"

        if "bash" in self.tools:
            tool_name_mapping["terminal_bash"] = "bash"
        else:
            tool_name_mapping["terminal_bash"] = "execute_bash"

        self.tool_names = tool_name_mapping

    def set_item(self, key: str, value: str):
        self.data[key] = value

    def get_items(self) -> dict[str, str]:
        return self.data

    def save_items(self, save_path: str):
        self.data.update(self.commons)
        with open(save_path, "w") as f:
            json.dump(self.data, f, indent=4)

    def get_token_count(self, messages: list[dict], individual: bool = False):
        # NOTE: system prompt may need special processing
        input_tokens = []
        output_tokens = []
        total_input = 0
        total_output = 0
        total_tokens = 0
        actual_turns = 0
        # To compare with context len
        last_input_token = 0
        for message in messages:
            # Perform analysis on each message
            if message["role"] == Role.ASSISTANT:
                usage = message.get("usage", {})
                ot = usage.get("completion_tokens", 0)
                total_output += ot
                output_tokens.append(ot)

                it = usage.get("prompt_tokens", 0)
                total_input += it
                input_tokens.append(it)
                last_input_token = it

                tt = usage.get("total_tokens", 0)
                total_tokens += tt

                actual_turns += 1

        self.data["token_count"] = {
            "total_input_tokens": total_input,
            "total_output_tokens": total_output,
            "total_tokens": total_tokens,
            "last_input_token": last_input_token,
            "tokens_for_each_tool": {},
        }

        # Include token count in each single turn
        if individual:
            self.data["token_count"]["input_tokens"] = input_tokens
            self.data["token_count"]["output_tokens"] = output_tokens

        self.data["actual_turns"] = actual_turns

    def analyze_tool_default(self, messages: list[dict], tool_name_canonical: str):
        tool_name = self.tool_names.get(tool_name_canonical)
        # Analyze the default tool interactions
        total_count = 0
        token_count = 0
        # Count total tokens the model used when calling this tool (only count input)
        # token_used = 0
        for message in messages:
            if "tool_calls" in message and message["tool_calls"]:
                for tool_call in message["tool_calls"]:
                    if tool_call.get("function", {}).get("name") == tool_name:
                        total_count += 1
                        token_count += message.get("usage", {}).get("total_tokens", 0)

        self.data["num_tool_calls"][tool_name_canonical] = total_count
        self.data["token_count"]["tokens_for_each_tool"][tool_name_canonical] = token_count

    def _text_editor_metrics_before_message(self, messages: list[dict], end: int) -> dict[str, int]:
        # Before a certain message
        # Note that `end` is not included
        total_count = 0
        token_count = 0
        apply_diff_failed = 0
        wrong_diff_format = 0
        mismatch_diff_content = 0
        json_parsing_error = 0
        successful_edits = 0
        num_edits = 0
        for i in range(end):
            message = messages[i]
            if "tool_calls" in message and message["tool_calls"]:
                for tool_call in message["tool_calls"]:
                    if tool_call.get("function", {}).get("name") == self.tool_names["text_file_editor"]:
                        total_count += 1
                        token_count += message.get("usage", {}).get("total_tokens", 0)

                        # Check tool output
                        tool_id = tool_call.get("id")
                        for next_message in messages[i + 1 :]:
                            if next_message.get("tool_call_id") == tool_id:
                                # Apply diff failed
                                if self.TextFileEditorMessage.apply_diff_failed in next_message["content"]:
                                    apply_diff_failed += 1
                                if self.TextFileEditorMessage.wrong_diff_format in next_message["content"]:
                                    wrong_diff_format += 1
                                if self.TextFileEditorMessage.mismatch_diff_content in next_message["content"]:
                                    mismatch_diff_content += 1

                                # Success
                                if self.TextFileEditorMessage.edit_success in next_message["content"]:
                                    successful_edits += 1

                                break

                        # Check args
                        args = tool_call.get("function", {}).get("arguments", {})
                        try:
                            args = json.loads(args)
                            # If using `operation` key
                            if args.get("operation", {}) == "edit":
                                num_edits += 1
                            # If using `command` key
                            if args.get("command", {}) == "str_replace":
                                num_edits += 1
                        except json.JSONDecodeError:
                            json_parsing_error += 1

        return {
            "successful_edits": successful_edits,
            "apply_diff_failed": apply_diff_failed,
            "wrong_diff_format": wrong_diff_format,
            "mismatch_diff_content": mismatch_diff_content,
            "json_parsing_error": json_parsing_error,
            "num_edits": num_edits,
            "total_count": total_count,
            "token_count": token_count,
        }

    def analyze_text_editor(self, messages: list[dict]):
        # Analyze the text editor interactions
        metrics = self._text_editor_metrics_before_message(messages, len(messages))
        total_count = metrics.pop("total_count")
        token_count = metrics.pop("token_count")

        self.data["text_file_editor"] = metrics
        self.data["num_tool_calls"]["text_file_editor"] = total_count
        self.data["token_count"]["tokens_for_each_tool"]["text_file_editor"] = token_count

        # Find the i-th edit to match another evaluation (e.g. Aider)
        target_edit = self.target_edit
        if metrics["num_edits"] >= target_edit:
            edit_index = None
            edit_count = 0
            for i, message in enumerate(messages):
                if "tool_calls" in message and message["tool_calls"]:
                    for tool_call in message["tool_calls"]:
                        if tool_call.get("function", {}).get("name") == self.tool_names["text_file_editor"]:
                            args = tool_call.get("function", {}).get("arguments", {})
                            try:
                                args = json.loads(args)
                                if args.get("operation", {}) == "edit":
                                    edit_count += 1
                                # Only `str_replace` will be counted, `insert` will be ignored
                                if args.get("command", {}) == "str_replace":
                                    edit_count += 1

                            except json.JSONDecodeError:
                                # Ignore this message here since we already counted, view it as an invalid message
                                pass

                            if edit_count == target_edit:
                                edit_index = i
                                break

                if edit_index is not None:
                    break

            metrics_before_this = self._text_editor_metrics_before_message(messages, edit_index + 1)

        else:
            # Edits with limit, search to the end
            edit_index = "last-edit"
            metrics_before_this = metrics.copy()

        self.data["text_file_editor"]["target_edit_index"] = edit_index
        self.data["text_file_editor"]["target_edit"] = target_edit
        self.data["text_file_editor"]["metrics_before_target_edit"] = metrics_before_this

    def analyze_terminal_bash(self, messages: list[dict]):
        # Analyze the terminal bash interactions
        total_count = 0
        token_count = 0
        terminal_error = 0
        unit_test = 0
        unit_test_failed = 0
        last_test_pass = False
        test_pass = []
        # Limit in non-agentic evaluation
        target_unit_test = 2
        target_test_index = None
        for i in range(len(messages)):
            message = messages[i]
            if "tool_calls" in message and message["tool_calls"]:
                for tool_call in message["tool_calls"]:
                    if tool_call.get("function", {}).get("name") == self.tool_names["terminal_bash"]:
                        total_count += 1
                        token_count += message.get("usage", {}).get("total_tokens", 0)

                        # Check tool output
                        tool_id = tool_call.get("id")
                        for next_message in messages[i + 1 :]:
                            if next_message.get("tool_call_id") == tool_id:
                                # Error execution
                                if self.TerminalBashMessage.terminal_success not in next_message["content"]:
                                    terminal_error += 1

                                break

                        # Especially when we run a unit test
                        if any(
                            unit_test_string in tool_call.get("function", {}).get("arguments", "")
                            for unit_test_string in self.unit_test_strings
                        ):
                            unit_test += 1
                            if unit_test == target_unit_test:
                                target_test_index = i

                            if self.TerminalBashMessage.terminal_success not in next_message["content"]:
                                unit_test_failed += 1
                                test_pass.append(False)
                                last_test_pass = False
                            else:
                                test_pass.append(True)
                                last_test_pass = True

        self.data["terminal_bash"] = {
            "terminal_error": terminal_error,
            "num_unit_tests": unit_test,
            "unit_test_failed": unit_test_failed,
            "unit_test_result": {
                "test_pass": test_pass,
                "last_test_pass": last_test_pass,
            },
            # If num of tests doesn't reach the bound
            "target_test_index": target_test_index or "last-test",
            "target_unit_test": target_unit_test,
        }
        self.data["num_tool_calls"]["terminal_bash"] = total_count
        self.data["token_count"]["tokens_for_each_tool"]["terminal_bash"] = token_count

    def analyze_search(self, messages: list[dict]):
        self.analyze_tool_default(messages, "search")

    def analyze_finish(self, messages: list[dict]):
        self.analyze_tool_default(messages, "finish")

    def analyze_think(self, messages: list[dict]):
        self.analyze_tool_default(messages, "think")

    def analyze_plan(self, messages: list[dict]):
        self.analyze_tool_default(messages, "plan")

    def coupled_tool_metrics(self, messages: list[dict]):
        target_test_index = self.data["terminal_bash"]["target_test_index"]
        self.data["editor_x_terminal"] = {}
        if isinstance(target_test_index, int):
            # Editor before the i-th unit test
            metrics_before_this = self._text_editor_metrics_before_message(messages, target_test_index + 1)
            self.data["editor_x_terminal"]["edit_before_target_test"] = metrics_before_this
        else:
            # If test within limit, directly reuse, but drop some fields
            metrics: dict = self.data["text_file_editor"].copy()
            metrics.pop("target_edit_index", None)
            metrics.pop("target_edit", None)
            metrics.pop("metrics_before_target_edit", None)
            self.data["editor_x_terminal"]["edit_before_target_test"] = metrics

    def total_tool_calls(self, messages: list[dict]):
        total_count = 0
        # token_used = 0
        for message in messages:
            if "tool_calls" in message and message["tool_calls"]:
                for tool_call in message["tool_calls"]:
                    # Including failed calls
                    total_count += 1

        self.data["num_tool_calls"]["total"] = total_count

    def analyze_tool_calls(self, messages: list[dict]):
        self.data["num_tool_calls"] = {}
        # Analyze the tool calls
        self.analyze_text_editor(messages)
        self.analyze_terminal_bash(messages)
        self.analyze_search(messages)
        self.analyze_finish(messages)
        self.analyze_think(messages)
        self.analyze_plan(messages)

        self.coupled_tool_metrics(messages)
        self.total_tool_calls(messages)

    def get_latency(self, messages: list[dict]):
        time_cost = 0.0
        for message in messages:
            if "response_ms" in message:
                time_cost += message["response_ms"]

        self.data["time_cost"] = time_cost

    def analyze_traj_for_one(self, messages: list[dict]):
        self.data = {}
        # Analyze the trajectory of messages
        self.get_token_count(messages)

        self.get_latency(messages)

        self.analyze_tool_calls(messages)

    def generate_analytics(self):
        # Load trajectories one by one
        traj_dir = os.path.join(self.traj_dir, self.time_stamp)
        for lang in os.listdir(traj_dir):
            lang_dir = os.path.join(traj_dir, lang)
            if os.path.isdir(lang_dir):
                for test_case in os.listdir(lang_dir):
                    if test_case.endswith(".jsonl"):
                        messages = read_jsonl(os.path.join(lang_dir, test_case))
                        self.analyze_traj_for_one(messages)

                        analytics_path = os.path.join(lang_dir, ".analytics")
                        if not os.path.exists(analytics_path):
                            os.makedirs(analytics_path)

                        self.save_items(os.path.join(analytics_path, test_case.replace(".jsonl", ".json")))

        # self.generate_report()

    def generate_report(self):
        # Statistics we want to record
        total_edits = 0
        successful_edits = 0
        total_edits_before_target_test = 0
        successful_edits_before_target_test = 0
        total_edits_before_target_edit = 0
        successful_edits_before_target_edit = 0

        editor_calls = []
        terminal_calls = []
        think_calls = []
        search_calls = []
        plan_calls = []
        finish_calls = []
        total_tool_calls = 0

        total_pass = 0
        total_cases = 0
        pass_at_2 = 0

        elapsed = 0
        total_tokens = []
        total_input_tokens = []
        total_output_tokens = []
        last_input_tokens = []
        special_cases = []
        actual_turns = []

        # Check each language
        traj_dir = os.path.join(self.traj_dir, self.time_stamp)
        for lang in os.listdir(traj_dir):
            lang_dir = os.path.join(traj_dir, lang)
            if os.path.isdir(lang_dir):
                analytics_path = os.path.join(lang_dir, ".analytics")
                # If analytics exists
                if os.path.exists(analytics_path):
                    # Check each test case
                    for test_case in os.listdir(analytics_path):
                        if test_case.endswith(".json"):
                            with open(os.path.join(analytics_path, test_case)) as f:
                                data = json.load(f)

                                # Text file editor stat
                                total_edits += data.get("text_file_editor", {}).get("num_edits", 0)
                                successful_edits += data.get("text_file_editor", {}).get("successful_edits", 0)
                                # If no such field, will not add
                                edits_before_target_test = data.get("editor_x_terminal", {}).get(
                                    "edit_before_target_test", {}
                                )
                                total_edits_before_target_test += edits_before_target_test.get("num_edits", 0)
                                successful_edits_before_target_test += edits_before_target_test.get(
                                    "successful_edits", 0
                                )

                                edits_before_target_edit = data.get("text_file_editor", {}).get(
                                    "metrics_before_target_edit", {}
                                )
                                total_edits_before_target_edit += edits_before_target_edit.get("num_edits", 0)
                                successful_edits_before_target_edit += edits_before_target_edit.get(
                                    "successful_edits", 0
                                )

                                # Terminal bash stat
                                unit_test_result = data.get("terminal_bash", {}).get("unit_test_result", {})
                                last_test_pass = unit_test_result.get("last_test_pass", "")
                                total_pass += 1 if last_test_pass else 0
                                total_cases += 1
                                test_pass = unit_test_result.get("test_pass", [])

                                if last_test_pass:
                                    # Last case passed and passed in 2 runs
                                    if len(test_pass) <= 2:
                                        pass_at_2 += 1
                                elif len(test_pass) <= 1:
                                    # Last case failed and test runs not enough
                                    special_cases.append(test_case)

                                # Num tool calls
                                num_tool_calls = data.get("num_tool_calls", {})
                                editor_calls.append(num_tool_calls.get("text_file_editor", 0))
                                terminal_calls.append(num_tool_calls.get("terminal_bash", 0))
                                think_calls.append(num_tool_calls.get("think", 0))
                                search_calls.append(num_tool_calls.get("search", 0))
                                plan_calls.append(num_tool_calls.get("plan", 0))
                                finish_calls.append(num_tool_calls.get("finish", 0))
                                total_tool_calls += num_tool_calls.get("total", 0)

                                # General stat
                                # Time we spent on calling LLM
                                elapsed += data.get("time_cost", 0.0)
                                # Total tokens used
                                token_count = data.get("token_count", {})
                                total_tokens.append(token_count.get("total_tokens", 0))
                                total_input_tokens.append(token_count.get("total_input_tokens", 0))
                                total_output_tokens.append(token_count.get("total_output_tokens", 0))
                                actual_turns.append(data.get("actual_turns", 0))
                                last_input_tokens.append(token_count.get("last_input_token", 0))

        # Report generation
        report = {
            "correct_format": successful_edits / total_edits if total_edits > 0 else 0,
            "pass_at_2": pass_at_2 / total_cases if total_cases > 0 else 0,
            "pass_at_max_turns": total_pass / total_cases if total_cases > 0 else 0,
            "correct_format_before_target_test": successful_edits_before_target_test / total_edits_before_target_test if total_edits_before_target_test > 0 else 0,
            "correct_format_before_target_edit": successful_edits_before_target_edit / total_edits_before_target_edit if total_edits_before_target_edit > 0 else 0,
            "special_cases": special_cases,
            "time_cost_ms": elapsed,
            "total_tokens": {
                "total": sum(total_tokens),
                **record_quantile(total_tokens),
            },
            "total_input_tokens": {
                "total": sum(total_input_tokens),
                **record_quantile(total_input_tokens),
            },
            "total_output_tokens": {
                "total": sum(total_output_tokens),
                **record_quantile(total_output_tokens),
            },
            "last_input_tokens": {
                **record_quantile(last_input_tokens),
            },
            "actual_turns": {
                **record_quantile(actual_turns),
            },
            "num_tool_calls": {
                "total": total_tool_calls,
                "text_file_editor": {
                    "total": sum(editor_calls),
                    **record_quantile(editor_calls),
                },
                "terminal_bash": {
                    "total": sum(terminal_calls),
                    **record_quantile(terminal_calls),
                },
                "think": {
                    "total": sum(think_calls),
                    **record_quantile(think_calls),
                },
                "search": {
                    "total": sum(search_calls),
                    **record_quantile(search_calls),
                },
                "plan": {
                    "total": sum(plan_calls),
                    **record_quantile(plan_calls),
                },
                "finish": {
                    "total": sum(finish_calls),
                    **record_quantile(finish_calls),
                },
            },
        }
        self.commons.update(report)

        with open(os.path.join(traj_dir, "report.json"), "w") as f:
            json.dump(self.commons, f, indent=4)
