import os
from typing import List, Tuple

import streamlit as st

from .draft_steady_states import SteadyStateDraftAgent
from .refine_steady_states import SteadyStateRefiner, SteadyStates
from ....preprocessing.preprocessor import ProcessedData
from ....utils.functions import pseudo_streaming_text
from ....utils.wrappers  import LLM
from ....utils.llms import LLMLog


class SteadyStateAgent:
    def __init__(
        self,
        llm: LLM,
        test_dir: str = "sandbox/unit_test",
        namespace: str = "chaos-eater",
        max_mod_loop: int = 5
    ) -> None:
        self.llm = llm
        self.test_dir = test_dir
        self.namespace = namespace
        self.max_mod_loop = max_mod_loop
        # agents
        self.draft_agent = SteadyStateDraftAgent(llm)
        self.refiner     = SteadyStateRefiner(llm, namespace)

    def define_steady_states(
        self,
        data: ProcessedData,
        work_dir: str
    ) -> Tuple[List[LLMLog], SteadyStates]:
        #-------------------
        # 0. initialization
        #-------------------
        # gui settings
        steady_state_msg = st.empty()
        st.session_state.steady_states = []
        # directory settings
        steady_state_dir = f"{work_dir}/steady_states"
        os.makedirs(steady_state_dir, exist_ok=True)
        logs = []

        #------------------------
        # 1. draft steady states
        #------------------------
        pseudo_streaming_text("##### Drafting steady states...", obj=steady_state_msg)
        draft_log, steady_state_names = self.draft_agent.draft_steady_states(data)
        logs.append(draft_log)

        #---------------------------------------------------------
        # 2. define a command & inspect the current state with it
        #---------------------------------------------------------
        pseudo_streaming_text("##### Refining each steady state...", obj=steady_state_msg)
        refine_logs, steady_states = self.refiner.refine(
            data=data,
            steady_state_names=steady_state_names,
            work_dir=steady_state_dir
        )
        logs += refine_logs
        return logs, steady_states