import numpy as np
import torch
from itertools import count
from datetime import datetime
from collections import deque
import os
import gym
import pybulletgym
import pointMass
import gym_minigrid
from gym_minigrid.wrappers import *
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
# from stable_baselines3.bench import Monitor
# from stable_baselines3.common.logger import Logger
# from stable_baselines3.common.monitor import Monitor
import re
import glob


def format_name_string(name_string):
    name_string = name_string.replace('{', '_').replace('}', '').replace(' ', '').replace("'xml_file'", '')
    name_string = name_string.replace("'", "").replace(":", "").replace('/', '')

    return name_string


def get_env_demo_files(expert_demo_dir, env_name, spec):
    demo_dir = os.listdir(expert_demo_dir)
    specd_env_name = env_name + format_name_string(str(spec))
    demo_files = [f for f in demo_dir if specd_env_name in f]

    def atoi(text):
        return int(text) if text.isdigit() else text

    def natural_keys(text):
        return [ atoi(c) for c in re.split(r'(\d+)', text) ]

    demo_files.sort(key=natural_keys)

    return demo_files


def make_venv(opt, n_envs, spec, spec_test, wrapper_kwargs, use_rank=True, use_subprocess=False):
    if spec is None:
        spec = {}
    if spec_test is None:
        spec_test = {}
    if wrapper_kwargs is None:
        wrapper_kwargs = {}

    print("-"*100)

    if 'Custom' in opt.env_name:
        envs = make_pybullet_venv(opt.env_name, env_kwargs=spec, wrapper_kwargs=wrapper_kwargs,
                                        n_envs=n_envs, seed=opt.seed, use_subprocess=use_subprocess,
                                        use_rank=use_rank)
        testing_env = make_pybullet_venv(opt.env_name, env_kwargs=spec_test,
                                         n_envs=1, seed=opt.seed, use_subprocess=use_subprocess,
                                         use_rank=use_rank)
    elif 'MiniGrid' in opt.env_name:
        envs = make_minigrid_venv(opt.env_name, env_kwargs=spec, wrapper_kwargs=wrapper_kwargs,
                                  n_envs=n_envs, seed=opt.seed,
                                  wrapper_type=opt.minigrid_wrapper,
                                  use_subprocess=use_subprocess,
                                  use_rank=use_rank)
        testing_env = make_minigrid_venv(opt.env_name, env_kwargs=spec_test,
                                         n_envs=1, seed=opt.seed, use_subprocess=use_subprocess,
                                         use_rank=use_rank,
                                         wrapper_type=opt.minigrid_wrapper)
    elif 'pointMass' in opt.env_name:
        envs = make_pointmass_venv(opt.env_name, env_kwargs=spec, wrapper_kwargs=wrapper_kwargs,
                                  n_envs=n_envs, seed=opt.seed,
                                  use_subprocess=use_subprocess,
                                  use_rank=use_rank)
        testing_env = make_pointmass_venv(opt.env_name, env_kwargs=spec_test,
                                         n_envs=1, seed=opt.seed, use_subprocess=use_subprocess,
                                         use_rank=use_rank)
        testing_env = gym.make(opt.env_name, **spec_test)
        testing_env = pmObsWrapper(testing_env)

    # not pybullet or minigrid
    else:
        def make_env(rank):
            def _thunk():
                env = gym.make(opt.env_name, **spec)
                if use_rank:
                    seed = opt.seed + rank
                else:
                    seed = opt.seed
                env.seed(seed)
                env.action_space.seed(seed)
                env.observation_space.seed(seed)
                env = apply_wrappers(env, **wrapper_kwargs)
                return env

            return _thunk

        envs = [make_env(i) for i in range(n_envs)]
        if use_subprocess:
            envs = SubprocVecEnv(envs)
        else:
            envs = DummyVecEnv(envs)
        testing_env = gym.make(opt.env_name, **spec_test)
        testing_env.seed(opt.seed)
        testing_env.action_space.seed(opt.seed)
        testing_env.observation_space.seed(opt.seed)

    print("-"*100)

    return envs, testing_env


