"""Reinforcement Learning-based solver for Pylic predicates.

Behind the scenes, this module wraps a "step-based" program into a `gym.Env`
whose reward depends on the robustness value of the input formula.

Naturally, whether the resulting optimization problem is a Markov Decision
Process depends on the step-program.
"""
from typing import Callable, Concatenate, ParamSpec, TypeVar
from typing import Union
from typing import Generic
from enum import Enum
from enum import auto
from pylic.code_transformations import get_tape
from pylic.tape import Tape
from pylic.tape import ReturnNode
from pylic.predicates import Predicate
from pylic.predicates import predicate_interpreter
from pylic.predicates import CustomFunction
from pylic.predicates import CustomFilter
import torch
import gym
import gym.spaces
import stable_baselines3.common.base_class
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold


T = TypeVar("T")
K = TypeVar("K")
RuntimeData = ParamSpec("RuntimeData")
Parameters = torch.Tensor
Action = torch.Tensor
State = TypeVar("State")


SB3Algorithm = TypeVar(
    "SB3Algorithm",
    bound=stable_baselines3.common.base_class.BaseAlgorithm
)


class StepWrapRewardStrategy(Enum):
    """Different ways to handle step-program wrapping for predicate solving."""
    EPISODE_REWARD = auto()
    STEP_REWARD = auto()
    STEP_REWARD_DELTA = auto()


StateVectorizer = Callable[[
    State,
    int,  # current timestep
    int,  # episode timesteps
], torch.Tensor]
"""Represents a function that featurizes a program state."""


class WrappedSAT(gym.Env, Generic[T]):
    """Base class for step-program predicate solving wrappers."""
    def __init__(
        self,
        predicate: Predicate,
        step: Callable[Concatenate[
            Action,
            State,
            int,  # current timestep
            RuntimeData
            ], State],
        initial_state: State,
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        restart_timestep_n: int,  # steps before resetting to initial state
        action_space: gym.spaces.Space,
        observation_space: gym.spaces.Space,
        state_to_observation: StateVectorizer,
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ):
        self.predicate = predicate
        self.f_step = step
        self.initial_state = initial_state
        self.max_value = max_value
        self.custom_functions = custom_functions
        self.custom_filters = custom_filters
        self.restart_timestep_n = restart_timestep_n
        self.args = args
        self.kwargs = kwargs
        self.action_space = action_space
        self.observation_space = observation_space
        self.current_state = self.initial_state
        self.state_to_observation = state_to_observation
        self.current_t = 0
        self.current_trace = Tape()

    def reset(self):
        self.current_trace = Tape()
        self.current_t = 0
        self.current_state = self.initial_state
        return self.state_to_observation(
            self.current_state,
            self.current_t,
            self.restart_timestep_n,
        )

    def get_current_trace(self, action):
        """Trace step program under the current state with the given action."""
        tape = get_tape(
            self.f_step,
            None,
            *(
                torch.tensor(action),
                self.current_state,
                self.current_t,
                *self.args,
            ),
            **self.kwargs,
        )
        return tape

    def get_new_state(self, tape):
        """Extract the output state in the given tape."""
        new_state = [
            node
            for node in tape
            if isinstance(node, ReturnNode)
        ][0].value
        return new_state



class EpisodicRewardWrappedSAT(Generic[T], WrappedSAT[T]):
    """Wrap a step-program into a `gym.Env` for predicate solving, where reward
    is zero if the episode has not ended, and otherwise corresponds to the
    robustness value of the predicate evaluated on the concatenation of
    execution traces of the step function in the episode."""
    def step(self, action):
        # Trace step-program under given action
        tape = self.get_current_trace(action)

        # Extract new state from tape
        new_state = self.get_new_state(tape)

        # Extend episode trace and update internal state
        self.current_trace.extend(tape)
        self.current_t += 1
        self.current_state = new_state

        # Compute step output
        done = self.current_t >= self.restart_timestep_n
        observation = self.state_to_observation(
            self.current_state,
            self.current_t,
            self.restart_timestep_n,
        )
        reward = 0 if not done else predicate_interpreter(
            predicate=self.predicate,
            input_tape=self.current_trace,
            max_value=self.max_value,
            custom_functions=self.custom_functions,
            custom_filters=self.custom_filters,
        )
        info = dict()
        return observation, float(reward), done, info


class StepRewardWrappedSAT(Generic[T], WrappedSAT[T]):
    """Wrap a step-program into a `gym.Env` for predicate solving, where reward
    corresponds to the robustness value of the predicate evaluated on the
    concatenation of execution traces of the step function in the episode so
    far."""
    def step(self, action):
        # Trace step-program under given action
        tape = self.get_current_trace(action)

        # Extract new state from tape
        new_state = self.get_new_state(tape)

        # Extend episode trace and update internal state
        self.current_trace.extend(tape)
        self.current_t += 1
        self.current_state = new_state

        done = self.current_t >= self.restart_timestep_n
        observation = self.state_to_observation(
            self.current_state,
            self.current_t,
            self.restart_timestep_n,
        )
        reward = predicate_interpreter(
            predicate=self.predicate,
            input_tape=self.current_trace,
            max_value=self.max_value,
            custom_functions=self.custom_functions,
            custom_filters=self.custom_filters,
        )
        info = dict()
        return observation, float(reward), done, info


