from __future__ import annotations
import json
import re

import os
import sys

workbench_path = '../WorkBench'
sys.path.append(workbench_path)

from typing import Any, Callable, Dict, List
from tqdm import tqdm

from src.data_generation.data_generation_utils import HARDCODED_CURRENT_TIME
from generate_results_openai_using_class import build_tools, _fmt_call
from jinja2 import Template
from openai import OpenAI
from openai.types.chat import ChatCompletionMessage
from scripts.inference.prompts import *
from src.evals.utils import calculate_metrics_single, calculate_metrics_single_in_class, execute_actions_and_return_result_in_class
import ast
from concurrent.futures import ProcessPoolExecutor
import pickle
import inflection

# =====================
max_workers = 8
exec_temperature = 0.8
num_repetitions = 3
model_name = "Qwen2.5-14B-Instruct"
base_url = "http://localhost:6000/v1"
API_KEY = ""
compress_ratio_threshold = 0.7


DEBUG = False
print('DEBUG', DEBUG)
with open('./data/split/leaf_nodes_train.pkl', 'rb') as fp:
    leaf_nodes = pickle.load(fp)
for k, v in leaf_nodes.items():
    v['id'] = k


ALL_EVAL_RESULTS = []

os.environ['TOKENIZERS_PARALLELISM'] = "false"
# ======================

def make_correct_ratio_reward_func(agent_type_inner):
    def reward_func(completions, **kwargs):
        return correct_ratio_reward_func(
            completions, fixed_agent_type=agent_type_inner, **kwargs
        )
    reward_func.__name__ = f"correct_ratio_reward_func_{agent_type_inner}"
    reward_func.__qualname__ = f"correct_ratio_reward_func{agent_type_inner}"
    return reward_func



def count_static_tool_calls(sample):
    workflow = json.loads(sample)
    num = 0
    for step in workflow:
        if step.get('from') == 'function_call':
            num += 1
        elif step.get('type') == 'static':
            num += 1
    return num



