import gym
import d4rl
import cv2
import numpy as np
import joblib
from copy import deepcopy

from gym.spaces import Box
from .multigoal_maze_map import BIG_MAZE_MULTIGOAL, HARDEST_MAZE_MULTIGOAL


class D4rlEnvWrapper:
    def __init__(self, env_name):
        self.env = gym.make(env_name)
        self.dataset = self.env.get_dataset()

    @property
    def observation_space(self):
        return self.env.observation_space

    @property
    def action_space(self):
        return self.env.action_space

    def reset(self):
        return self.env.reset()

    def step(self, action):
        return self.env.step(action)

    def render(self, mode='human'):
        return self.env.render(mode)

    def get_dataset(self):
        return deepcopy(self.dataset)


class Maze2d(D4rlEnvWrapper):
    def __init__(self, env_name, action_repeat=3, **kwargs):
        super().__init__(env_name)
        self.p = kwargs['p'] if 'p' in kwargs else 0.
        self.action_repeat = action_repeat
        if 'data_pth' in kwargs:
            self.dataset = joblib.load(kwargs['data_pth'])[f'p{self.p:.1f}']
        elif self.p != 0:
            print('Warning: using non-stochastic dataset for stochastic environment!')
            print('Please specify data path (data_pth=\'path/to/dataset\')')

    def step(self, action):
        action = np.random.uniform(-1, 1, size=2) if np.random.rand() < self.p else action
        for _ in range(self.action_repeat):
            obs, rew, done, info = self.env.step(action)
        return obs, rew, done, info

    def sample_target(self):
        idx = self.env.np_random.choice(len(self.env.empty_and_goal_locations))
        reset_location = np.array(self.env.empty_and_goal_locations[idx]).astype(self.observation_space.dtype)
        target_location = reset_location + self.env.np_random.uniform(low=-.1, high=.1, size=self.env.model.nq)
        return target_location

    def get_goal_state(self):
        goal = np.zeros(self.env.observation_space.shape[0], dtype=np.float32)
        goal[:2] = self.env.get_target()
        return goal


class Antmaze(D4rlEnvWrapper):
    def __init__(self, env_name, **kwargs):
        super().__init__(env_name)
        if 'multigoal' in kwargs and kwargs['multigoal']:
            if 'medium' in env_name:
                self.env.set_maze_map(BIG_MAZE_MULTIGOAL)
            elif 'large' in env_name:
                self.env.set_maze_map(HARDEST_MAZE_MULTIGOAL)
            else:
                raise NotImplementedError

    def get_goal_state(self):
        goal = np.zeros(self.observation_space.shape[0], dtype=np.float32)
        goal[:2] = self.env.target_goal
        return goal


class Kitchen(D4rlEnvWrapper):
    def __init__(self, env_name, **kwargs):
        super().__init__(env_name)
        self.goal = self.dataset['observations'][0, 30:]
        self.dataset['observations'] = self.dataset['observations'][:, :30]
        self.dataset['timeouts'] = self.dataset['terminals']
        self.max_action = np.max(self.dataset['actions'], axis=0)
        self.min_action = np.min(self.dataset['actions'], axis=0)
        self.dataset['actions'] = 2. * (self.dataset['actions'] - self.min_action[np.newaxis, :]) \
                                  / (self.max_action - self.min_action)[np.newaxis, :] - 1.

    @property
    def observation_space(self):
        observation_space = Box(
            low=self.env.observation_space.low[:30],
            high=self.env.observation_space.high[:30]
        )
        return observation_space

    def reset(self):
        obs = self.env.reset()
        return obs[:30]

    def step(self, action):
        action = self.unnormalize_action(action)
        obs, rew, done, info = self.env.step(action)
        return obs[:30], rew, done, info

    def get_goal_state(self):
        return self.goal

    def unnormalize_action(self, action):
        return (self.max_action - self.min_action) * (action + 1.) / 2. + self.min_action


class GoalReachingD4rlEnv:
    def __init__(self, env_name, normalize=True, **kwargs):
        self.domain = env_name.split('-')[0]
        if self.domain == 'antmaze':
            self.env = Antmaze(env_name, **kwargs)
        elif self.domain == 'kitchen':
            self.env = Kitchen(env_name, **kwargs)
        elif self.domain == 'maze2d':
            self.env = Maze2d(env_name, **kwargs)
        else:
            raise NotImplementedError
        self.dataset = self.env.get_dataset()
        if normalize:
            self.obs_max = np.max(self.dataset['observations'], axis=0)
            self.obs_min = np.min(self.dataset['observations'], axis=0)
            self.dataset['observations'] = 2. * (self.dataset['observations'] - self.obs_min[np.newaxis, :]) \
                                           / (self.obs_max - self.obs_min)[np.newaxis, :] - 1.
        self.normalize = normalize

    def reset(self):
        obs = self.env.reset()
        goal = self.env.get_goal_state()
        if self.normalize:
            obs = self.normalize_obs(obs)
            goal = self.normalize_obs(goal)
            if self.domain == 'antmaze':
                goal[2:] = 0.
            elif self.domain == 'kitchen':
                goal[:9] = 0.
        return obs, goal

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        if self.normalize:
            obs = self.normalize_obs(obs)
        return obs, rew, done, info

    def normalize_obs(self, obs):
        return 2. * (obs - self.obs_min) / (self.obs_max - self.obs_min) - 1.

    def render(self, mode='human'):
        self.env.render(mode=mode)

    def get_dataset(self):
        return deepcopy(self.dataset)
