import re
import json
import copy
from typing import List

from vita.config import DEFAULT_LLM_EVALUATOR, DEFAULT_LLM_EVALUATOR_ARGS
from vita.data_model.message import UserMessage, SystemMessage, Message
from vita.data_model.simulation import NLAssertionCheck, RewardInfo
from vita.data_model.tasks import RewardType, Task, EvaluationCriteria, StoreBaseModel, ProductBaseModel
from vita.utils.llm_utils import generate
from vita.utils import str_to_datetime, evaluator_extracter, get_weekday
from vita.prompts import get_prompts


class RubricEvaluator:
    """
    Judge that evaluates whether a trajectory adheres to all the natural-language rubrics.
    """

    @classmethod
    def calculate_reward(
        cls,
        rubric_type: str,
        task: Task,
        final_state: dict,
        final_messages: list[Message],
        llm_evaluator: str = None,
        llm_args_evaluator: dict = None,
        language: str = None,
    ) -> RewardInfo:
        """
        Calculate the reward for the simulation by using an LLM to evaluate whether the trajectory adheres to all the natural-language rubrics
        """
        if task.evaluation_criteria is None:
            return RewardInfo(
                reward=1.0,
                nl_rubrics=[],
                info={"note": "No evaluation criteria"},
                reward_breakdown={RewardType.NL_ASSERTION: 1.0},
            )
        nl_rubrics = task.evaluation_criteria
        if not nl_rubrics:
            return RewardInfo(
                reward=1.0,
                nl_rubrics=[],
                info={"note": "No nl_rubrics to evaluate"},
                reward_breakdown={RewardType.NL_ASSERTION: 1.0},
            )

        store_dict = StoreBaseModel.get_all_stores()
        product_dict = ProductBaseModel.get_all_products()
        env_info = {}
        env_info["system_time"] = task.environment.get("time", "")
        env_info["database"] = []
        final_state_updated = final_state.get("new_states", [])
        for state in final_state_updated:
            stores = store_dict.get(state.store_id, {}) if hasattr(state, "store_id") else store_dict.get(state.shop_id, {})
            env_info["database"].append(stores)
            if hasattr(state, "products"):
                for product in state.products:
                    if hasattr(product, 'product_id'):
                        product_id = product.product_id
                    else:
                        product_id = product.get("product_id")
                    env_info["database"].append(product_dict.get(product_id, {}))
        final_message = final_messages[-1].content
        env_info["system_time"] = task.environment.get("time", "") + " " + get_weekday(task.environment.get("time", ""), language)

        nl_rubric_checks = cls.evaluate_nl_rubrics(
            env_info, final_state_updated, task.evaluation_criteria, final_message, rubric_type,
            llm_evaluator, llm_args_evaluator, language
        )

        all_expectations_met = all(result.met for result in nl_rubric_checks) and len(nl_rubric_checks) > 0
        rubric_score = sum(1.0 if result.met else 0.0 for result in nl_rubric_checks) / len(nl_rubric_checks)
        reward = 1.0 if all_expectations_met else 0.0

        return RewardInfo(
            reward=reward,
            nl_rubrics=nl_rubric_checks,
            reward_breakdown={RewardType.NL_ASSERTION: rubric_score},
        )

    @classmethod
    def evaluate_nl_rubrics(
        cls,
        env_info: dict,
        final_state: list,
        rubrics: EvaluationCriteria,
        final_message: str,
        rubric_type: str,
        llm_evaluator: str = None,
        llm_args_evaluator: dict = None,
        language: str = None,
    ) -> list[NLAssertionCheck]:
        """
        Evaluate whether the trajectory meets each expected outcome.

        Parameters:
            env_info (dict): Information about the environment, such as time and database.
            final_state (list): List of final state Order objects to be evaluated.
            rubrics (EvaluationCriteria): The evaluation criteria, including expected states and overall rubrics.

        Returns:
            List of evaluation results for each NL assertion, containing:
            - nl_rubric: The NL assertion being evaluated
            - meetExpectation: Boolean indicating if the assertion was met
            - reasoning: Explanation for the evaluation
        """
        # 使用配置的评估模型，如果没有提供则使用默认值
        if llm_evaluator is None:
            llm_evaluator = DEFAULT_LLM_EVALUATOR
        if llm_args_evaluator is None:
            llm_args_evaluator = DEFAULT_LLM_EVALUATOR_ARGS

        if "<summary>" in final_message and "</summary>" in final_message:
            summary = re.findall(r"<summary>(.*?)</summary>", final_message, re.DOTALL)[0]
        else:
            summary = final_message
        summary = "<summary>" + summary + "</summary>"

        final_state_str = "<final_state>" + json.dumps([order.model_dump() for order in final_state], ensure_ascii=False) + "</final_state>"

        expected_states_str = "<expected_states>" + json.dumps([state.model_dump() for state in rubrics.expected_states], ensure_ascii=False) + "</expected states>"

        overall_rubrics_str = "<overall_rubrics>" + json.dumps(rubrics.overall_rubrics, ensure_ascii=False) + "</overall_rubrics>"

        prompts = get_prompts(language)
        result_data = []
        if rubric_type == "separate":
            state_system_prompt = prompts.state_eval_prompt

            state_system_prompt = state_system_prompt.format(
                env_info=env_info,
            )

            state_user_prompt = f"""
            # Input
            {final_state_str}
            {expected_states_str}
            """

            state_messages = [
                SystemMessage(role="system", content=state_system_prompt),
                UserMessage(role="user", content=state_user_prompt),
            ]
            print(f"state_system_prompt = {state_system_prompt}\n")
            print(f"state_user_prompt = {state_user_prompt}\n")

            state_assistant_message = generate(
                model=llm_evaluator,
                messages=state_messages,
                **llm_args_evaluator,
            )
            print(f"state_assistant_message.content = {state_assistant_message.content}")

            result_data += evaluator_extracter(state_assistant_message.content)
            if len(result_data) == 0:
                result = {
                    "rubrics": expected_states_str,
                    "reasoning": state_assistant_message.content,
                    "meetExpectation": False if "false" in state_assistant_message.content else True,
                }
                result_data = [result]

            if rubrics.overall_rubrics:
                summary_system_prompt = prompts.summary_eval_prompt
                summary_system_prompt = summary_system_prompt.format(
                    env_info=env_info,
                )

                summary_user_prompt = f"""
                # Input
                {final_state_str}
                {summary}
                {overall_rubrics_str}
                """

                summary_messages = [
                    SystemMessage(role="system", content=summary_system_prompt),
                    UserMessage(role="user", content=summary_user_prompt),
                ]
                print(f"summary_system_prompt = {summary_system_prompt}\n")
                print(f"summary_user_prompt = {summary_user_prompt}\n")

                summary_assistant_message = generate(
                    model=llm_evaluator,
                    messages=summary_messages,
                    **llm_args_evaluator,
                )
                print(f"summary_assistant_message.content = {summary_assistant_message.content}")

                summary_result_data = evaluator_extracter(summary_assistant_message.content)
                if len(summary_result_data) == 0:
                    summary_result = {
                        "rubrics": overall_rubrics_str,
                        "reasoning": summary_assistant_message.content,
                        "meetExpectation": False if "false" in summary_assistant_message.content else True,
                    }
                    summary_result_data = [summary_result]

                result_data += summary_result_data
        elif rubric_type == "combined":
            combined_system_prompt = prompts.combined_eval_prompt
            combined_system_prompt = combined_system_prompt.format(
                env_info=env_info,
            )

            combined_user_prompt = f"""
            # Input
            {final_state_str}
            {summary}
            {expected_states_str}
            {overall_rubrics_str}
            """

            combined_messages = [
                SystemMessage(role="system", content=combined_system_prompt),
                UserMessage(role="user", content=combined_user_prompt),
            ]
            print(f"combined_system_prompt = {combined_system_prompt}\n")
            print(f"combined_user_prompt = {combined_user_prompt}\n")

            combined_assistant_message = generate(
                model=llm_evaluator,
                messages=combined_messages,
                **llm_args_evaluator,
            )
            print(f"combined_assistant_message.content = {combined_assistant_message.content}")

            result_data += evaluator_extracter(combined_assistant_message.content)
            if len(result_data) == 0:
                result = {
                    "rubrics": overall_rubrics_str,
                    "reasoning": combined_assistant_message.content,
                    "meetExpectation": False if "false" in combined_assistant_message.content else True,
                }
                result_data = [result]

        return [
                NLAssertionCheck(
                    nl_rubric=result["rubrics"],
                    met=result["meetExpectation"],
                    justification=result["reasoning"],
                )
                for result in result_data
            ]