class MetaflowRunner:
    
    def __init__(self,  domains:str, base_url: str, openai_api_key: str, model_name:str, temperature:float=1.0, agent_type='react', ICL_information="", max_agent_steps = 10):
        self.context = {}
        self.history = []

        self.function_calls: List[str] = []

        self.openai_client = OpenAI(base_url=base_url, api_key=openai_api_key)
        self.model_name = model_name

        if isinstance(domains, str):
            domains = domains.strip("[]").replace("'", "").split(", ")
            self.tools, self.lookup = build_tools(domains)
        elif isinstance(domains, list):
            self.tools, self.lookup = build_tools(domains)
        else:
            raise ValueError("domains must be a string or list of strings")

        assert len(self.lookup) == len(domains)
        self.tool_instances = {}  # This will store tool instances
        for class_name, tool_class in self.lookup.items():
            self.tool_instances[class_name] = tool_class()

        self.reflection_num = 3

        self.system_prompt =  (
        f"You are a helpful assistant and your task is to complete the user's query using the given tools."
        f"Today's date is {HARDCODED_CURRENT_TIME.strftime('%A, %B %d, %Y')} "
        f"and the current time is {HARDCODED_CURRENT_TIME.strftime('%H:%M')}."
        "Please remember the current date and time when answering queries. Please use today's date for all dates unless otherwise specified. Meetings must not start before 9am or end after 6pm."
        "Please respond helpfully and accurately to the user."
        "All the information is in the given query or you can get it using the given tools. Please do not assume or will ask for my help."
    )


        self.dynamic_agnet_prompt = (
        f"You are a step-by-step task-solving assistant. At each stage, you are only allowed to complete the current subtask based on the available variables, tools and instructions."
        f"Today's date is {HARDCODED_CURRENT_TIME.strftime('%A, %B %d, %Y')} "
        f"and the current time is {HARDCODED_CURRENT_TIME.strftime('%H:%M')}."
        "Please remember the current date and time. Please use today's date for all dates unless otherwise specified. Meetings must not start before 9am or end after 6pm."
        "Please respond helpfully and accurately to the user."
        "All the information is in the given query or you can get it using the given tools. Please do not assume or will ask for my help."
    )

        self.temperature = temperature
        self.agent_type = agent_type
        self.ICL_information = ICL_information

        self.max_agent_steps = max_agent_steps
        self.tool_calls_by_LLM = 0
        self.LLM_api_calls = 0


    def _record_history(
            self,
            role: str,
            content: str | None,
            *,
            name: str | None = None,
            tool_calls: List[Dict] | None = None,
            tool_call_id: str | None = None
    ):
        """记录符合OpenAI消息规范的执行历史

        Args:
            role: 消息角色 (user/assistant/tool/system)
            content: 文本内容（工具调用时为None）
            name: 工具名称（仅tool角色需要）
            tool_calls: 工具调用列表（仅assistant角色需要）
            tool_call_id: 工具调用ID（仅tool角色需要）
        """
        entry = {"role": role}

        # 内容处理
        if content is not None:
            entry["content"] = content

        # 工具调用特殊处理
        if role == "assistant" and tool_calls:
            entry["tool_calls"] = [
                {
                    "id": call["id"],
                    "type": "function",
                    "function": {
                        "name": call["function"]["name"],
                        "arguments": json.dumps(call["function"]["arguments"])
                    }
                } for call in tool_calls
            ]

        # 工具响应处理
        elif role == "tool":
            if not tool_call_id:
                raise ValueError("tool_call_id is required for tool role")
            if not name:
                raise ValueError("name is required for tool role")

            entry["content"] = content or ""
            entry["tool_call_id"] = tool_call_id
            entry["name"] = name

        # 系统消息处理
        elif role == "system":
            entry["content"] = content


        self.history.append(entry)

    def _resolve_placeholders(self, params: Any) -> Any:
        if isinstance(params, dict):
            return {k: self._resolve_placeholders(v) for k, v in params.items()}
        elif isinstance(params, list):
            return [self._resolve_placeholders(v) for v in params]
        elif isinstance(params, str):
            return re.sub(
                r"\${(.*?)}",
                lambda m: str(self.context.get(m.group(1), "")),
                params
            )
        return params

    def _execute_tool(self, tool_call: Dict) -> str:
        fn_name = tool_call["name"]
        domain, method = fn_name.split(".", 1)

        tool = self.tool_instances.get(inflection.camelize(domain))

        if not tool:
            return f"Error: Tool {domain} not found"

        args = self._resolve_placeholders(tool_call.get("arguments", {}))

        self.function_calls.append(_fmt_call(fn_name, args))

        try:
            result = getattr(tool, method)(**args)

            return result
        except Exception as e:
            return f"Tool error: {str(e)}"

    def _call_llm(self, messages: List[Dict], tools: List[Dict] = None) -> ChatCompletionMessage:
        response = self.openai_client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            tools=tools,
            temperature=self.temperature,
            max_tokens=8192
        )
        return response.choices[0].message



    def _execute_dynamic_step(self, total_task :str, step: Dict, workflow: list, history_logs:list) -> List[str]:
        """
        Execute a *dynamic* (LLM‑driven) step in a meta‑workflow.

        """
        logs = []


        adapt_messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user",
             "content": Template(ADAPTATION_PROMPT_TEMPLATE.lstrip()).render({"total_task": total_task,
                                                                              'context': self.context,
                                                                              'current_sub_task': step['instruction'],
                                                                              'workflow': workflow,
                                                                              'api_docs': self.tools,
                                                                              'history_logs': history_logs
                                                                              })}
        ]


        if self.ICL_information:
            adapt_messages[-1]['content'] += f"\n\nThe following tool calls and their corresponding results are provided for your reference in adapting the subtask.\n{self.ICL_information}\n"

        # phi!!
        sub_task = self._call_llm(adapt_messages)
        sub_goal_pattern = r'```json\s*([\s\S]*?)```'
        matches = re.findall(sub_goal_pattern, sub_task.content, re.DOTALL)
        sub_task = json.loads(matches[0])


        messages = [
            {
                "role": "system",
                "content": self.dynamic_agnet_prompt,
            },
            {
                "role": "user",
                "content": (
                    f"You are solving a multi-step query within a pre-defined workflow. The full user query is:\n\n"
                    f"{total_task}\n\n"
                    f"And the pre-defined workflow is:\n{workflow}\n\n"
                    f"But you should now focus **only** on the following subtask:\n"
                    f"→ {step['instruction']}\n\n"
                    f"Here are the historical Logs prior to the completion of this sub_task:\n{history_logs}\n"
                    f"Use the current variables: {self.context}\n"
                    f"Do not perform any part of the full task beyond this subtask. Do not call other tools or add outputs unless explicitly instructed.\n")
            }]

        if sub_task['outputs']:
            messages[-1]['content'] += (f'\n\nAfter completing the subtask, please store the output variables in the `\\boxed{{json key value pair}}` format. Below is the list of expected output variables: {sub_task["outputs"]}'
                                        'Example: If your output variables are: `result = 42` Please return them in this format: `\\boxed{{"result": 42}}`')

        if self.ICL_information:
            messages[-1]['content'] += f"\n\nThe following tool calls and their corresponding results are provided for your reference in completing the subtask.\n{self.ICL_information}\n"

        # ---------- 3️⃣ ：Reason → Act → Observe ---------- #
        step_index = 0
        response = None
        while step_index <= self.max_agent_steps:
            self.LLM_api_calls += 1
            response = self._call_llm(messages, tools=self.tools)


            messages.append(response.to_dict())

            if not response.tool_calls:
                break

            for tc in response.tool_calls:
                self.tool_calls_by_LLM += 1
                result = self._execute_tool({
                    "name": tc.function.name,
                    "arguments": json.loads(tc.function.arguments)
                })
                result_json = json.dumps(result, ensure_ascii=False)

                logs.append(f"Called {tc.function.name} => {result_json}")

                messages.append({
                    "role": "tool",
                    "tool_call_id": tc.id,
                    "content": result_json
                })

            step_index += 1

        if response and response.content:
            missing_vars = set(step.get("outputs", []))
            if match := re.search(r"\\boxed{(\{.*?\})}", response.content):
                try:
                    data = json.loads(match.group(1))
                    for var in step.get("outputs", []):
                        if var in data:
                            self.context[var] = data[var]
                            missing_vars.remove(var)

                            logs.append(f"Stored {var}={data[var]}")
                except json.JSONDecodeError:
                    pass


            for _ in range(self.reflection_num):
                if len(missing_vars) > 0:
                    reflection_prmpt = Template(REFLECTION_PROMPT_TEMPLATE.lstrip()).render({"variables": missing_vars
                                                                              })
                    messages.append({"role": "user", "content": reflection_prmpt})

                    response = self._call_llm(messages,
                                              tools=self.tools,
                                              )
                    messages.append({"role": "assistant", "content": response.content})

                    if match := re.search(r"\\boxed{(\{.*?\})}", response.content):
                        try:
                            data = json.loads(match.group(1))
                            for var in step.get("outputs", []):
                                if var in data:
                                    self.context[var] = data[var]
                                    missing_vars.remove(var)
                                    logs.append(f"Stored {var}={data[var]}")
                        except json.JSONDecodeError:
                            pass
                else:
                    break
        return logs

    def execute_flow(self, user_query: str, flow: list) -> Dict:
        # self.context["user_query"] = user_query
        logs = []

        try:
            for index, step in enumerate(flow):
                if step["type"] == "static":
                    params = self._resolve_placeholders(step.get("params", {}))
                    result = self._execute_tool({
                        "name": step["action"],
                        "arguments": params
                    })
                    if output := step.get("output"):
                        self.context[output] = result
                    logs.append(f"Executed {step['action']}: {result}")


                elif step["type"] == "dynamic":
                    step_logs = self._execute_dynamic_step(total_task=user_query, step=step, workflow=flow, history_logs=logs)
                    logs.extend(step_logs)
                else:
                    raise ValueError(f"Unknown step type: {step['type']}")
            return {"status": "success", "logs": logs, "context": self.context, 'function_calls' : self.function_calls, 'num_llm_calls':self.LLM_api_calls, 'num_tool_calls':self.tool_calls_by_LLM, 'query':user_query ,'workflow' : flow}
        except Exception as e:
            return {"status": "error", "message": str(e), "logs": logs , 'function_calls' : self.function_calls, 'num_llm_calls':self.LLM_api_calls, 'num_tool_calls':self.tool_calls_by_LLM, 'query':user_query ,'workflow' : flow}

