"""Augment the environment observation space with a one-hot ground truth mode
vector."""
from gymnasium import Wrapper
from swmpo.gymnasium_wrapper import get_augmented_observation
from swmpo.gymnasium_wrapper import get_one_hot_encoding
from swmpo.gymnasium_wrapper import get_total_reward
import numpy as np
from gymnasium import spaces
from gymnasium import Env
from collections import defaultdict


class GroundTruthWrapper(Wrapper):
    """Expects the wrapped environment to provide a `ground_truth_mode`
    in the `info` dictionaries."""

    def __init__(
            self,
            env: Env,
            extrinsic_reward_scale: float,
            mode_n: int,
            exploration_window_size: int,
            ):
        super().__init__(env)
        self.env = env
        self.mode_n = mode_n
        old_space = env.observation_space
        assert isinstance(old_space, spaces.Box)
        encoding_len = len(get_one_hot_encoding(
            0, mode_n,
        ))
        low = np.concatenate((
            old_space.low,
            np.array([0.0 for _ in range(encoding_len)]),
        ))
        high = np.concatenate((
            old_space.high,
            np.array([1.0 for _ in range(encoding_len)]),
        ))
        self.observation_space = spaces.Box(
            low=low,
            high=high,
            dtype=np.float32,
        )
        self.extrinsic_reward_scale = extrinsic_reward_scale
        self.visited_nodes = set()
        self.exploration_window_size = exploration_window_size
        self.mode_rewards = [defaultdict(list)]

    def reset(self, *args, **kwargs):
        self.visited_nodes = set()
        self.mode_rewards.append(defaultdict(list))

        obs, info = self.env.reset(*args, **kwargs)
        node_i = info["ground_truth_mode"]
        augmented_obs = get_augmented_observation(
            obs=obs,
            active_i=node_i,
            total_n=self.mode_n,
        )
        return augmented_obs, info

    def step(self, action):
        # Step environment
        obs, reward, terminated, truncated, info = self.env.step(action)

        # Get new node
        new_node = info["ground_truth_mode"]

        # Get total reward
        if len(self.mode_rewards) < 2:
            prev_episode_mode_rewards = None
        else:
            prev_episode_mode_rewards = self.mode_rewards[-2]
        total_reward = get_total_reward(
            prev_episode_mode_rewards=prev_episode_mode_rewards,
            current_episode_mode_rewards=self.mode_rewards[-1][new_node],
            base_reward=float(reward),
            current_mode=new_node,
            current_episode_visited_modes=self.visited_nodes,
            exploration_window_size=self.exploration_window_size,
            extrinsic_reward_scale=self.extrinsic_reward_scale,
        )

        # Get augmented observation
        augmented_obs = get_augmented_observation(
            obs=obs,
            active_i=new_node,
            total_n=self.mode_n,
        )

        # Update visited nodes
        self.visited_nodes = self.visited_nodes | {new_node}
        self.current_node = new_node

        return augmented_obs, total_reward, terminated, truncated, info
