import os
from typing import List, Tuple

from .llm_agents.steady_states.steady_state_agent import SteadyStateAgent, SteadyStates
from .llm_agents.faults.fault_agent import FaultAgent, FaultScenarios
from ..utils.wrappers import LLM, BaseModel
from ..utils.functions import save_json, recursive_to_dict
from ..utils.llms import LLMLog
from ..ce_tools.ce_tool_base import CEToolBase
from ..preprocessing.preprocessor import ProcessedData


HYPOTHESIS_OVERVIEW_TEMPLATE = """\
The hypothesis is "The steady states of the sytem are maintained even when the fault scenario occurs (i.e., when the faults are injected)".
The steady states here are as follows:
{steady_state_overview}

The fault scenario here is as follows:
{fault_scenario_overview}"""


class Hypothesis(BaseModel):
    steady_states: SteadyStates
    faults: FaultScenarios

    def to_str(self) -> str:
        return HYPOTHESIS_OVERVIEW_TEMPLATE.format(
            steady_state_overview=self.steady_states.to_str(),
            fault_scenario_overview=self.faults.elems[0].to_str()
        )


class Hypothesizer:
    def __init__(
        self,
        llm: LLM,
        ce_tool: CEToolBase,
        test_dir: str = "sandbox/unit_test",
        namespace: str = "chaos-eater",
        max_mod_loop: int = 3
    ) -> None:
        self.llm = llm
        self.ce_tool = ce_tool
        # params
        self.test_dir = test_dir
        self.namespace = namespace
        self.max_mod_loop = max_mod_loop
        # agents
        self.steady_state_agent = SteadyStateAgent(llm, test_dir, namespace, max_mod_loop)
        self.fault_agent = FaultAgent(llm, ce_tool, test_dir, namespace)

    def hypothesize(
        self,
        data: ProcessedData,
        work_dir: str
    ) -> Tuple[List[LLMLog], Hypothesis]:
        #----------------
        # initialization
        #----------------
        hypothesis_dir = f"{work_dir}/hypothesis"
        os.makedirs(hypothesis_dir, exist_ok=True)
        logs = []

        #-------------------------
        # 1. define steady states
        #-------------------------
        steady_state_logs, steady_states = self.steady_state_agent.define_steady_states(data=data, work_dir=hypothesis_dir)
        logs += steady_state_logs
        save_json(f"{hypothesis_dir}/steady_states.json", steady_states.dict())
        save_json(f"{hypothesis_dir}/steady_states.json", recursive_to_dict(steady_state_logs))

        #------------------
        # 2. define faults
        #------------------
        fault_logs, faults = self.fault_agent.define_faults(data=data, steady_states=steady_states, work_dir=hypothesis_dir)
        logs += fault_logs
        save_json(f"{hypothesis_dir}/faults.json", faults.dict())
        save_json(f"{hypothesis_dir}/faults_log.json", recursive_to_dict(fault_logs))

        #-------------------
        # make a hypothesis
        #-------------------
        hypothesis = Hypothesis(steady_states=steady_states, faults=faults)
        save_json(f"{hypothesis_dir}/hypothesis.json", hypothesis.dict())
        save_json(f"{hypothesis_dir}/hypothesis_log.json", recursive_to_dict(logs))
        return logs, hypothesis