import asyncio
from typing import List, Dict, Tuple

import streamlit as st

from .write_cmd import Inspection
from ....preprocessing.preprocessor import ProcessedData
from ....utils.wrappers import LLM, LLMBaseModel, LLMField
from ....utils.llms import build_json_agent, LLMLog, LoggingCallback


SYS_DEFINE_THRESHOLD = """\
You are a helpful AI assistant for Chaos Engineering. 
Given k8s manifests that define a network system, its steady state, and the current state of the steady state, you will define the threshold for the steady state.
Always keep the following rules:
- The threshold should be satisfied in the current state at least.
- Threshold must be representative value (e.g., ratio, percentage, ect.), not fixed absolute value.
- Threshold must have reasonable tolerance.
- You may set the threshold with a certain tolerance.
- NEVER output any sentences but the JSON format.
- {format_instructions}"""

USER_DEFINE_THRESHOLD = """\
# Here is the overview of my system:
{user_input}

# You will determine a reasonable threshold for the following steady state of my system:
{steady_state_name}: {steady_state_thought}

{inspection_summary}

# Please follow the instructions below regarding Chaos Engineering:
{ce_instructions}

Now, please define a reasonable threshold for the steady state according to the above information."""


class Threshold(LLMBaseModel):
    threshold: str = LLMField(description="the threshold of the steady state, which should be satisfied satisfied in the current state.")
    reason_for_threshold: str = LLMField(description="reason for setting the threshold")


class ThresholdAgent:
    def __init__(self, llm: LLM) -> None:
        self.llm = llm
        self.agent = build_json_agent(
            llm=llm,
            chat_messages=[("system", SYS_DEFINE_THRESHOLD), ("human", USER_DEFINE_THRESHOLD)],
            pydantic_object=Threshold,
            is_async=False
        )
    
    def define_thresholds(
        self,
        input_data: ProcessedData,
        steady_state_names: List[Dict[str, str]],
        inspections: List[Inspection]
    ) -> Tuple[LLMLog, List[Dict[str, str]]]:
        logger = LoggingCallback(name="threshold_definition", llm=self.llm)
        results = []
        for idx, (steady_state, inspection) in enumerate(zip(steady_state_names, inspections)):
            for token in self.agent.stream({
                "user_input": input_data.to_k8s_overview_str(),
                "ce_instructions": input_data.ce_instructions,
                "steady_state_name": steady_state["name"],
                "steady_state_thought": steady_state["reason"],
                "inspection_summary": inspection.to_str()},
                {"callbacks": [logger]}
            ):
                if (threshold := token.get("threshold")) is not None:
                    steady_state["threshold"] = threshold
                    st.session_state.steady_states[idx]["threshold"]["spinner"].end()
                    st.session_state.steady_states[idx]["threshold"]["empty"].write(threshold)
                if (reason := token.get("reason_for_threshold")) is not None:
                    steady_state["threshold_reason"] = reason
                    st.session_state.steady_states[idx]["threshold"]["empty"].write(threshold + ": " + reason)
            results.append({"threshold": threshold, "reason": reason})
        return logger.log, results