"""Augment MDPs with state machines using Gymnasium's API."""
from collections import defaultdict
from gymnasium import Wrapper
from swmpo.state_machine import StateMachine
from swmpo.state_machine import state_machine_model
from swmpo.transition import Transition
import numpy as np
from gymnasium import spaces
from gymnasium import Env
import statistics
import torch


def get_one_hot_encoding(
    active_i: int,
    total_n: int,
) -> list[float]:
    """Get the one-hot encoding of the given active class."""
    one_hot_vector = [
        0.0 if i != active_i else 1.0
        for i in range(total_n)
    ]
    return one_hot_vector


def get_augmented_observation(
    obs: np.ndarray,
    active_i: int,
    total_n: int,
) -> np.ndarray:
    """Augment the observation with a one-hot encoding
    vector of the current state of the state machine."""
    one_hot_vector = np.array(get_one_hot_encoding(active_i, total_n))
    augmented_observation = np.concatenate((
        obs,
        one_hot_vector,
    ))
    return augmented_observation


def get_exploration_reward(
    current_episode_mode_rewards: list[float],
    prev_episode_mode_rewards: list[float] | None,
    window_size: int,
) -> float:
    """
    Reward the agent for exploring modes where reward has been improving.
    - current_reward: reward in current timestep
    - mode_rewards: previous rewards in current timestep
    """
    if len(current_episode_mode_rewards) == 0:
        return 0.0
    if prev_episode_mode_rewards is None:
        return 0.0
    if len(prev_episode_mode_rewards) == 0:
        return 0.0

    # Return the rate of improvement in the window of the last rewards
    new_mean = statistics.mean(current_episode_mode_rewards)
    old_mean = statistics.mean(prev_episode_mode_rewards)
    delta_mean = new_mean - old_mean
    return max(0.0, delta_mean)


def get_total_reward(
    base_reward: float,
    current_episode_mode_rewards: list[float],
    prev_episode_mode_rewards: list[float] | None,
    current_mode: int,
    current_episode_visited_modes: set[int],
    exploration_window_size: int,
    extrinsic_reward_scale: float,
) -> float:
    # Get exploration reward
    exploration_reward = get_exploration_reward(
        prev_episode_mode_rewards=prev_episode_mode_rewards,
        current_episode_mode_rewards=current_episode_mode_rewards,
        window_size=exploration_window_size,
    )

    # Get extrinsic reward
    should_reward = current_mode not in current_episode_visited_modes
    extrinsic_reward = extrinsic_reward_scale if should_reward else 0.0

    # Get total reward
    total_reward = sum((
        base_reward,
        extrinsic_reward,
        exploration_reward
    ))
    return total_reward


def cast_state(state):
    if not isinstance(state, torch.Tensor):
        if isinstance(state, np.ndarray):
            state = torch.from_numpy(state)
        else:
            state = torch.tensor(state)
    else:
        state = state
    return state


class DeepSynthWrapper(Wrapper):
    """Augment an environment with information from a state machine.

    This wrapper modifies observations and rewards:
    - reward is augmented providing a constant value for each new
      node visited during an episode.
    - observations are augmented with a one-hot encoding of the state
      of the state machine.

    It is expected the wrapped env returns a dict with a `state` key that
    contains a `torch.Tensor` that will be used to update the state machine.

    This augmented version of an MDP was proposed in Hasanbeig et al, 2021.
    """

    def __init__(
            self,
            env: Env,
            state_machine: StateMachine,
            initial_state_machine_state: int,
            extrinsic_reward_scale: float,
            exploration_window_size: int,
            dt: float,
            ):
        super().__init__(env)
        old_space = env.observation_space
        assert isinstance(old_space, spaces.Box)
        encoding_len = len(get_one_hot_encoding(
            initial_state_machine_state, len(state_machine.local_models),
        ))
        low = np.concatenate((
            old_space.low,
            np.array([0.0 for _ in range(encoding_len)]),
        ))
        high = np.concatenate((
            old_space.high,
            np.array([1.0 for _ in range(encoding_len)]),
        ))
        self.observation_space = spaces.Box(
            low=low,
            high=high,
            dtype=np.float32,
        )
        self.initial_node = initial_state_machine_state
        self.current_node = initial_state_machine_state
        self.state_machine = state_machine
        self.extrinsic_reward_scale = extrinsic_reward_scale
        self.visited_nodes = set()
        self.exploration_window_size = exploration_window_size
        self.dt = dt

        # Keep track of per-mode rewards
        self.mode_rewards = [defaultdict(list)]

    def reset(self, *args, **kwargs):
        self.current_node = self.initial_node
        self.visited_nodes = set()
        self.mode_rewards.append(defaultdict(list))

        obs, info = self.env.reset(*args, **kwargs)
        augmented_obs = get_augmented_observation(
            obs=obs,
            active_i=self.current_node,
            total_n=len(self.state_machine.local_models),
        )
        info["active_mode"] = self.current_node
        self.state = info["state"]
        return augmented_obs, info

    def step(self, action):
        # Step environment
        prev_state = self.state
        obs, reward, terminated, truncated, info = self.env.step(action)
        info["active_mode"] = self.current_node

        # Step node machine
        new_state = info["state"]
        _, new_node = state_machine_model(
            state_machine=self.state_machine,
            state=cast_state(new_state),
            prev_action=torch.from_numpy(action),
            prev_state=cast_state(prev_state),
            current_node=self.current_node,
            dt=self.dt,
        )

        # Get total reward
        if len(self.mode_rewards) < 2:
            prev_episode_mode_rewards = None
        else:
            prev_episode_mode_rewards = self.mode_rewards[-2]
        total_reward = get_total_reward(
            base_reward=float(reward),
            prev_episode_mode_rewards=prev_episode_mode_rewards,
            current_episode_mode_rewards=self.mode_rewards[-1][new_node],
            current_mode=new_node,
            current_episode_visited_modes=self.visited_nodes,
            exploration_window_size=self.exploration_window_size,
            extrinsic_reward_scale=self.extrinsic_reward_scale,
        )

        # Get augmented observation
        augmented_obs = get_augmented_observation(
            obs=obs,
            active_i=new_node,
            total_n=len(self.state_machine.local_models),
        )

        # Update per-mode rewards
        self.mode_rewards[-1][new_node].append(float(reward))

        # Update state machine and visited nodes
        self.visited_nodes = self.visited_nodes | {new_node}
        self.current_node = new_node

        self.state = new_state

        return augmented_obs, total_reward, terminated, truncated, info
