import math
import operator
from functools import reduce

import numpy as np
import gym
from gym import error, spaces, utils
# from .minigrid import OBJECT_TO_IDX, COLOR_TO_IDX, STATE_TO_IDX, Goal

class PureStateBonus(gym.Wrapper):
    """
    Assigns an exploration bonus based on which positions
    are visited on the grid.
    Record the counts in info.
    Only done at max steps to encourage exploration.
    """

    def __init__(self, env):
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        # # Only done at max steps to encourage exploration.
        # done = (self.step_count > self.max_steps)

        # Tuple based on which we index the counts
        # We use the position after an update
        env = self.unwrapped
        tup = (tuple(env.agent_pos))

        # Get the count for this key
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this key
        new_count = pre_count + 1
        self.counts[tup] = new_count
        info["counts"] = self.counts

        bonus = 1 / math.sqrt(new_count)
        reward = bonus

        return obs, reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class PureStateBonusNeverStop(gym.Wrapper):
    """
    Assigns an exploration bonus based on which positions
    are visited on the grid.
    Record the counts in info.
    Only done at max steps to encourage exploration.
    """

    def __init__(self, env):
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        if done:
            if self.step_count >= self.max_steps:
                pass
            else:
                done = False
        # Only done at max steps to encourage exploration.
        # done = (self.step_count >= self.max_steps)


        # Tuple based on which we index the counts
        # We use the position after an update
        env = self.unwrapped
        tup = (tuple(env.agent_pos))

        # Get the count for this key
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this key
        new_count = pre_count + 1
        self.counts[tup] = new_count
        info["counts"] = self.counts

        bonus = 1 / math.sqrt(new_count)
        reward = bonus

        return obs, reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

class RecordStateVisitation(gym.Wrapper):
    """
    Assigns an exploration bonus based on which positions
    are visited on the grid.
    Record the counts in info.
    Only done at max steps to encourage exploration.
    """

    def __init__(self, env):
        super().__init__(env)
        self.counts = {}

    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        # # Only done at max steps to encourage exploration.
        # done = (self.step_count > self.max_steps)

        # Tuple based on which we index the counts
        # We use the position after an update
        env = self.unwrapped
        tup = (tuple(env.agent_pos))

        # Get the count for this key
        pre_count = 0
        if tup in self.counts:
            pre_count = self.counts[tup]

        # Update the count for this key
        new_count = pre_count + 1
        self.counts[tup] = new_count
        info["counts"] = self.counts

        # bonus = 1 / math.sqrt(new_count)
        # reward = bonus

        return obs, reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)