from __future__ import annotations

import inspect
import json
import os
import re
import csv
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List
from tqdm import tqdm
import pandas as pd
from openai import OpenAI, RateLimitError
from src.data_generation.data_generation_utils import HARDCODED_CURRENT_TIME
import inflection

# ---------------------------------------------------------------------------
# Configuration helpers
# ---------------------------------------------------------------------------
# These domain objects are assumed to be the same as in the original repo
from src.tools_class import (
    calendar,
    email,
    analytics,
    project_management,
    customer_relationship_manager,
    company_directory,
)

DOMAINS = [
    calendar,
    email,
    analytics,
    project_management,
    customer_relationship_manager,
    company_directory,
]

# ---------------------------------------------------------------------------
# Utility: convert python callables → OpenAI tool schemas
# ---------------------------------------------------------------------------

def _annotation_to_json_type(annotation: Any) -> str:
    mapping = {int: "number", float: "number", bool: "boolean", str: "string"}
    return mapping.get(annotation, "string")


def function_to_tool(fn):
    """
    Build an OpenAI function schema from a LangChain-style tool object.
    Expects `fn` to have:
      - `name`: the tool identifier (e.g. "email.delete_email")
      - `args_schema`: a Pydantic BaseModel class describing its parameters
      - `description`: a docstring or human-readable description
    Returns None if required attributes are missing.
    """
    if not hasattr(fn, 'name'):
        return None

    schema_model = fn.args_schema
    model_schema = schema_model.model_json_schema()
    properties = model_schema.get("properties", {})
    required = model_schema.get("required", [])

    return {
        "type": "function",
        "function": {
            "name": fn.name,
            "description": (fn.description or "No description provided.").strip(),
            "parameters": {
                "type": "object",
                "properties": {k:v for k, v in properties.items() if k!= 'self'},
                "required": [name for name in required if name !='self'],
            },
        },
    }


def build_tools(selected_toolkits: List[str]) -> tuple[list[dict[str, Any]], dict[str, Callable]]:
    """
    Convert the selected toolkit modules into OpenAI function‑calling schema,
    and construct a name → callable mapping for the tools by instantiating class-based tools.
    """
    tools_json: list[dict[str, Any]] = []
    lookup: dict[str, Callable] = {}

    module_alias = {
        'crm': 'customer_relationship_manager',
    }
    class_name_set = set()

    with open('./tools_description.json', 'r') as file:
        tools_json = json.load(file)

    for module_name in selected_toolkits:
        actual_module_name = module_alias.get(module_name, module_name)
        module = globals()[actual_module_name]  # Load the module

        # Find the class name based on the module
        class_name = module.__name__.split('.')[-1]
        class_name = inflection.camelize(class_name)
        tool_class = getattr(module, class_name)  # Get the class from the module
        class_name_set.add(class_name)
        lookup[class_name] = tool_class  # Save the method as callable

    tools_json = [tool for tool in tools_json if inflection.camelize(tool['function']['name'].split('.')[0]) in class_name_set]

    return tools_json, lookup

def _fmt_call(name: str, args: Dict[str, Any]) -> str:
    """Generate the function call string 'tool.func(a="1", b="2")'."""
    arg_list = [f'{k}="{v}"' for k, v in args.items()]
    return f"{name}.func(" + ", ".join(arg_list) + ")"


