from typing import List, Dict, Tuple, Literal

import streamlit as st

from .utils import Inspection, run_pod
from ....preprocessing.preprocessor import ProcessedData
from ....utils.wrappers import LLM, LLMBaseModel, LLMField
from ....utils.llms import build_json_agent, LLMLog, LoggingCallback
from ....utils.schemas import File
from ....utils.functions import write_file, dict_to_str, sanitize_filename


INTERVAL_SEC = "1s"
MAX_DURATION = "5s"


SYS_DEFINE_CMD = """\
You are a helpful AI assistant for Chaos Engineering.
Given k8s manifests that define a network system and its steady state, you will determine the way to inspect the steady state.
Always keep the following rules:
- You can use either k8s Client Libraries (Python) or k6 (Javascript) to inspect the steady state.
- Pay attention to namespace specification. If the namespace is specified in the manifest, it is deployed with the namespace. If not, it is deployed with the 'default' namespace.
- Use the K8s API for checking the state of Kubernetes resources, and use k6 for obtaining communication statuses/metrics (e.g., request sending, response time, latency, etc.).
- If you use k6, consider both an appropriate number of virtual users and appropriate test duration.
- If you use k8s, consider appropriate test duration.
- {format_instructions}"""

USER_DEFINE_CMD = """\
# Here is the overview of my system:
{user_input}

# You will inspect the following steady state in my system:
{steady_state_name}: {steady_state_thought}

# Please follow the instructions below regarding Chaos Engineering:
{ce_instructions}

Please define the way to inspect "{steady_state_name}" in the system defined by the above k8s manifest(s)."""

USER_REWRITE_INSPECTION = """\
Your current inspection script causes errors when coducted.
The error message is as follows:
{error_message}

Please analyze the reason why the errors occur, then fix the errors.
Always keep the following rules:
- NEVER repeat the same fixes that have been made in the past.
- Fix only the parts related to the errors without changing the original content.
- You can change the tool (k8s -> k6 or k6 -> k8s) if it can keep the original intention.
- {format_instructions}"""


class K8sAPI(LLMBaseModel):
    duration: str = LLMField(description=f"Duration of the status check every second in a for loop. Set appropriate duration to check the steady state (i.e., normal behavior) of the system. The maximum duration is {MAX_DURATION}.")
    python: str = LLMField(description="Python code with k8s client libraries to inspect the status of k8s resources (i.e., steady state). Write only the content of the code without enclosing it in a code block. Implement a for loop that checks the status every second for the duration, and print a summary of the results at the end.\n- To support docker env, please configure the client as follows: ```\n# Load Kubernetes configuration based on the environment\n    if os.getenv('KUBERNETES_SERVICE_HOST'):\n        config.load_incluster_config()\n    else:\n        config.load_kube_config()\n```\n- Please add a Add a entry point at the bottom to allow the test to be run from the command line.\n- Please add argparse '--duration' (type=int) so that users can specify the loop duration.")

class K6JS(LLMBaseModel):
    vus: int = LLMField(description="The number of virtual users. You can run a load test with the number of virutal users.")
    duration: str = LLMField(description=f"Duration of the load test. Set appropriate duration to check the steady state (i.e., normal behavior) of the system. The maximum duration is {MAX_DURATION}.")
    js: str = LLMField(description=f"k6 javascript to inspect the steady state. In options in the javascript, set the same 'vus' and 'duration' options as the above. The interval of status check must be {INTERVAL_SEC}. Write only the content of the code without enclosing it in a code block.")

class _Inspection(LLMBaseModel):
    thought: str = LLMField(description="Describe your thoughs for the tool usage. e.g., the reason why you choose the tool and how to use.")
    tool_type: Literal["k8s", "k6"] = LLMField(description="Tool to inspect the steady state. Select from ['k8s', 'k6'].")
    tool: K8sAPI | K6JS = LLMField(description="If tool_tyepe='k8s', write here K8sAPI. If tool_tyepe='k6', write here K6JS.")