# The point of this is to include custom wrappers before creating vectorized env
# This is adapted from stable baselines
def make_pybullet_venv(env_id, n_envs, seed, env_kwargs=None,
                       wrapper_kwargs=None,
                       allow_early_resets=True,
                       start_method=None, 
                       use_rank=True,
                       use_subprocess=False):
    """
    Create a wrapped, monitored VecEnv for Mujoco.

    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environment you wish to have in subprocesses
    :param seed: (int) the initial seed for RNG
    :param wrapper_kwargs: (dict) the parameters for wrap_deepmind function
    :param start_index: (int) start rank index
    :param allow_early_resets: (bool) allows early reset of the environment
    :param start_method: (str) method used to start the subprocesses.
        See SubprocVecEnv doc for more information
    :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` when
        `num_env` > 1, `DummyVecEnv` is usually faster. Default: False
    :return: (VecEnv) The atari environment
    """
    if wrapper_kwargs is None:
        wrapper_kwargs = {}

    print(">>> Making environments with parameters: ", env_kwargs)

    def make_env(rank):
        def _thunk():
            env = gym.make(env_id, **env_kwargs)
            if use_rank:
                seedr = seed + rank
            else:
                seedr = seed
            env.seed(seedr)
            env.action_space.seed(seedr)
            env.observation_space.seed(seedr)
            # env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
                          # allow_early_resets=allow_early_resets)
            return apply_wrappers(env, **wrapper_kwargs)
            # return env

        return _thunk

    # When using one environment, no need to start subprocesses
    if n_envs == 1 or not use_subprocess:
        return DummyVecEnv([make_env(i) for i in range(n_envs)])
 
    return SubprocVecEnv([make_env(i) for i in range(n_envs)],
                         start_method=start_method)


def make_minigrid_venv(env_id, n_envs, seed, env_kwargs=None,
                       wrapper_type='flat',
                       wrapper_kwargs=None,
                       allow_early_resets=True,
                       start_method=None, 
                       use_rank=True,
                       use_subprocess=False):
    """
    Create a wrapped, monitored VecEnv for Minigrid.

    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environment you wish to have in subprocesses
    :param seed: (int) the initial seed for RNG
    :param wrapper_kwargs: (dict) the parameters for wrap_deepmind function
    :param start_index: (int) start rank index
    :param allow_early_resets: (bool) allows early reset of the environment
    :param start_method: (str) method used to start the subprocesses.
        See SubprocVecEnv doc for more information
    :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` when
        `num_env` > 1, `DummyVecEnv` is usually faster. Default: False
    :return: (VecEnv) The atari environment
    """
    if wrapper_kwargs is None:
        wrapper_kwargs = {}


    print(">>> Making environments with parameters: ", env_kwargs)

    def make_env(rank):
        def _thunk():
            env = gym.make(env_id, **env_kwargs)
            if use_rank:
                seedr = seed + rank
            else:
                seedr = seed
            env.seed(seedr)
            env.action_space.seed(seedr)
            env.observation_space.seed(seedr)
            # env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
                          # allow_early_resets=allow_early_resets)
            env = apply_wrappers(env, **wrapper_kwargs)
            if wrapper_type == 'flat':
                return FlatObsWrapper(env, **wrapper_kwargs)
            else:
                return ImgObsWrapper(env, **wrapper_kwargs)

        return _thunk

    if n_envs == 1 or not use_subprocess:
        return DummyVecEnv([make_env(i) for i in range(n_envs)])
 
    return SubprocVecEnv([make_env(i) for i in range(n_envs)],
                         start_method=start_method)


def make_pointmass_venv(env_id, n_envs, seed, env_kwargs=None,
                       wrapper_kwargs=None,
                       allow_early_resets=True,
                       start_method=None, 
                       use_rank=True,
                       use_subprocess=False):
    """
    Create a wrapped, monitored VecEnv for pointmass.

    :param env_id: (str) the environment ID
    :param num_env: (int) the number of environment you wish to have in subprocesses
    :param seed: (int) the initial seed for RNG
    :param wrapper_kwargs: (dict) the parameters for wrap_deepmind function
    :param start_index: (int) start rank index
    :param allow_early_resets: (bool) allows early reset of the environment
    :param start_method: (str) method used to start the subprocesses.
        See SubprocVecEnv doc for more information
    :param use_subprocess: (bool) Whether to use `SubprocVecEnv` or `DummyVecEnv` when
        `num_env` > 1, `DummyVecEnv` is usually faster. Default: False
    :return: (VecEnv) The atari environment
    """
    if wrapper_kwargs is None:
        wrapper_kwargs = {}


    print(">>> Making environments with parameters: ", env_kwargs)

    def make_env(rank):
        def _thunk():
            env = gym.make(env_id, **env_kwargs)
            if use_rank:
                seedr = seed + rank
            else:
                seedr = seed
            env.seed(seedr)
            env.action_space.seed(seedr)
            env.observation_space.seed(seedr)
            # env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
                          # allow_early_resets=allow_early_resets)
            env = apply_wrappers(env, **wrapper_kwargs)
            return pmObsWrapper(env, **wrapper_kwargs)

        return _thunk

    if n_envs == 1 or not use_subprocess:
        return DummyVecEnv([make_env(i) for i in range(n_envs)])
 
    return SubprocVecEnv([make_env(i) for i in range(n_envs)],
                         start_method=start_method)


