from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union

from tp_lodge.task_planning.models.sas.sas_action import SasAction
from tp_lodge.task_planning.models.sas.sas_plan import SasPlan


@dataclass
class PlanActionResult:
    pass

    def to_string(self) -> str:
        raise NotImplementedError()


PlanAnyActionResult = Union["PlanLeafActionResult", "PlanCompositeActionResult"]


@dataclass
class PlanLeafActionResult(PlanActionResult):
    py_function: str

    def to_string(self):
        return self.py_function


@dataclass
class PlanCompositeActionResult(PlanActionResult):
    sub_actions: List[Tuple[SasAction, Optional[PlanAnyActionResult]]] = field(default_factory=list)

    @staticmethod
    def from_plan(sas_plan: SasPlan) -> "PlanCompositeActionResult":
        result = PlanCompositeActionResult()
        for action in sas_plan.actions:
            result.add_action(action)
        return result

    def get_flattened_skills(self) -> List[str]:
        skills = []
        for _, result in self.sub_actions:
            if result is None:
                continue
            if isinstance(result, PlanLeafActionResult):
                skills.append(result.py_function)
            elif isinstance(result, PlanCompositeActionResult):
                skills.extend(result.get_flattened_skills())
            else:
                raise NotImplementedError()
        return skills

    def inject_composite(self, composite: "PlanCompositeActionResult"):
        if self.sub_actions[-1][1] is None:
            self.sub_actions[-1] = (self.sub_actions[-1][0], composite)
        else:
            self.sub_actions[-1][1].inject_composite(composite)

    def add_action(self, sas_action: SasAction):
        self.sub_actions.append((sas_action, None))

    def unspecify_action(self, sas_action: SasAction):
        for i, (action, _) in enumerate(self.sub_actions):
            if action == sas_action:
                self.sub_actions[i] = (action, None)
                break
        else:
            raise ValueError(f"Action {sas_action} not found in sub_actions.")

    def specify_action(self, sas_action: SasAction, result: PlanAnyActionResult):
        for i, (action, _) in enumerate(self.sub_actions):
            if action == sas_action:
                self.sub_actions[i] = (action, result)
                break
        else:
            raise ValueError(f"Action {sas_action} not found in sub_actions.")

    def to_string(self):
        response_str = []
        for sas_action, sub_action in self.sub_actions:
            title = "- %s" % sas_action.to_string()
            if isinstance(sub_action, PlanLeafActionResult):
                title += ": %s" % sub_action.to_string()
            response_str.append(title)
            if isinstance(sub_action, PlanCompositeActionResult):
                nested_text = sub_action.to_string()
                nested_text = "\n".join(["  " + t for t in nested_text.splitlines()])
                response_str.append(nested_text)
        return "\n".join(response_str)

    def get_last_composite(self) -> "PlanCompositeActionResult":
        if self.sub_actions[-1][1] is None:
            return self
        else:
            return self.sub_actions[-1][1].get_last_composite()

    def get_mapping_from_action_to_skill(self) -> Dict[SasAction, str]:
        mapping = {}
        for sas_action, sub_action in self.sub_actions:
            if isinstance(sub_action, PlanLeafActionResult):
                mapping[sas_action.name] = sub_action.py_function
            elif isinstance(sub_action, PlanCompositeActionResult):
                mapping.update(sub_action.get_mapping_from_action_to_skill())
            elif sub_action is None:
                continue
            else:
                raise NotImplementedError()
        return mapping


@dataclass
class PlanResult:
    composite: Optional[PlanCompositeActionResult] = None

    def inject_composite(self, composite: PlanCompositeActionResult):
        if self.composite is None:
            self.composite = composite
        else:
            self.composite.inject_composite(composite=composite)

    def get_last_composite(self) -> PlanCompositeActionResult:
        assert self.composite is not None
        return self.composite.get_last_composite()

    def get_flattened_skills(self) -> List[str]:
        assert self.composite is not None
        return self.composite.get_flattened_skills()

    def get_mapping_from_action_to_skill(self) -> Dict[SasAction, str]:
        return self.composite.get_mapping_from_action_to_skill()

    def to_string(self) -> str:
        return self.composite.to_string()