class InspectionCMDAgent:
    def __init__(
        self,
        llm: LLM,
        namespace: str = "chaos-eater"
    ) -> None:
        self.llm = llm
        self.namespace = namespace
        self.agent = build_json_agent(
            llm=llm,
            chat_messages=[("system", SYS_DEFINE_CMD), ("human", USER_DEFINE_CMD)],
            pydantic_object=_Inspection,
            is_async=False
        )
    
    def write_cmd(
        self,
        input_data: ProcessedData,
        steady_state_names: List[Dict[str, str]],
        work_dir: str,
        max_mod_loop: int = 3
    ) -> Tuple[LLMLog, List[Inspection]]:
        self.logger = LoggingCallback(name="tool_command_writing", llm=self.llm)
        results = []
        for idx, steady_state in enumerate(steady_state_names):
            #---------------
            # intialization
            #---------------
            output_history = []
            error_history = []

            #---------------
            # first attempt
            #---------------
            raw_output, inspection = self.generate_inspection(
                idx,
                input_data,
                steady_state,
                work_dir
            )
            output_history.append(raw_output)

            #-----------------------------------------
            # validate loop for the inspection script
            #-----------------------------------------
            mod_count = 0
            while (1):
                assert mod_count < max_mod_loop, f"MAX_MOD_COUNT_EXCEEDED: {max_mod_loop}"
                
                # run the inspection script
                returncode, console_log = run_pod(
                    inspection,
                    work_dir,
                    self.namespace
                )
                inspection.result = console_log
                st.session_state.steady_states[idx]["value"]["spinner"].end()
                st.session_state.steady_states[idx]["value"]["empty"].code(console_log, language="powershell")

                # validation
                if returncode == 0:
                    break
                error_history.append(console_log)
                print(console_log)

                # modify the inspections
                raw_output, inspection = self.debug_inspection(
                    idx,
                    mod_count,
                    input_data,
                    steady_state,
                    output_history,
                    error_history,
                    work_dir
                )
                output_history.append(raw_output)
                
                # increment count
                mod_count += 1
            results.append(inspection)
        return self.logger.log, results
    
    def generate_inspection(
        self,
        idx: int,
        input_data: ProcessedData,
        steady_state: Dict[str, str],
        work_dir: str
    ) -> Tuple[dict, Inspection]:
        #------------------------------
        # generate a inspection script
        #------------------------------
        for cmd in self.agent.stream({
            "user_input": input_data.to_k8s_overview_str,
            "ce_instructions": input_data.ce_instructions,
            "steady_state_name": steady_state["name"],
            "steady_state_thought": steady_state["reason"]},
            {"callbacks": [self.logger]}
        ):
            if (thought := cmd.get("thought")) is not None:
                st.session_state.steady_states[idx]["cmd"]["tool"]["spinner"].end()
                st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"{thought}")
            if (tool := cmd.get("tool")) is not None:
                st.session_state.steady_states[idx]["cmd"]["cmd"]["spinner"].end()
                if (tool_type := cmd.get("tool_type")) is not None:
                    if tool_type == "k8s":
                        if (py := tool.get("python")) is not None:
                            duration = tool.get("duration")
                            st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"tool: ```{tool_type}``` duration: ```{duration}```  \n{thought}")
                            st.session_state.steady_states[idx]["cmd"]["cmd"]["empty"].code(py, language="python")
                    elif tool_type == "k6":
                        if (js := tool.get("js")) is not None:
                            vus = tool.get("vus")
                            duration = tool.get("duration")
                            st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"tool: ```{tool_type}``` vus: ```{vus}``` duration: ```{duration}```  \n{thought}")
                            st.session_state.steady_states[idx]["cmd"]["cmd"]["empty"].code(js, language="javascript")
        
        #----------
        # epilogue
        #----------
        if tool_type == "k8s":
            code = py
            fname = "k8s_" + sanitize_filename(steady_state["name"]) + ".py"
        elif tool_type == "k6":
            code = js
            fname = "k6_" + sanitize_filename(steady_state["name"]) + ".js"
        else:
            raise TypeError(f"Invalid tool type selected: {tool_type}. Select either 'k8s' or 'k6'.")
        fpath = f"{work_dir}/{fname}"
        write_file(fpath, code)
        return (
            cmd,
            Inspection(
                tool_type=tool_type,
                duration=duration,
                script=File(
                    path=fpath,
                    content=code,
                    work_dir=work_dir,
                    fname=fname
                )
            )
        )
    
    def debug_inspection(
        self,
        idx: int,
        mod_count: int,
        input_data: ProcessedData,
        steady_state: Dict[str, str],
        output_history: List[dict],
        error_history: List[str],
        work_dir: str
    ) -> Inspection:
        #------------------------------------------
        # update chat messages & build a new agent
        #------------------------------------------
        # update chat messages
        chat_messages = [("system", SYS_DEFINE_CMD), ("human", USER_DEFINE_CMD)]
        for output, error in zip(output_history, error_history):
            chat_messages.append(("ai", dict_to_str(output)))
            chat_messages.append(("human", USER_REWRITE_INSPECTION.replace("{error_message}", error.replace('{', '{{').replace('}', '}}'))))

        debugging_agent = build_json_agent(
            llm=self.llm,
            chat_messages=chat_messages,
            pydantic_object=_Inspection,
            is_async=False
        )
        
        #------------------------------------
        # debug the inspection script script
        #------------------------------------
        for cmd in debugging_agent.stream({
            "user_input": input_data.to_k8s_overview_str,
            "ce_instructions": input_data.ce_instructions,
            "steady_state_name": steady_state["name"],
            "steady_state_thought": steady_state["reason"]},
            {"callbacks": [self.logger]}
        ):
            if (thought := cmd.get("thought")) is not None:
                st.session_state.steady_states[idx]["cmd"]["tool"]["spinner"].end()
                st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"{thought}")
            if (tool := cmd.get("tool")) is not None:
                st.session_state.steady_states[idx]["cmd"]["cmd"]["spinner"].end()
                if (tool_type := cmd.get("tool_type")) is not None:
                    if tool_type == "k8s":
                        if (py := tool.get("python")) is not None:
                            duration = tool.get("duration")
                            st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"tool: ```{tool_type}``` duration: ```{duration}```  \n{thought}")
                            st.session_state.steady_states[idx]["cmd"]["cmd"]["empty"].code(py, language="python")
                    elif tool_type == "k6":
                        if (js := tool.get("js")) is not None:
                            vus = tool.get("vus")
                            duration = tool.get("duration")
                            st.session_state.steady_states[idx]["cmd"]["tool"]["empty"].write(f"tool: ```{tool_type}``` vus: ```{vus}``` duration: ```{duration}```  \n{thought}")
                            st.session_state.steady_states[idx]["cmd"]["cmd"]["empty"].code(js, language="javascript")
        
        #----------
        # epilogue
        #----------
        if tool_type == "k8s":
            code = py
            fname = "k8s_" + sanitize_filename(steady_state["name"]) + f"_mod{mod_count}.py"
        elif tool_type == "k6":
            code = js
            fname = "k6_" + sanitize_filename(steady_state["name"]) + f"_mod{mod_count}.js"
        else:
            raise TypeError(f"Invalid tool type selected: {tool_type}. Select either 'k8s' or 'k6'.")
        fpath = f"{work_dir}/{fname}"
        write_file(fpath, code)
        return (
            cmd,
            Inspection(
                tool_type=tool_type,
                duration=duration,
                script=File(
                    path=fpath,
                    content=code,
                    work_dir=work_dir,
                    fname=fname
                )
            )
        )