def run_agent(
    client,
    model: str,
    query: str,
    tools: List[Dict[str, Any]],
    fn_lookup: Dict[str, Callable],
    max_iter: int = 20,
    max_seconds: int = 120,
    temperature: float = 0,
):
    """
    Returns the final response and logs the sequence of function calls.
    """
    # ============================
    # Instantiate tool classes to ensure independent state for each agent run
    tool_instances = {}  # This will store tool instances
    for class_name, tool_class in fn_lookup.items():
        tool_instances[class_name] = tool_class()
    # ============================

    system_prompt = (
        f"You are a helpful assistant and your task is to complete the user's query using the given tools."
        f"Today's date is {HARDCODED_CURRENT_TIME.strftime('%A, %B %d, %Y')} "
        f"and the current time is {HARDCODED_CURRENT_TIME.strftime('%H:%M')}."
        "Please remember the current date and time when answering queries. Meetings must not start before 9am or end after 6pm."
        "Please respond helpfully and accurately to the user."
    )
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": query},
    ]

    start_time = datetime.now()
    function_calls: List[str] = []

    for _ in range(max_iter):
        # Timeout check
        if (datetime.now() - start_time).total_seconds() > max_seconds:
            return {
                "output": "Agent stopped due to iteration or time limit.",
                "function_calls": function_calls,
            }

        # LLM call
        resp = client.chat.completions.create(
            model=model,
            messages=messages,
            tools=tools,
            tool_choice="auto",
            temperature=temperature
        )
        choice = resp.choices[0]
        msg = choice.message

        # Tool function calls
        if choice.finish_reason == "tool_calls":
            for call in msg.tool_calls:
                fn_name = call.function.name
                raw_args = call.function.arguments or "{}"
                try:
                    args = json.loads(raw_args)
                except json.JSONDecodeError:
                    args = {}

                # Execute the tool's method using the class instance
                domain_name = fn_name.split(".")[0]  # e.g., "email" or "calendar"
                tool_instance = tool_instances.get(inflection.camelize(domain_name))
                if tool_instance:
                    method_name = fn_name.split(".")[1]  # Get method name (e.g., "send_email")
                    result = getattr(tool_instance, method_name)(**args)

                    # Log the formatted function call
                    function_calls.append(_fmt_call(fn_name, args))

                    # Send the result back to the model
                    messages.append(
                        {"role": "assistant", "content": None, "tool_calls": [call]}
                    )
                    messages.append(
                        {
                            "role": "tool",
                            "tool_call_id": call.id,
                            "name": fn_name,
                            "content": str(result),
                        }
                    )
                else:
                    raise Exception(f"No function called {fn_name}")
            continue  # Proceed to next round

        # Final answer
        messages.append({"role": "assistant", "content": msg.content})
        return {"output": msg.content, "function_calls": function_calls}

    # Timeout reached
    return {
        "output": "Agent stopped due to iteration or time limit.",
        "function_calls": function_calls,
    }


def generate_results(
    queries_path: str | os.PathLike,
    model_name: str,
    tool_selection: str = "all",
    temperature: float = 0,
) -> pd.DataFrame:
    """Run inference over queries CSV and save results."""

    client = OpenAI()

    df_queries = pd.read_csv(queries_path)
    queries = df_queries["query"].tolist()

    # Pre‑build tools / look‑up once (unless overridden per‑row)
    toolkits = [
        "email",
        "calendar",
        "analytics",
        "project_management",
        "customer_relationship_manager",
    ]

    base_tools, base_lookup = build_tools(toolkits)

    results_rows = []
    for i, q in tqdm(enumerate(queries), total=len(queries), desc=f"{model_name} Processing {queries_path} queries... [temperature {temperature}]"):
        if tool_selection == "domains":
            row_tks = (
                df_queries["domains"].iloc[i].strip("[]").replace("'", "").split(", ")
            )
            tools, lookup = build_tools(row_tks)
        else:
            tools, lookup = base_tools, base_lookup

        # main attempt
        error_msg = ""
        try:
            agent_result = run_agent(client, model_name, q, tools, lookup, temperature=temperature)
            final = agent_result["output"]
            calls = agent_result["function_calls"]
        except Exception as exc:
            final = exc
            calls = []
            error_msg = str(exc)
        results_rows.append({
            "query": q,
            "function_calls": calls,
            "full_response": final,
            "error": error_msg,
        })


    results_df = pd.DataFrame(results_rows)

    domain_name = Path(queries_path).stem.replace("_queries_and_answers", "")
    save_dir = Path("data/results") / domain_name
    save_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    save_path = save_dir / f"{model_name}_{tool_selection}_{timestamp}.csv"
    results_df.to_csv(save_path, index=False, quoting=csv.QUOTE_ALL)
    print(f"Results saved to {save_path}")
    return results_df



# ---------------------------------------------------------------------------
# CLI usage
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    AVAILABLE_LLMS = ["qwen2.5-32b-instruct"]
    os.environ["OPENAI_BASE_URL"] = "http://localhost:6000/v1"
    os.environ["OPENAI_API_KEY"] = ""

    import argparse, time

    parser = argparse.ArgumentParser(description="OpenAI ReAct Runner")
    args = parser.parse_args()

    query_paths = [
        "data/processed/queries_and_answers/multi_domain_queries_and_answers.csv",
        "data/processed/queries_and_answers/email_queries_and_answers.csv",
        "data/processed/queries_and_answers/calendar_queries_and_answers.csv",
        "data/processed/queries_and_answers/analytics_queries_and_answers.csv",
        "data/processed/queries_and_answers/project_management_queries_and_answers.csv",
        "data/processed/queries_and_answers/customer_relationship_manager_queries_and_answers.csv",
    ]

    for tool_selection in ["domains"]:
        for model in AVAILABLE_LLMS:
            for query_path in query_paths:
                generate_results(
                    queries_path=query_path,
                    model_name=model,
                    tool_selection=tool_selection,
                    temperature=0
                )