import os
from typing import List, Dict, Tuple

import streamlit as st
from langchain_core.runnables.base import Runnable

from .assume_fault_scenario import FaultScenarioAgent
from .refine_faults import FaultRefiner, FaultScenarios
from ....ce_tools.ce_tool_base import CEToolBase
from ..steady_states.steady_state_agent import SteadyStates
from ....preprocessing.preprocessor import ProcessedData
from ....utils.functions import pseudo_streaming_text
from ....utils.llms import LLMLog


class FaultAgent:
    def __init__(
        self,
        llm: Runnable,
        ce_tool: CEToolBase,
        test_dir: str = "sandbox/unit_test",
        namespace: str = "chaos-eater"
    ) -> None:
        self.llm = llm
        self.ce_tool = ce_tool
        self.test_dir = test_dir
        self.namespace = namespace
        # agents
        self.fault_scenario_agent = FaultScenarioAgent(llm, ce_tool)
        self.refiner = FaultRefiner(llm, ce_tool)

    def convert_steady_state_to_str(self, steady_states: List[Dict[str, str]]) -> str:
        steady_state_str = ""
        for i, steady_state in enumerate(steady_states):
            steady_state_str += f"Steady state #{i}: {steady_state.name}\nDescription: {steady_state.description}\nThreshold: {steady_state.threshold['threshold']}; {steady_state.threshold['reason']}"
        return steady_state_str

    def define_faults(
        self,
        data: ProcessedData,
        steady_states: SteadyStates,
        work_dir: str
    ) -> Tuple[LLMLog, FaultScenarios]:
        #-------------------
        # 0. initialization
        #-------------------
        fault_msg = st.empty()
        st.session_state.fault_container = []
        fault_dir = f"{work_dir}/faults"
        os.makedirs(fault_dir, exist_ok=True)
        logs = []

        #----------------------------
        # 1. assume a fault scenario
        #----------------------------
        pseudo_streaming_text("##### Assuming fault scienarios...", obj=fault_msg)
        scenario_log, fault_scenarios = self.fault_scenario_agent.assume_scenarios(
            user_input=data.to_k8s_overview_str(),
            ce_instructions=data.ce_instructions,
            steady_states=steady_states
        )
        logs.append(scenario_log)

        #---------------------------------------------
        # refine the faults: determine the parameters
        #---------------------------------------------
        pseudo_streaming_text("##### Refining faults...", obj=fault_msg)
        fault_log, faults = self.refiner.refine_faults(
            user_input=data.to_k8s_overview_str(),
            ce_instructions=data.ce_instructions,
            steady_states=steady_states,
            fault_scenarios=fault_scenarios,
            work_dir=fault_dir
        )
        logs.append(fault_log)

        pseudo_streaming_text("##### Completed defining faults!", obj=fault_msg)
        return logs, faults