import time
import os

import argparse
# from omegaconf import OmegaConf, DictConfig
import yaml
from pathlib import Path
# from envs.meta_world.mtrl.utils import config as config_utils
from typing import Dict

# import hydra
# # from hydra.core.global_hydra import initialize
# # from hydra.compose import compose
# import numpy as np
#
# from envs.meta_world.mtrl.agent import utils as agent_utils
# from envs.meta_world.mtrl.env import builder as env_builder
# from envs.meta_world.mtrl.env.vec_env import VecEnv  # type: ignore[attr-defined]
# from envs.meta_world.mtrl.experiment import multitask
# from envs.meta_world.mtrl.utils.types import ConfigType
# import torch

from envs.meta_world.mtrl.env import builder as env_builder
from envs.meta_world.mtrl.env.vec_env import VecEnv  # type: ignore[attr-defined]
import gym
import gym.spaces
import gym.spaces.utils
# import metaworld
import numpy as np

class MetaWorld:
    metadata = {}

    def __init__(self, name=None, seed=0, size=(64,64), mode="train", task=None): #mode默认为train就行，eval的时候是和train一样的
        # # 初始化 Hydra 上下文（指定配置路径，与 main.py 中的 @hydra.main 一致）
        # with initialize(config_path="envs/meta_world/config"):  # version_base 适配 Hydra 2.0+
        #     # 手动 compose 配置（模拟命令行传入 env=xxx）
        #     # overrides 参数：传入命令行参数列表，格式与命令行一致
        #     config: DictConfig = compose(
        #         config_name="config",  # 与 main.py 中的 config_name 一致
        #         overrides=[f"env=metaworld-{name}"]  # 指定要加载的 env 子配置
        #     )
        #
        # # 调用 main.py 中的 launch 函数，传入手动构建的配置
        # self._config = get_config(config)

        # torch.manual_seed(seed)
        # if torch.cuda.is_available():
        #     torch.cuda.manual_seed_all(seed)
        # np.random.seed(seed)
        # random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)

        # self.config = load_hydra_like_config(env_name=f"metaworld-{name}")
        # self.config = config_utils.process_config(self.config)
        # self.should_reset_env_manually = True

        # benchmark = hydra.utils.instantiate(self.config.env.benchmark)
        #
        # self.env, env_id_to_task_map = env_builder.build_metaworld_vec_env(
        #     config=self.config, benchmark=benchmark, mode=mode, env_id_to_task_map=None
        # )
        # dummy_env = self.env.env_fns[0]().env
        # self.observation_space = dummy_env.observation_space
        # self.action_space = dummy_env.action_space
        # mt1 = metaworld.MT1(name)
        # env_id_list = list(mt1.train_classes.keys())
        # # print(f"env_id_list:{env_id_list}")
        #
        # def _get_class_items(current_benchmark):
        #     return current_benchmark.train_classes.items()
        #
        # def _get_tasks(current_benchmark):
        #     return current_benchmark.train_tasks
        #
        # env_id_to_task_map = {}
        # current_benchmark = mt1
        # for env_id in env_id_list:
        #     for name, _ in _get_class_items(current_benchmark):
        #         if name == env_id:
        #             task = random.choice(
        #                 [
        #                     task
        #                     for task in _get_tasks(current_benchmark)
        #                     if task.env_name == name
        #                 ]
        #             )
        #             env_id_to_task_map[env_id] = task
        # env_id = env_id_list[0]
        # for name, env_cls in _get_class_items(mt1):
        #     if name == env_id:
        #         env = env_cls()
        #         task = env_id_to_task_map[env_id]
        #         env.set_task(task)


        # print(mt1.train_tasks)
        env, env_id_to_task_map = env_builder.build_metaworld_vec_env(
            task=task, mode=mode, env_id_to_task_map=None
        )
        self._env = env
        self.reward_range = [-np.inf, np.inf]
        self.task = name
        self._size = size
        self.has_done = False
        print(f"env.has_done:{self.has_done}")
        # info["goal"]可能没变，但是计算reward等时用到的_state_goal其实在随机改变（可以看metaworld源代码）

        # env = ml1.train_classes['pick-place-v1']()  # Create an environment with task `pick_place`
        # task = random.choice(ml1.train_tasks)
        # env.set_task(task)  # Set task

        # obs = env.reset()  # Reset environment
        # a = env.action_space.sample()  # Sample an action
        # obs, reward, done, info = env.step(a)  # Step the environoment with the sampled random action

        # envs[mode], env_id_to_task_map = env_builder.build_metaworld_vec_env(
        #     config=self.config,
        #     benchmark=benchmark,
        #     mode="train",
        #     env_id_to_task_map=env_id_to_task_map,
        # )
        # In MT10 and MT50, the tasks are always sampled in the train mode.
        # For more details, refer https://github.com/rlworkgroup/metaworld

        # hardcoding the steps as different environments return different
        # values for max_path_length. MetaWorld uses 150 as the max length.
        # metadata = self.get_env_metadata(
        #     env=envs["train"],
        #     max_episode_steps=max_episode_steps,
        #     ordered_task_list=list(env_id_to_task_map.keys()),
        # )

    @property
    def observation_space(self):
        return gym.spaces.Dict(
            {
                "state": self._env.observation_space,
                "token_embed": gym.spaces.Box(-np.inf, np.inf, (384,), dtype=np.float32)

            }
        )

    @property
    def action_space(self):
        return self._env.action_space

    def reset(self):
        o = self._env.reset()
        obs = {"state": o, "is_terminal": False, "is_first": True, }
        token_embed = np.zeros(384)
        obs["token_embed"] = token_embed
        obs["image"] = np.zeros(shape=self._size + (3,)) #占位用
        self.has_done = False
        return obs

    def step(self, action):
        o, reward, done, info = self._env.step(action)
        obs = {"state": o, "is_terminal": done, "is_first": False, }
        token_embed = np.zeros(384)
        obs["token_embed"] = token_embed
        obs["image"] = np.zeros(shape=self._size + (3,)) #占位用
        self.has_done = self.has_done or done
        info["done"] = done
        info["has_done"] = self.has_done
        return obs, reward, False, info #这里的done用于reset，认为仅在达到max_episode_step的时候才reset