def apply_wrappers(env, reward_fn=None, stack_size=1, disc=None, **kwargs):
    if stack_size > 1:
        env = FeatureStack(env, stack_size)
    if reward_fn is not None:
        env = CustomReward(env, reward_fn, kwargs)
    if disc is not None:
        env = DiscReward(env, disc, kwargs)

    return env


def repack_vecenv(vecenv, disc, use_subprocess=False):
    def repack_env(e):
        def _thunk():
            env = apply_wrappers(e, disc=disc)
            return env
        return _thunk

    env_list = [repack_env(env) for env in vecenv.envs]
    if use_subprocess:
        return SubprocVecEnv(env_list)
    else:
        return DummyVecEnv(env_list)


## Gym wrapper classes
class FeatureStack(gym.Wrapper):
    def __init__(self, env, stack_size=2):
        super().__init__(env=env)

        self.stack_size = stack_size
        self.stack = deque([], self.stack_size)

    def reset(self):
        obs = self.env.reset()

        first_features = self.env.get_current_features()
        for _ in range(self.stack_size):
            self.stack.append(first_features)

        return obs

    def step(self, action):
        state, reward, done, info = self.env.step(action)
        self.stack.append(info["features"])

        info["features"] = np.concatenate(list(self.stack), axis=0)
        # the key name "visual_features" is for compatibility with the
        # VisualHistoryWrapper
        # info["visual_features"] = LazyFeatures(list(self.stack))

        return state, reward, done, info


class CustomReward(gym.Wrapper):
    def __init__(self, env, reward_fn, use_actions=False):
        super().__init__(env=env)
        # self.reward_fn = make_network(**reward_fn_spec)
        self.use_actions = use_actions
        self.reward_fn = reward_fn
        self.obs = None

    def step(self, action):
        # also, here, WANT state before applying action
        next_obs, gt_reward, done, info = self.env.step(action)
        if self.obs is None:
            obs = next_obs
        else:
            obs = self.obs
        info['gt_reward'] = gt_reward
        with torch.no_grad():
            # here either use observation or features from info dict 
            # reward = self.reward_fn(torch.tensor(info["features"],
                                                 # dtype=torch.float32))
            if self.use_actions:
                reward = self.reward_fn(torch.cat([torch.tensor(obs, dtype=torch.get_default_dtype()),
                                    torch.tensor(action, dtype=torch.get_default_dtype())], axis=-1))
            else:
                reward = self.reward_fn(torch.tensor(self.obs, dtype=torch.get_default_dtype()))
            reward = reward.unsqueeze(0).cpu().numpy()

        self.obs = next_obs

        return next_obs, reward, done, info


class DiscReward(gym.Wrapper):
    def __init__(self, env, discriminator, use_actions=False):
        super().__init__(env=env)
        # self.reward_fn = make_network(**reward_fn_spec)
        self.use_actions = use_actions
        self.discriminator = discriminator
        self.obs = None

    def step(self, action):
        next_obs, gt_reward, done, info = self.env.step(action)
        info['gt_reward'] = gt_reward
        if self.obs is not None:
            obs_t = torch.tensor(self.obs, dtype=torch.get_default_dtype())
        else:
            obs_t = torch.tensor(next_obs, dtype=torch.get_default_dtype())

        acs_t = torch.tensor(action, dtype=torch.get_default_dtype())
        next_obs_t = torch.tensor(next_obs, dtype=torch.get_default_dtype())

        with torch.no_grad():
            # get discriminator reward and train on that
            #print([p.norm(2) for p in self.discriminator.parameters()])
            irl_reward = self.discriminator.get_reward(obs_t, acs_t).cpu().numpy()
            #irl_reward = self.discriminator.get_reward(obs_t, acs_t, next_obs_t).cpu().numpy()
        self.obs = next_obs

        return next_obs, irl_reward, done, info

# pointmass dict extraction wrapper
class pmObsWrapper(gym.Wrapper):
    def __init__(self, env):
        self.env = env
        self.env.observation_space = self.env.observation_space['observation']
        print(self.env.observation_space)
        super().__init__(env=self.env)

    def reset(self):
        o = self.env.reset()
        return np.concatenate([o['full_positional_state'], o['desired_goal']])

    def step(self, action):
        o, r, d, i = self.env.step(action)
        # shaped reward
        r_s = -np.linalg.norm(o['desired_goal'] - o['full_positional_state'])
        o = np.concatenate([o['full_positional_state'], o['desired_goal']])
        return o, r_s, d, i


class CustomImgObsWrapper(gym.core.ObservationWrapper):
    """
    Use the transposed image as the only observation output, no language/mission.
    """

    def __init__(self, env):
        super().__init__(env)
        self.observation_space = env.observation_space.spaces['image']

    def observation(self, obs):
        obs_t = np.transpose(obs['image'], (2, 1, 0))
        print(obs_t.shape)
        return obs_t
