import json
from typing import List, Dict, Tuple, Literal

import streamlit as st

from ....ce_tools.ce_tool_base import CEToolBase
from ....hypothesis.llm_agents.steady_states.steady_state_agent import SteadyStates
from ....utils.wrappers import LLM, LLMBaseModel, LLMField
from ....utils.llms import build_json_agent, LLMLog, LoggingCallback
from ....utils.app_utils import st_spinner


SYS_ASSUME_FAULT_SCENARIOS = """\
You are a helpful AI assistant for Chaos Engineering. 
Given k8s manifests (i.e., system), their steady states, and user's instructions for the Chaos Engineering, please define the most efficient fault injections to break the system.
Always keep the following rules:
- Assume a real-world event that may occur in the system. For example, promotion campaign, cyber attacks, disasters, etc.
- If the number of scenarios is not specified, please propose only one.
- The injected faults should be selected from the fault types of {ce_tool_name}:
{ce_tool_fault_types}
- Design the strongest fault injections with the intention of breaking the given system. This is to build a resilient system in advance.
- {format_instructions}"""

USER_ASSUME_FAULT_SCENARIOS = """\
Here is the overview of my system:
{user_input}

Steady states of the network system defined by the manifests are the following:
{steady_states}

Please follow the instructions below regarding Chaos Engineering as necessary:
{ce_instructions}

Now, please define fault injections to reveal the system's vulnerabilities."""


class Fault(LLMBaseModel):
    # TODO: support other CE tools
    name: Literal["PodChaos", "NetworkChaos", "DNSChaos", "HTTPChaos", "StressChaos", "IOChaos", "TimeChaos"] = LLMField(description='Select a fault type from ["PodChaos", "NetworkChaos", "DNSChaos", "HTTPChaos", "StressChaos", "IOChaos", "TimeChaos"]')
    name_id: int = LLMField(description="An identifier to prevent name conflicts when the same Fault appears. Assign numbers starting from 0 in sequential order to prevent name conflicts.")
    scope: Dict[str, str] = LLMField(description="Specify only the fault injection scope in advance here.")

class FaultScenario(LLMBaseModel):
    event: str = LLMField(description="Assume a real-world fault event that is challenging for the system.")
    thought: str = LLMField(description="Write down the whole your thought process as follows: Describe the weak points of the given K8s manifests. Then, list ALL the fault types that are related to the weak points. Lastly, plan a fault sequence consisting of the most effective faults to break the system through the weak points.")
    faults: List[List[Fault]] = LLMField(description="Define the most effective fault injections to break the system. In the inner list, a set of simultaneously injected faults are listed, while in the outer list, the sets are listed in the injection order. For example, [[fault_a], [fault_b, fault_c]] indicates that fault_a is injected, then fault_b and fault_c are injected simultaneously.")
    effects: Dict[str, str] = LLMField(description="Describe how the list of faults affects the steady states directly or indirectly. The format is dict[steady_state_name: str, details_of_the_effects: str].")

class FaultScenarios(LLMBaseModel):
    fault_scenarios: List[FaultScenario] = LLMField(descripton="A list of fault scenarios")


class FaultScenarioAgent:
    def __init__(self, llm: LLM, ce_tool: CEToolBase) -> None:
        self.llm = llm
        self.ce_tool = ce_tool
        self.agent = build_json_agent(
            llm=llm,
            chat_messages=[("system", SYS_ASSUME_FAULT_SCENARIOS), ("human", USER_ASSUME_FAULT_SCENARIOS)],
            pydantic_object=FaultScenarios,
            is_async=False,
            streaming_func=self._extract_json_items_streaming
        )
        self.ICON: Dict[str, str] = {
            "description": "💬",
            "faults": "🐞",
            "effects": "💥",
            "settings": "⚙"
        }

    def assume_scenarios(
        self,
        user_input: str,
        ce_instructions: str,
        steady_states: SteadyStates
    ) -> Tuple[LLMLog, List[Dict[str, str]]]:
        logger = LoggingCallback(name="fault_scenario_assumption", llm=self.llm)
        fault_scenarios = []
        for token in self.agent.stream({
            "user_input": user_input,
            "ce_instructions": ce_instructions,
            "steady_states": steady_states.to_overview_str(),
            "ce_tool_name": self.ce_tool.name,
            "ce_tool_fault_types": self.ce_tool.get_chaos_var_candidates()},
            {"callbacks": [logger]}
        ):
            if (idx := token.get("idx")) is not None:
                if idx + 1 > len(st.session_state.fault_container):
                    st.session_state.fault_container.append(self.get_fault_items())
                    fault_scenarios.append({})
            if (event := token.get("event")) is not None:
                st.session_state.fault_container[idx]["event"].expander("##### " + f"⬜  Scenario#{idx+1}: {event}", expanded=True)
                fault_scenarios[idx]["event"] = event
            if (faults := token.get("faults"))is not None:
                st.session_state.fault_container[idx]["faults"]["spinner"].end()
                st.session_state.fault_container[idx]["faults"]["empty"].write(self.convert_fault_list_to_str(faults))
                fault_scenarios[idx]["faults"] = faults
            if (description := token.get("thought")) is not None:
                st.session_state.fault_container[idx]["description"]["spinner"].end()
                st.session_state.fault_container[idx]["description"]["empty"].write(description)
                fault_scenarios[idx]["description"] = description
            if (effects := token.get("effects")) is not None:
                st.session_state.fault_container[idx]["effects"]["spinner"].end()
                st.session_state.fault_container[idx]["effects"]["empty"].write(self.convert_effect_to_str(effects))
                fault_scenarios[idx]["effects"] = effects
        return logger.log, fault_scenarios

    def _extract_json_items_streaming(self, input_stream):
        for input in input_stream:
            if not isinstance(input, dict):
                continue
            if "fault_scenarios" not in input:
                continue
            fault_scenarios = input["fault_scenarios"]
            if not isinstance(fault_scenarios, list):
                continue
            for idx, fault_scenario in enumerate(fault_scenarios):
                yield {"idx": idx} | {key: fault_scenario.get(key) for key in FaultScenario.__fields__.keys()}
    
    def get_fault_items(self):
        fault_items = {}
        fault_items["event"] = st.empty()
        expander = fault_items["event"].expander("##### " + "⬜  ", expanded=True)
        with expander:
            for key, value in self.ICON.items():
                frame_empty = st.empty()
                col1, col2 = frame_empty.columns([1, 20])
                with col1:
                    st.write(value)
                with col2:
                    if key == "settings":
                        fault_items[key] = {"spinner": st_spinner("Pending..."), "empty": st.container()}
                    else:
                        fault_items[key] = {"spinner": st_spinner("Pending..."), "empty": st.empty()}
        return fault_items
    
    def convert_fault_list_to_str(self, faults: List[List[Fault]]) -> str:
        fault_list = ""
        for j, para_faults in enumerate(faults):
            para_fault_str = ""
            for i, fault in enumerate(para_faults):
                if i != 0:
                    para_fault_str += ", "
                if "name" in fault.keys():
                    para_fault_str += fault["name"]
                    if "scope" in fault.keys():
                        para_fault_str += f"({fault['scope']})"
            if j != 0:
                fault_list += "  ➡  "
            fault_list += para_fault_str
        return fault_list
    
    def convert_effect_to_str(self, effects: Dict[str, str]) -> str:
        effect_str = ""
        num_effects = len(effects.items())
        for i, (steady_state, description) in enumerate(effects.items()):
            if i != 0:
                effect_str += "  \n"
            if num_effects > 1:
                effect_str += "- "
            effect_str += f"```{steady_state}```: {description}"
        return effect_str