import numpy as np
import gym
import matplotlib.pyplot as plt
import cv2
import torch
from gym.spaces import Box, Dict
from IPython import embed
from countbased.models.models import NatureCNN

# this wrapper is always for pixel-based envs!!!

class E3BWrapper(gym.Wrapper):
    def __init__(
            self,
            env,
            hidden_dim,
            map_shape=(32,32),
            ridge=0.1,
            add_true_rew=False,
            proc_gen=False,
            episodic=False,
            augment=False,
            level_string="",
            is_biomes=False
        ):
        
        super().__init__(env)

        self.augment = augment
        self.episodic = episodic
        self.level_string = level_string
        self.proc_gen = proc_gen
        self.x, self.y = -1, -1
        self.map_shape = map_shape
        self.is_biomes = is_biomes

        self.player_name =  "player" if is_biomes else "avatar"

        self.env = env

        if self.is_biomes:
            self.visited_biomes = [0,0,0,0,0,0,0,0,0]
        else:
            # initialize objects for measuring coverage
            self.covered_states = np.zeros(self.map_shape)
            self.episodic_covered_states = np.zeros(self.map_shape)

            state = self.env.get_state()
            for obj in state["Objects"]:
                if obj["Name"] == "wall":
                    x, y = obj["Location"]
                    self.episodic_covered_states[x,y] = -1
                    self.covered_states[x,y] = -1
                elif obj["Name"] == "avatar":
                    x, y = obj["Location"]
                    self.x, self.y = x, y
                    self.episodic_covered_states[x,y] = 1
                    self.covered_states[x,y] = 1

        # initialize ellipsoid matrix and IDM model
        self.hidden_dim = hidden_dim
        self.ridge = ridge
        self.cov_inverse = torch.eye(hidden_dim) * (1.0 / ridge)
        self.outer_product_buffer = torch.empty(hidden_dim, hidden_dim)
        
        self.encoder_model = NatureCNN(hidden_dim=hidden_dim).cuda()
        self.encoder_model.eval()

        # params
        self.add_true_rew = add_true_rew

        # no matter if its full or partial obs, we are going to resize the pixel image to same size
        self.observation_space = Dict({
            "image" : Box(-np.inf, np.inf, shape = (3, 84, 84))
        })
                    
        if augment == True:
            self.observation_space["ellipsoid"] = Box(-np.inf, np.inf, shape = (1, hidden_dim, hidden_dim))

        print("======= OBS SPACE =======")
        print(self.observation_space)
    
    def step(self, action):
        # step environment
        obs, r, d, info = self.env.step(action)
        obs = self.preprocess_obs(obs)


        # get new agent pos
        state = self.env.get_state()
        for obj in state["Objects"]:
            if self.player_name in obj["Name"]:
                x, y = obj["Location"]
                break
        self.x, self.y = x, y

        # compute intrinsic reward
        with torch.no_grad():
            h = self.encoder_model(torch.tensor(obs).float().unsqueeze(0).cuda()).squeeze().detach().cpu()
            u = torch.mv(self.cov_inverse, h)
            
            count_rew = torch.dot(h, u).item()

            self.outer_product_buffer = torch.outer(u, u)
            self.cov_inverse = torch.add(self.cov_inverse, self.outer_product_buffer, alpha=-(1./(1. + count_rew)))

        if self.is_biomes:
            biome_idx = self.get_biome_from_obs((x,y))
            self.visited_biomes[biome_idx] = 1
            coverage_info = {"visited_biomes" : sum(self.visited_biomes)}
        else:
            # comput coverage
            self.covered_states[x, y] = 1
            self.episodic_covered_states[x, y] = 1
            coverage = np.around(len(np.where(self.covered_states == 1)[0]) /  (len(np.where(self.covered_states == 0)[0]) + len(np.where(self.covered_states == 1)[0])) * 100, 2)
            episodic_coverage = np.around(len(np.where(self.episodic_covered_states == 1)[0]) /  (len(np.where(self.episodic_covered_states == 0)[0]) + len(np.where(self.episodic_covered_states == 1)[0])) * 100, 2)
            coverage_info = {"coverage" : coverage, "episodic_coverage": episodic_coverage}
            coverage_info["task_reward"] = r

        # add true reward
        if self.add_true_rew == True:
            r = r + count_rew
        else:
            r = count_rew
            
        # compute obs
        obs = self.get_obs(obs)
        return obs, r, d, {**info, **coverage_info}

    def preprocess_obs(self, obs):
        # resize image
        obs = obs.transpose(1,2,0)
        obs = cv2.resize(obs, (84, 84), interpolation=cv2.INTER_AREA) / 255.0
        obs = obs.transpose(2,0,1)
        return obs

    def get_obs(self, obs):
        obs = {"image" : obs}
                
        if self.augment == True:
            obs["ellipsoid"] = self.cov_inverse.unsqueeze(0)
            
        return obs

    def reset(self):
        # reset env
        if self.is_biomes:
            obs = self.env.reset()
        else:
            if self.proc_gen:
                obs = self.env.reset(random_gen=True)
            else:
                obs = self.env.reset(random_gen=False, level_string=self.level_string)
        
        # reset for computing episode metrics
        if self.is_biomes:
            self.visited_biomes = [0,0,0,0,0,0,0,0,0]
            state = self.env.get_state()
            for obj in state["Objects"]:
                if self.player_name in obj["Name"]:
                    x, y = obj["Location"]
                    self.visited_biomes[self.get_biome_from_obs((x,y))] = 1
                    break
        else:
            self.episodic_covered_states = np.zeros((self.map_shape))
            state = self.env.get_state()
            for obj in state["Objects"]:
                if obj["Name"] == "wall":
                    x, y = obj["Location"]
                    self.episodic_covered_states[x,y] = -1
                    self.covered_states[x,y] = -1
                elif obj["Name"] == "avatar":
                    x, y = obj["Location"]
                    self.episodic_covered_states[(x,y)] = 1
                    self.covered_states[x,y] = 1

        # if episodic bonus reset ellipsoid at each episode
        if self.episodic == True:
            self.cov_inverse = torch.eye(self.hidden_dim) * (1.0 / self.ridge)
            self.outer_product_buffer = torch.empty(self.hidden_dim, self.hidden_dim)

        self.x, self.y = x, y
        obs = self.preprocess_obs(obs)
        obs = self.get_obs(obs)
        return obs
    
    def load_elliptical_encoder(self, state_dict):
        self.encoder_model.load_state_dict(state_dict)
        self.encoder_model.eval()

    def get_biome_from_obs(self, pos):
        x = pos[0]
        y = pos[1]

        biome_idx = -1
        if 4 <= x and x <= 12:
            if 4 <= y and y <= 12:
                biome_idx = 0
            elif 13 <= y and y <= 25:
                biome_idx = 1
            elif 26 <= y and y <= 34:
                biome_idx = 2
        elif 13 <= x and x <= 25:
            if 4 <= y and y <= 12:
                biome_idx = 3
            elif 13 <= y and y <= 25:
                biome_idx = 4
            elif 26 <= y and y <= 34:
                biome_idx = 5
        elif 26 <= x and x <= 34:
            if 4 <= y and y <= 12:
                biome_idx = 6
            elif 13 <= y and y <= 25:
                biome_idx = 7
            elif 26 <= y and y <= 34:
                biome_idx = 8
        else:
            return "incorrect x and y"

        return biome_idx