import json
import abc
import dataclasses as dc
from pathlib import Path
from typing import TypedDict, TypeVar, Literal, Callable, Any, Self, Protocol


@dc.dataclass(slots=True)
class CoT[In, Thought, Out]: 
    """The generic class for a CoT (Chain-of-Thought) reasoning."""

    input: In
    thought: Thought
    outcome: Out
    
    def as_dict(self):
        return dc.asdict(self)

    @classmethod
    def load_instances(cls, file: str | Path):
        with open(file, 'rt') as f:
            data = json.load(f)
        if not isinstance(data, list):
            raise TypeError
        return [cls(**sample) for sample in data]



class MDPStep[State, Action](TypedDict):

    state: State
    action: Action | None


_In = TypeVar('_In')
_Step = TypeVar('_Step')
_Out = TypeVar('_Out')
_State = TypeVar('_State')
_Action = TypeVar('_Action')


StepwiseCoT = CoT[_In, list[_Step], _Out]
"""The type of long CoT with multiple steps."""


MDPCoT = StepwiseCoT[_In, MDPStep[_State, _Action], _Out]
"""In such an MTP, the next state is explicitly included in each step, leading to a
MDP-style CoT strucutre. This simplifies the implementation of the transition function,
which simply fetches the next state from the step.
"""


type ReasonerFn[In, Thought, Out] = Callable[[In], tuple[Out, Thought]]


class Reasoner[In, Thought, Out](Protocol):
    """
    The abstract reasoner following arbitrary mechanism.
    """

    def __call__(self, input: In) -> tuple[Thought, Out]:  ...


class StepReasoner[In, Thought, Out](Reasoner[In, Thought, Out]):
    """
    The logical (non-generative) reasoner that have multiple steps.
    """

    def __call__(self, input: In) -> tuple[Thought, Out]:
        self.begin(input)
        while True:
            done = self.step()
            if done:
                break
        return self.end()
    
    @abc.abstractmethod
    def begin(self, input: In) -> None:
        raise NotImplementedError
    
    @abc.abstractmethod
    def step(self) -> bool | None:
        raise NotImplementedError
    
    @abc.abstractmethod
    def end(self) -> tuple[Thought, Out]:
        """
        The behavior of ending the reasoning.
        """

        raise NotImplementedError


class AgentReasoner[In, Action, Thought, Out](StepReasoner[In, Thought, Out]):
    """
    The logical (non-generative) reasoner that runs a decision process.
    """
    
    def step(self) -> bool | None:

        action = self.actor()
        return self.transit(action)
    
    @abc.abstractmethod
    def actor(self) -> Action:
        raise NotImplementedError

    @abc.abstractmethod
    def transit(self, action: Action) -> bool | None:
        raise NotImplementedError

