import numpy as np
from numpy.random import randint
import os
import gym
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from collections import deque


class ColorWrapper(gym.Wrapper):
    """Wrapper for the color experiments"""
    def __init__(self, env, mode, domain_name, seed=None, fix_color=True):
        # TODO: remove when you replace FrameStack
        assert isinstance(env, FrameStack), 'wrapped env must be a framestack'

        super().__init__(env)
        self._max_episode_steps = env._max_episode_steps
        self._mode = mode
        self._random_state = np.random.RandomState(seed)
        self._domain_name = domain_name
        self._fix_color = fix_color
        self._first = True
        # self.time_step = 0
        if 'color' in self._mode:
            self._load_colors()

    def reset(self):
        if 'color' in self._mode and (not self._fix_color or self._first):
            self.randomize()
            self._first = False
        return np.asarray(self.env.reset())

    def step(self, action):
        obs, done, reward, info = self.env.step(action)
        return np.asarray(obs), done, reward, info

    def randomize(self):
        assert 'color' in self._mode, f'can only randomize in color mode, received {self._mode}'
        self.reload_physics(self.get_random_color())

    def _load_colors(self):
        assert self._mode in {'color_easy', 'color_hard'}
        current_dir = os.path.join(os.path.dirname(__file__))
        self._colors = torch.load(f'{current_dir}/{self._mode}.pt')

    def get_random_color(self):
        assert len(self._colors) >= 100, 'env must include at least 100 colors'
        return self._colors[self._random_state.randint(len(self._colors))]

    def reload_physics(self, setting_kwargs=None, state=None):
        domain_name = self._domain_name

        if setting_kwargs is None:
            setting_kwargs = {}
        if state is None:
            state = self._get_state()

        from .settings import get_model_and_assets_from_setting_kwargs
        self._reload_physics(
            *get_model_and_assets_from_setting_kwargs(
                domain_name+'.xml', setting_kwargs
            )
        )
        self._set_state(state)

    def get_state(self):
        return self._get_state()

    def set_state(self, state):
        self._set_state(state)

    def _get_gym_dmc_wrapper(self):
        import gym_dmc
        _env = self.env
        while not isinstance(_env, gym_dmc.dm_env.DMCEnv) and hasattr(_env, 'env'):
            _env = _env.env
        assert isinstance(_env, gym_dmc.dm_env.DMCEnv), f'environment is not gym_dmc env: {_env}'

        return _env

    def _reload_physics(self, xml_string, assets=None):
        _env = self.env
        while not hasattr(_env, 'physics') and hasattr(_env, 'env'):
            _env = _env.env
        assert hasattr(_env, 'physics'), 'environment does not have physics attribute'
        _env.physics.reload_from_xml_string(xml_string, assets=assets)

    def _get_physics(self):
        _env = self.env
        while not hasattr(_env, 'physics') and hasattr(_env, 'env'):
            _env = _env.env
        assert hasattr(_env, 'physics'), 'environment does not have physics attribute'

        return _env.physics

    def _get_state(self):
        return self._get_physics().get_state()

    def _set_state(self, state):
        self._get_physics().set_state(state)


class FrameStack(gym.Wrapper):
    """Stack frames as observation"""
    def __init__(self, env, k):
        gym.Wrapper.__init__(self, env)
        self._k = k
        self._frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        # TODO: Check where the observation is normalized.
        self.observation_space = gym.spaces.Box(
            low=0,
            high=1,
            shape=((shp[0] * k,) + shp[1:]),
            dtype=env.observation_space.dtype
        )
        self._max_episode_steps = env._max_episode_steps

    def reset(self):
        obs = self.env.reset()
        for _ in range(self._k):
            self._frames.append(obs)
        return self._get_obs()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self._frames.append(obs)
        return self._get_obs(), reward, done, info

    def _get_obs(self):
        assert len(self._frames) == self._k
        return LazyFrames(list(self._frames))


class LazyFrames(object):
    def __init__(self, frames, extremely_lazy=True):
        self._frames = frames
        self._extremely_lazy = extremely_lazy
        self._out = None

    @property
    def frames(self):
        return self._frames

    def _force(self):
        if self._extremely_lazy:
            return np.concatenate(self._frames, axis=0)
        if self._out is None:
            self._out = np.concatenate(self._frames, axis=0)
            self._frames = None
        return self._out

    def __array__(self, dtype=None):
        out = self._force()
        if dtype is not None:
            out = out.astype(dtype)
        return out

    def __len__(self):
        if self._extremely_lazy:
            return len(self._frames)
        return len(self._force())

    def __getitem__(self, i):
        return self._force()[i]

    def count(self):
        if self.extremely_lazy:
            return len(self._frames)
        frames = self._force()
        return frames.shape[0]//3

    def frame(self, i):
        return self._force()[i*3:(i+1)*3]
