from __future__ import annotations

import collections
from typing import Dict, Optional, Tuple, Union

import gym
import numpy as np
import torch

from offline_rl.rewards.reward_model import RewardModel
from offline_rl.rewards.tabular_reward_model import TabularRewardModel


class LineEnv(gym.Env):
    def __init__(self, side_length: int = 2, right_reward: float = 1, left_reward: float = -1):
        self.side_length = side_length
        self.right_reward = right_reward
        self.left_reward = left_reward
        self.observation_space = gym.spaces.Discrete(1 + side_length * 2)
        self.action_space = gym.spaces.Discrete(2)

        self.max_steps = side_length**2 * 50
        self.state = None
        self.t = None

    def _get_obs(self) -> int:
        return self.state + self.side_length

    def reset(self) -> int:
        # Internal state is in the range [-self.side_length, self.side_length]
        self.state = 0
        self.t = 0
        return self._get_obs()

    def step(self, action: int) -> Tuple[int, float, bool, Dict]:
        assert self.action_space.contains(action)
        self.t += 1
        dx = 1 if action == 0 else -1
        next_state = self.state + dx

        if next_state == self.side_length:
            terminal = True
            reward = self.right_reward
        elif next_state == -self.side_length:
            terminal = True
            reward = self.left_reward
        elif self.t >= self.max_steps:
            terminal = True
            reward = 0
        else:
            terminal = False
            reward = 0

        self.state = next_state
        return self._get_obs(), reward, terminal, {}

    def _render_text(self) -> str:
        string = list("_" * (self.side_length * 2 + 1))
        x = self._get_obs()
        string[x] = "*"
        return "".join(string)

    def _render_rgb_array(self) -> np.ndarray:
        arr = np.zeros(self.side_length * 2 + 1, dtype=np.uint8)
        x = self._get_obs()
        arr[x] = 1
        return arr

    def render(self, mode: str = "text") -> Union[str, np.ndarray]:
        if mode == "text":
            return self._render_text()
        elif mode == "human":
            return self._render_rgb_array()

    def render_policy(self, pi: Dict, *args, **kwargs) -> None:
        del args
        del kwargs
        policy_string = []
        for i in range(self.side_length * 2 + 1):
            action = pi[i]
            action_string = "right" if action == 0 else "left"
            policy_string.append(action_string)
        print(" ".join(policy_string))

    def render_value_function(self, v: Dict, *args, **kwargs) -> None:
        del args
        del kwargs
        value_string = []
        for i in range(self.side_length * 2 + 1):
            value = v[i]
            value_string.append(f"{value:.04f}")
        print(" ".join(value_string))

    def render_q_function(self, q: Dict, *args, std_q=None, **kwargs) -> None:
        del args
        del kwargs
        for action in [0, 1]:
            action_string = "right" if action == 0 else "left"
            print(action_string)
            value_string = []
            for i in range(self.side_length * 2 + 1):
                value = q[i][action]
                value_string.append(f"{value:.04f}")
                if std_q is not None:
                    std = std_q[i][action]
                    value_string.append(f"+-{std:.04f}")
                value_string.append("|")
            print(" ".join(value_string))

    def render_reward_function(self, rewards: Dict) -> None:
        key = next(iter(rewards))
        if len(key) == 3:
            # If the reward function comes in as transitions (s,a,s'), reduce it to valid (s,a).
            reduced = collections.defaultdict(dict)
            for (state, action, next_state), reward in rewards.items():
                delta = 1 if action == 0 else -1
                exp_next_state = max(min(state + delta, self.side_length * 2), 0)
                if next_state == exp_next_state:
                    reduced[state][action] = reward
            rewards = reduced
        self.render_q_function(rewards)


class LineEnvReward(TabularRewardModel):
    """Defines rewards for LineEnv.

    Each `@classmethod` below defines a different reward model.
    """
    @classmethod
    def make_ground_truth_reward(
            cls,
            side_length: int = 2,
            right_reward: float = 1,
            left_reward: float = -1,
    ) -> LineEnvReward:
        """Makes the ground-truth reward model for the environment.

        Args:
            side_length: The number of states to the left and right of center in the env.
            right_reward: Reward for reaching the right side.
            left_reward: Reward for reaching the left side.

        Returns:
            An instance of this class with the ground-truth reward function.
        """
        num_states = side_length * 2 + 1
        num_actions = 2
        rewards = torch.zeros((num_states, num_actions, num_states))

        # This actually encodes reward values that are not possible under the
        # transition model of the environment, but it seems appropriate that
        # transitioning to the leftmost or rightmost states from any other state
        # via any action should yield the leftmost/rightmost reward.
        min_state = 0
        rewards[:, :, min_state] = left_reward
        max_state = num_states - 1
        rewards[:, :, max_state] = right_reward
        return cls(rewards)

    @classmethod
    def make_reverse_reward(
            cls,
            side_length: int = 2,
            right_reward: float = 1,
            left_reward: float = -1,
    ) -> LineEnvReward:
        """Makes the reverse of the ground-truth reward model for the environment.

        Args:
            side_length: The number of states to the left and right of center in the env.
            right_reward: Reward for reaching the right side.
            left_reward: Reward for reaching the left side.

        Returns:
            An instance of this class with the reverse of the ground-truth reward function.
        """
        # Flip the left and right reward arguments.
        return cls.make_ground_truth_reward(side_length=side_length, right_reward=left_reward, left_reward=right_reward)

    @classmethod
    def make_zero_reward(cls, side_length: int = 2) -> LineEnvReward:
        """Makes a zero-everywhere reward.

        Args:
            side_length: The number of states to the left and right of center in the env.

        Returns:
            An instance of this class with a zero-everywhere reward.
        """
        return cls.make_ground_truth_reward(side_length=side_length, right_reward=0, left_reward=0)

    @classmethod
    def make_center_reward(cls, side_length: int = 2) -> LineEnvReward:
        """Makes a reward that is 1 when transitioning to the center state.

        Args:
            side_length: The number of states to the left and right of center in the env.

        Returns:
            An instance of this class with a center-1 reward.
        """
        num_states = side_length * 2 + 1
        num_actions = 2
        rewards = torch.zeros((num_states, num_actions, num_states))

        center_state = side_length
        rewards[:, :, center_state] = 1

        return cls(rewards)


class LineEnvRightwardPotential(RewardModel):
    """A potential shaping of a base reward that adds one for moving right.
    
    Args:
        base_reward: The base reward model to shape.
        potential_value: The value to add/subtract for moving right/left.
    """
    def __init__(self, base_reward: RewardModel):
        self.base_reward = base_reward

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        original = self.base_reward.reward(states, actions, next_states, terminals)
        states_potential = states.to(original.dtype)
        next_states_potential = next_states.to(original.dtype)
        # Assumes a discount of one.
        shaped = original + next_states_potential - states_potential
        return shaped

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.base_reward.observation_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.base_reward.action_space