def process_leaf(args):
    meta_id, meta_query, meta_workflow, leaf_id, model_name, rep_index, rep_exec_temperature, agent_type, ICL_information = args

    if ICL_information == "":
        if 'privileged' in agent_type:
            ICL_information = f"```json\n{json.dumps(leaf_nodes[str(leaf_id)]['toolcall_and_results'], indent=2)}\n```"


    runner = MetaflowRunner(
        domains=leaf_nodes[str(leaf_id)]['domains'],
        base_url=base_url,
        openai_api_key=API_KEY,
        model_name=model_name,
        temperature=rep_exec_temperature,
        agent_type=agent_type,
        ICL_information=ICL_information
    )
    query = leaf_nodes[str(leaf_id)]['query']
    result = runner.execute_flow(query, meta_workflow)

    ground_truth = leaf_nodes[str(leaf_id)]['ground_truth']



    accuracy, side_effects = calculate_metrics_single_in_class( result.get('function_calls'), ground_truth, result.get('message'))

    # =========FOR DEBUGGING===========
    result['query'] = query
    result['ground_truth'] = ground_truth
    result['meta_workflow'] = meta_workflow
    # print(json.dumps(result, indent=2))
    # ================================


    return {
        'meta_id' : meta_id,
        'rep_index': rep_index,
        'eval_result':accuracy and (not side_effects),
        'result': result
    }




