import numpy as np
from collections import defaultdict
import d4rl

from environments.kitchen.spirl.utils.general_utils import AttrDict
from environments.kitchen.spirl.utils.general_utils import ParamDict
from environments.kitchen.spirl.rl.components.environment import GymEnv


class KitchenEnv(GymEnv):
    """Tiny wrapper around GymEnv for Kitchen tasks."""

    SUBTASKS = [
        "microwave",
        "kettle",
        "slide cabinet",
        "hinge cabinet",
        "bottom burner",
        "light switch",
        "top burner",
    ]

    def _default_hparams(self):
        return (
            super()
            ._default_hparams()
            .overwrite(
                ParamDict(
                    {
                        "name": "kitchen-mixed-v0",
                    }
                )
            )
        )

    def step(self, *args, **kwargs):
        obs, rew, done, info = super().step(*args, **kwargs)
        return (
            obs,
            np.float64(rew),
            done,
            self._postprocess_info(info),
        )  # casting reward to float64 is important for getting shape later

    def reset(self):
        self.solved_subtasks = defaultdict(lambda: 0)
        return super().reset()

    def get_episode_info(self):
        info = super().get_episode_info()
        info.update(AttrDict(self.solved_subtasks))
        return info

    def _postprocess_info(self, info):
        """Sorts solved subtasks into separately logged elements."""
        completed_subtasks = info.pop("completed_tasks")
        for task in self.SUBTASKS:
            self.solved_subtasks[task] = (
                1 if task in completed_subtasks or self.solved_subtasks[task] else 0
            )
        return info


class NoGoalKitchenEnv(KitchenEnv):
    """Splits off goal from obs."""

    def step(self, *args, **kwargs):
        obs, rew, done, info = super().step(*args, **kwargs)
        obs = obs[: int(obs.shape[0] / 2)]
        return obs, rew, done, info

    def reset(self, *args, **kwargs):
        obs = super().reset(*args, **kwargs)
        return obs[: int(obs.shape[0] / 2)]
