import numpy as np
import gym
import matplotlib.pyplot as plt
from gym.spaces import Box, Dict
from IPython import embed

import sys
np.set_printoptions(threshold=sys.maxsize)
np.set_printoptions(edgeitems=30, linewidth=100000, formatter=dict(float=lambda x: "%.3g" % x))


class CountBasedExplorationWrapper(gym.Wrapper):
    def __init__(
            self,
            env,
            heatmap_shape,
            beta=1.0,
            add_true_rew=False,
            proc_gen=False,
            episodic=False,
            augment=False,
            use_cnn=False,
            partial_obs=False,
            salesman=False,
            no_intr_rew=False,
            level_string=""
        ):
        
        super().__init__(env)

        self.partial_obs = partial_obs
        self.augment = augment
        self.salesman = salesman
        self.use_cnn = use_cnn
        self.episodic = episodic
        self.level_string = level_string
        self.proc_gen = proc_gen
        self.x, self.y = -1, -1
        
        self.t = 0
        self.max_steps = 1500 if self.episodic == True else 250

        # initialize objects for measuring coverage
        self.covered_states = np.zeros(heatmap_shape)
        self.episodic_covered_states = np.zeros(heatmap_shape)
        self.env = env
        obs = self.env.reset(random_gen=False, level_string=level_string)
        
        if partial_obs == False and self.pixel_obs == False:
            obs_ = np.argmax(obs, axis=0)
            self.covered_states[np.where(obs_ == 2)] = -1
            self.episodic_covered_states[np.where(obs_ == 2)] = -1
        else:
            state = self.env.get_state()
            for obj in state["Objects"]:
                if obj["Name"] == "wall":
                    x, y = obj["Location"]
                    self.episodic_covered_states[x,y] = -1
                    self.covered_states[x,y] = -1
                elif obj["Name"] == "avatar":
                    x, y = obj["Location"]
                    self.x, self.y = x, y
                    self.episodic_covered_states[x,y] = 1
                    self.covered_states[x,y] = 1

        # initialize heatmap for augmentation        
        self.heatmap_shape = heatmap_shape
        self.heatmap = np.ones(heatmap_shape)

        # count-related params
        self.add_true_rew = add_true_rew
        self.no_intr_rew = no_intr_rew
        self.beta = beta

        # build observation space
        self.old_obs_space_shape = env.observation_space.shape

        # will NOT use a CNN for partial observations of size 5x5, but can still use CNN for the heatmaps
        if use_cnn == True and partial_obs == False:
            img_obs = {"image" : Box(-np.inf, np.inf, shape = (1, self.old_obs_space_shape[1], self.old_obs_space_shape[2]))}
        else:
            img_obs = {"image" : Box(-np.inf, np.inf, shape = (np.array(self.old_obs_space_shape).prod(), ))}

        self.observation_space = Dict({
            "image" : img_obs["image"]
        })
                    
        if augment == True:
            if use_cnn == True:
                heatmap_obs = {"heatmap_and_pos" : Box(-np.inf, np.inf, shape = (2, heatmap_shape[0], heatmap_shape[1])),}
            else:
                heatmap_obs = {"heatmap_and_pos" : Box(-np.inf, np.inf, shape = (np.array(heatmap_shape).prod() * 2, ))}
                
            self.observation_space["heatmap_and_pos"] = heatmap_obs["heatmap_and_pos"]

        print("======= OBS SPACE =======")
        print(self.observation_space)
    
    def step(self, action):
        # step environment
        obs, r, d, info = self.env.step(action)

        # terminate the episode if the goal is reacher and the objective was to find the goal
        if r > 0:
            if self.no_intr_rew == True or self.add_true_rew == True:
                d = True

        # get new agent pos
        if self.partial_obs or self.pixel_obs:
            state = self.env.get_state()
            for obj in state["Objects"]:
                if obj["Name"] == "avatar":
                    x, y = obj["Location"]
                    break
        else:
            x, y = np.unravel_index(np.argmax(obs[0], axis=None), obs[0].shape)

        self.x, self.y = x, y

        # compute intrinsic reward
        if self.salesman == True:
            count_rew = 1 if self.heatmap[x, y] == 1 else 0
        else:
            count_rew = self.beta / np.sqrt(self.heatmap[x, y])

        # comput coverage
        self.covered_states[x, y] = 1
        self.episodic_covered_states[x, y] = 1
        coverage = np.around(len(np.where(self.covered_states == 1)[0]) /  (len(np.where(self.covered_states == 0)[0]) + len(np.where(self.covered_states == 1)[0])) * 100, 2)
        episodic_coverage = np.around(len(np.where(self.episodic_covered_states == 1)[0]) /  (len(np.where(self.episodic_covered_states == 0)[0]) + len(np.where(self.episodic_covered_states == 1)[0])) * 100, 2)
        maze_coverage = {"coverage" : coverage, "episodic_coverage": episodic_coverage}
        maze_coverage["task_reward"] = r
        maze_coverage["heatmap"] = self.heatmap

        # increment count
        self.heatmap[x, y] += 1

        # add true reward
        if self.no_intr_rew == True:
            r = r
        elif self.add_true_rew == True:
            r = r + count_rew
        else:
            r = count_rew
            
        # compute obs
        obs = self.get_obs(obs)
        return obs, r, d, {**info, **maze_coverage}

    def get_obs(self, obs):
        
        if self.add_true_rew == False and self.no_intr_rew == False:
            # only get player, floor, and walls
            obs = obs[[0,2]]

        # get agent position in the observation (also works with partial_obs)
        x, y = np.unravel_index(np.argmax(obs[0], axis=None), obs[0].shape)
        
        # tranform to one channel only with indexes
        obs = np.argmax(obs, axis=0)

        # manually set agent position
        max_num = np.max(obs)
        obs[x, y] = max_num + 1
        
        # build observation
        if self.use_cnn == True and self.partial_obs == False:
            img_obs = {"image" : np.array(obs) / 2}
        else:
            img_obs = {"image" : np.array(obs).flatten() / 2}
        
        obs = {"image" : img_obs["image"]}
                
        if self.augment == True:
            pos_map = np.zeros((self.heatmap_shape))[None, :]
            pos_map[0, self.x, self.y] = 1
            
            # if using salesman reward we only care about the binary map
            if self.salesman == True:
                heatmap_ = np.sign(self.heatmap - 1)
            else:
                heatmap_ = self.heatmap / np.max(self.heatmap) 

            if self.use_cnn == True:
                obs["heatmap_and_pos"] = np.stack([
                    heatmap_[None,:],
                    pos_map
                ]).squeeze()
            else:
                obs["heatmap_and_pos"] = np.concatenate([
                    heatmap_.flatten(),
                    pos_map.flatten()
                ])
            
        return obs

    def reset(self):
        # reset env
        if self.proc_gen:
            obs = self.env.reset(random_gen=True)
        else:
            obs = self.env.reset(random_gen=False, level_string=self.level_string)
        
        # reset for computing episode metrics
        self.episodic_covered_states = np.zeros((self.heatmap_shape))
        if self.partial_obs == True or self.pixel_obs == True:
            state = self.env.get_state()
            for obj in state["Objects"]:
                if obj["Name"] == "wall":
                    x, y = obj["Location"]
                    self.episodic_covered_states[x,y] = -1
                    self.covered_states[x,y] = -1
                elif obj["Name"] == "avatar":
                    x, y = obj["Location"]
                    self.episodic_covered_states[(x,y)] = 1
                    self.covered_states[x,y] = 1
        else:
            x, y = np.unravel_index(np.argmax(obs[0], axis=None), obs[0].shape)            
            obs_ = np.argmax(obs, axis=0)
            self.episodic_covered_states[np.where(obs_ == 2)] = -1
            self.episodic_covered_states[x,y] = 1
            self.covered_states[x,y] = 1

        # if episodic bonus reset heatmap at each episode
        if self.episodic == True:
            self.heatmap = np.ones(self.heatmap_shape)

        self.x, self.y = x, y
        obs = self.get_obs(obs)
        return obs