# from mtenv
#
# """"An environment wrapper that normalizes action, observation and reward."""
# type: ignore
class NormalizedEnvWrapper(gym.Wrapper):
    """An environment wrapper for normalization.

    This wrapper normalizes action, and optionally observation and reward.

    Args:
        env (garage.envs.GarageEnv): An environment instance.
        scale_reward (float): Scale of environment reward.
        normalize_obs (bool): If True, normalize observation.
        normalize_reward (bool): If True, normalize reward. scale_reward is
            applied after normalization.
        expected_action_scale (float): Assuming action falls in the range of
            [-expected_action_scale, expected_action_scale] when normalize it.
        flatten_obs (bool): Flatten observation if True.
        obs_alpha (float): Update rate of moving average when estimating the
            mean and variance of observations.
        reward_alpha (float): Update rate of moving average when estimating the
            mean and variance of rewards.

    """

    def __init__(
        self,
        env,
        scale_reward=1.0,
        normalize_obs=False,
        normalize_reward=False,
        expected_action_scale=1.0,
        flatten_obs=True,
        obs_alpha=0.001,
        reward_alpha=0.001,
    ):
        super().__init__(env)

        self._scale_reward = scale_reward
        self._normalize_obs = normalize_obs
        self._normalize_reward = normalize_reward
        self._expected_action_scale = expected_action_scale
        self._flatten_obs = flatten_obs

        self._obs_alpha = obs_alpha
        flat_obs_dim = gym.spaces.utils.flatdim(env.observation_space)
        self._obs_mean = np.zeros(flat_obs_dim)
        self._obs_var = np.ones(flat_obs_dim)

        self._reward_alpha = reward_alpha
        self._reward_mean = 0.0
        self._reward_var = 1.0

    def _update_obs_estimate(self, obs):
        flat_obs = gym.spaces.utils.flatten(self.env.observation_space, obs)
        self._obs_mean = (
            1 - self._obs_alpha
        ) * self._obs_mean + self._obs_alpha * flat_obs
        self._obs_var = (
            1 - self._obs_alpha
        ) * self._obs_var + self._obs_alpha * np.square(flat_obs - self._obs_mean)

    def _update_reward_estimate(self, reward):
        self._reward_mean = (
            1 - self._reward_alpha
        ) * self._reward_mean + self._reward_alpha * reward
        self._reward_var = (
            1 - self._reward_alpha
        ) * self._reward_var + self._reward_alpha * np.square(
            reward - self._reward_mean
        )

    def _apply_normalize_obs(self, obs):
        """Compute normalized observation.

        Args:
            obs (np.ndarray): Observation.

        Returns:
            np.ndarray: Normalized observation.

        """
        self._update_obs_estimate(obs)
        flat_obs = gym.spaces.utils.flatten(self.env.observation_space, obs)
        normalized_obs = (flat_obs - self._obs_mean) / (np.sqrt(self._obs_var) + 1e-8)
        if not self._flatten_obs:
            normalized_obs = gym.spaces.utils.unflatten(
                self.env.observation_space, normalized_obs
            )
        return normalized_obs

    def _apply_normalize_reward(self, reward):
        """Compute normalized reward.

        Args:
            reward (float): Reward.

        Returns:
            float: Normalized reward.

        """
        self._update_reward_estimate(reward)
        return reward / (np.sqrt(self._reward_var) + 1e-8)

    def reset(self, **kwargs):
        """Reset environment.

        Args:
            **kwargs: Additional parameters for reset.

        Returns:
            tuple:
                * observation (np.ndarray): The observation of the environment.
                * reward (float): The reward acquired at this time step.
                * done (boolean): Whether the environment was completed at this
                    time step.
                * infos (dict): Environment-dependent additional information.

        """
        ret = self.env.reset(**kwargs)
        if self._normalize_obs:
            return self._apply_normalize_obs(ret)
        else:
            return ret

    def step(self, action):
        """Feed environment with one step of action and get result.

        Args:
            action (np.ndarray): An action fed to the environment.

        Returns:
            tuple:
                * observation (np.ndarray): The observation of the environment.
                * reward (float): The reward acquired at this time step.
                * done (boolean): Whether the environment was completed at this
                    time step.
                * infos (dict): Environment-dependent additional information.

        """
        if isinstance(self.action_space, gym.spaces.Box):
            # rescale the action when the bounds are not inf
            lb, ub = self.action_space.low, self.action_space.high
            if np.all(lb != -np.inf) and np.all(ub != -np.inf):
                scaled_action = lb + (action + self._expected_action_scale) * (
                    0.5 * (ub - lb) / self._expected_action_scale
                )
                scaled_action = np.clip(scaled_action, lb, ub)
            else:
                scaled_action = action
        else:
            scaled_action = action
        try:
            next_obs, reward, done, info = self.env.step(scaled_action)
        except Exception as e:
            print(e)

        if self._normalize_obs:
            next_obs = self._apply_normalize_obs(next_obs)
        if self._normalize_reward:
            reward = self._apply_normalize_reward(reward)

        return next_obs, reward * self._scale_reward, done, info