class StepRewardDeltaWrappedSAT(Generic[T], WrappedSAT[T]):
    """Wrap a step-program into a `gym.Env` for predicate solving, where reward
    corresponds to the step-wise difference in robustness value of the
    predicate evaluated on the concatenation of execution traces of the step
    function in the episode so far."""
    def step(self, action):
        # Trace step-program under given action
        tape = self.get_current_trace(action)

        # Extract new state from tape
        new_state = self.get_new_state(tape)

        # Extend episode trace and update internal state
        self.current_trace.extend(tape)
        self.current_t += 1
        self.current_state = new_state

        done = self.current_t >= self.restart_timestep_n
        observation = self.state_to_observation(
            self.current_state,
            self.current_t,
            self.restart_timestep_n,
        )
        current_reward = predicate_interpreter(
            predicate=self.predicate,
            input_tape=self.current_trace,
            max_value=self.max_value,
            custom_functions=self.custom_functions,
            custom_filters=self.custom_filters,
        )
        reward = 0.0\
            if self.previous_reward is None\
            else current_reward - self.previous_reward
        self.previous_reward = current_reward
        info = dict()
        return observation, float(reward), done, info


    def reset(self):
        observation = super().reset()
        self.previous_reward = None
        return observation



def solver(
        predicate: Predicate,
        step: Callable[Concatenate[
            Action,
            State,
            int,  # current timestep
            RuntimeData
            ], State],
        initial_state: State,
        max_value: torch.Tensor,
        custom_functions: dict[str, CustomFunction],
        custom_filters: dict[str, CustomFilter],
        restart_timestep_n: int,  # how many times to step before resetting to initial state
        train_timestep_n: int,
        get_sb_algorithm: Callable[
            [gym.Env|SubprocVecEnv],
            SB3Algorithm,
            ],
        action_space: gym.spaces.Space,
        observation_space: gym.spaces.Space,
        state_to_observation: Callable[[
            State,
            int,  # current timestep
            int,  # episode timesteps
            ], torch.Tensor],
        worker_n: Union[int, None],  # None for no vectorization
        wrap_strategy: StepWrapRewardStrategy,
        verbose: bool,
        *args: RuntimeData.args,
        **kwargs: RuntimeData.kwargs,
        ) -> SB3Algorithm:
    """Solve the predicate using the given Reinforcement Learning
    algorithm.

    The program will be wrapped into a `gym.Env` depending on the value of
    `wrap_strategy`:

    - `EPISODE_REWARD`: the robustness value of the predicate is given as
      reward at the end of each episode (after `restart_timestep_n` steps).

    - `STEP_REWARD`: the robustness value of the predicate is given as
      reward at each timestep. Naturally, the predicate must have well defined
      values for every step sequence.

    - `STEP_REWARD_DELTA`: the step-wise difference in robustness value of the
      predicate is given as reward at each timestep. Naturally, the predicate
      must have well defined values for every step sequence.

    In every case, robustness values correspond to the predicate evaluated on
    the concatenation of the execution traces of the steps in the current
    episode.
    """
    # Helper to get a gym.Env for the given program and predicate
    def make_env(wrap_strategy: StepWrapRewardStrategy):
        if wrap_strategy is StepWrapRewardStrategy.EPISODE_REWARD:
            return EpisodicRewardWrappedSAT(
                step=step,
                predicate=predicate,
                initial_state=initial_state,
                max_value=max_value,
                custom_functions=custom_functions,
                custom_filters=custom_filters,
                restart_timestep_n=restart_timestep_n,
                observation_space=observation_space,
                action_space=action_space,
                state_to_observation=state_to_observation,
                *args,
                **kwargs,
            )
        if wrap_strategy is StepWrapRewardStrategy.STEP_REWARD:
            return StepRewardWrappedSAT(
                step=step,
                predicate=predicate,
                initial_state=initial_state,
                max_value=max_value,
                custom_functions=custom_functions,
                custom_filters=custom_filters,
                restart_timestep_n=restart_timestep_n,
                observation_space=observation_space,
                action_space=action_space,
                state_to_observation=state_to_observation,
                *args,
                **kwargs,
            )
        if wrap_strategy is StepWrapRewardStrategy.STEP_REWARD_DELTA:
            return StepRewardDeltaWrappedSAT(
                step=step,
                predicate=predicate,
                initial_state=initial_state,
                max_value=max_value,
                custom_functions=custom_functions,
                custom_filters=custom_filters,
                restart_timestep_n=restart_timestep_n,
                observation_space=observation_space,
                action_space=action_space,
                state_to_observation=state_to_observation,
                *args,
                **kwargs,
            )
        raise ValueError(f"Unrecognized wrap strategy {wrap_strategy}!")

    # Check env
    check_env(make_env(wrap_strategy))

    # Instantiate environment
    env = make_env(wrap_strategy)\
        if worker_n is None\
        else SubprocVecEnv([lambda: make_env(wrap_strategy)] * worker_n)

    # Instantiate the RL algorithm
    model = get_sb_algorithm(env)

    # Setup a call-back to stop training as soon as the predicate is satisfied.
    # The call-back evaluates the robustness value of the predicate in an
    # entire episode. The policy is executed in deterministic mode.
    eval_env = make_env(StepWrapRewardStrategy.EPISODE_REWARD)
    callback_on_best = StopTrainingOnRewardThreshold(
        reward_threshold=1e-6,  # minimum robustness value to stop training
        verbose=verbose,
    )
    eval_callback = EvalCallback(
        eval_env,
        callback_on_new_best=callback_on_best,
        verbose=verbose,
        deterministic=True,
    )

    # Optimize model
    model.learn(
        total_timesteps=train_timestep_n,
        callback=eval_callback,
    )
    return model