def correct_ratio_reward_func(completions, fixed_agent_type, **kwargs):
    global ALL_EVAL_RESULTS


    completion_contents = [completion[0]["content"] for completion in completions]
    # print(completion_contents)
    # ===========================

    rewards = []
    args_list = []
    meta_id = 0
    invalid_meta_ids = set()

    for completion, leaf_ids in zip(completion_contents, kwargs['leaf_ids']):
        try:
            pattern_meta_query = r'\*\*Meta Query\*\*:\s*(.*?)\s*\*\*Meta Workflow\*\*'
            meta_query_match = re.search(pattern_meta_query, completion, re.DOTALL)
            meta_query = meta_query_match.group(1).strip() if meta_query_match else None

            pattern_meta_workflow = r'\*\*Meta Workflow\*\*:\s*```json\s*([\s\S]*?)```'
            meta_workflow_match = re.search(pattern_meta_workflow, completion, re.DOTALL)
            meta_workflow = meta_workflow_match.group(1).strip() if meta_workflow_match else None

            if meta_query is None or meta_workflow is None:
                raise Exception("Metaflow query or Metaflow not found")

            meta_workflow = json.loads(meta_workflow)


            if meta_query is None or len(meta_workflow) == 0:
                raise Exception('Meta flow is invalid.')
            
            valid = True
            for step in meta_workflow:
                if step['type'] not in ['static', 'dynamic']:
                    valid = False
                if step['type'] == 'static' and '$' in step.get('action', '$'):
                    valid = False
            if valid == False:
                raise Exception('Meta flow is invalid.')

            rep_exec_temperature = exec_temperature
            temperature_step = (exec_temperature - 0.0) / num_repetitions
            for rep_index in range(num_repetitions):
                rep_exec_temperature = max(rep_exec_temperature - temperature_step, 0)
                args_list += [
                    (meta_id, meta_query, meta_workflow, leaf_id, model_name, rep_index, rep_exec_temperature, fixed_agent_type, "") for leaf_id in
                    leaf_ids]
            # ================================
        except Exception as e:
            print('Metaflow Exception:', e)
            invalid_meta_ids.add(meta_id)
        # ====================
        meta_id += 1

    if fixed_agent_type in ['react', 'privileged_react']:
        if len(args_list) > 0:
            if DEBUG:
                eval_results = [process_leaf(args) for args in args_list]
            else:
                with ProcessPoolExecutor(max_workers=max_workers) as executor:
                    eval_results = list(tqdm(
                        executor.map(process_leaf, args_list),
                        total=len(args_list), desc=f"Processing Metaflows."
                    ))
        else:
            eval_results = []

        ALL_EVAL_RESULTS += eval_results
        for i in range(len(completion_contents)):
            if i in invalid_meta_ids:
                rewards.append(-1)
            else:
                max_reward_among_repetitions = -1
                for rep_index in range(num_repetitions):
                    sub_eval_results = [sample for sample in eval_results if
                                        sample['meta_id'] == i and sample['rep_index'] == rep_index]
                    num_passes = len([eval_result for eval_result in sub_eval_results if
                                      eval_result.get('eval_result', False) == True])
                    num_testes = len(sub_eval_results)
                    if num_testes == 0:
                        max_reward_among_repetitions = max_reward_among_repetitions
                    else:
                        max_reward_among_repetitions = max(max_reward_among_repetitions, num_passes / num_testes)
                rewards.append(max_reward_among_repetitions)

        return rewards

    elif fixed_agent_type == 'code_lines_ratio':
        for i in range(len(completion_contents)):
            if i in invalid_meta_ids:
                rewards.append(-1)
            else:
                completion = completion_contents[i]

                pattern_meta_workflow = r'\*\*Meta Workflow\*\*:\s*```json\s*([\s\S]*?)```'
                meta_workflow_match = re.search(pattern_meta_workflow, completion, re.DOTALL)
                meta_workflow = meta_workflow_match.group(1).strip() if meta_workflow_match else None
                lines = count_static_tool_calls(meta_workflow)

                compress_ratio = max(float(lines) / (kwargs['workflow1_lines'][i] + 0.001),
                                     float(lines) / (kwargs['workflow2_lines'][i] + 0.001))

                compress_ratio = min(compress_ratio, compress_ratio_threshold)

                compress_ratio = compress_ratio / compress_ratio_threshold

                rewards.append(compress_ratio)

        return rewards
    else:
        raise Exception(f'{fixed_agent_type} is not supported.')


