from __future__ import annotations

from typing import Any, Optional, Protocol


Prediction = Any
State = Any
Value = Any


class Getter(Protocol):
    def __call__(self, state: State) -> Prediction: ...

    def forward(self, state: State) -> Value: ...

    def get(self, state: State) -> Value: ...

    def decode_state(self, state: State) -> str: ...


class Putter(Protocol):
    def __call__(self, state: State, value: Value) -> Prediction: ...

    def put(self, state: State, value: Value) -> State: ...

    def decode_state(self, state: State) -> str: ...

    def generate(self,
                 state: State,
                 value: Value,
                 strategy: str,
                 prepare_for_getter: bool,
                 prepare_for_putter: bool,
                 valid_state: Optional[State] = None) -> State: ...


def classify(getter: Getter,
             state: State,
             value: Value) -> tuple[Value, Value]:
    print(f'classify: {getter.decode_state(state) = }')
    prediction = getter.forward(state)
    target = value

    return prediction, target


def get_put(getter: Getter,
            putter: Putter,
            state: State,
            valid_state: State) -> tuple[State, State]:
    prediction = putter(state, getter.get(valid_state))
    target = valid_state

    return prediction, target


def put(putter: Putter,
        state: State,
        value: Value) -> tuple[State, State]:
    prediction = putter(state, value)
    target = state

    return prediction, target


def put_get(putter: Putter,
            getter: Getter,
            state: State,
            value: Value) -> tuple[Value, Value]:
    getter_input = putter.generate(state,
                                   value,
                                   strategy='beam',
                                   prepare_for_getter=True)
    prediction = getter.forward(getter_input)
    target = value

    return prediction, target


def put_put(putter: Putter,
            state: State,
            value1: Value,
            value2: Value,
            valid_state: State) -> tuple[State, State]:
    putter_2_input = putter.generate(state,
                                     value1,
                                     strategy='beam',
                                     prepare_for_putter=True,
                                     valid_state=valid_state)
    prediction = putter.put(
        putter_2_input,
        value2
    )
    target = putter.put(state, value2)

    return prediction, target


def undo(putter: Putter,
         getter: Getter,
         state: State,
         value: Value,
         valid_state: State) -> tuple[State, State]:
    putter_2_input = putter.generate(state,
                                     value,
                                     strategy='beam',
                                     prepare_for_putter=True,
                                     valid_state=valid_state)

    prediction = putter(
        putter_2_input,
        getter.get(valid_state)
    )
    target = valid_state

    return prediction, target