if __name__ == "__main__":
    metaworkflow = [
        {
            "type": "dynamic",
            "instruction": "Extract date from user query and format to YYYY-MM-DD",
            "outputs": ["formatted_date"]
        },
        {
            "type": "static",
            "action": "calendar.search_events",
            "params": {
                "time_min": "${formatted_date} 09:00:00",
                "time_max": "${formatted_date} 18:00:00"
            },
            "output": "event_data"
        },
        {
            "type": "dynamic",
            "instruction": "Extract the name and participant_email of the earliest event from the `event_data` list",
            "outputs": ["event_name", "participant_email"]
        },
        {
            "type": "static",
            "action": "email.send_email",
            "params": {
                "recipient": "${participant_email}",
                "subject": "${event_name}",
                "body": "Remember to attend this event."
            },
            "output": "email_result"
        }
    ]

    runner = MetaflowRunner(
        domains=["calendar","email"],
        base_url="http://localhost:6000/v1",
        openai_api_key="",
        model_name="qwen2.5-14b-instruct"
    )


    result = runner.execute_flow('Can you send an email to attendees of the first event on November 30? Title it with the event name and tell them \'Remember to attend this event.\'',
                        metaworkflow)

    ground_truth = """[\'email.send_email.func(recipient="fatima.khan@atlas.com", subject="Annual Budget Planning Session", body="Remember to attend this event.")\']"""
    ground_truth = ast.literal_eval(ground_truth)
    print(result.get('function_calls'))


    print(calculate_metrics_single(result.get('function_calls'), ground_truth, result.get('message')))


    print(calculate_metrics_single_in_class( result.get('function_calls'), ground_truth, result.get('message')))


    print(execute_actions_and_return_result_in_class(result.get('function_calls')))
