from multiprocessing.context import Process
from multiprocessing import Pipe
import numpy as np

from utils import RunningMeanStd



import cloudpickle
class CloudpickleWrapper(object):
    def __init__(self, data):
        self.data = data

    def __getstate__(self):
        return cloudpickle.dumps(self.data)

    def __setstate__(self, data):
        self.data = cloudpickle.loads(data)


def _worker(parent, p, env_fn_wrapper):
    parent.close()
    env = env_fn_wrapper.data()
    try:
        while True:
            try:
                cmd, data = p.recv()
            except EOFError:  # the pipe has been closed
                p.close()
                break
            if cmd == "step":
                if data is None:  # reset
                    obs = env.reset()
                else:
                    obs, reward, done, info = env.step(data)
                if data is None:
                    p.send(obs)
                else:
                    p.send((obs, reward, done, info))
            elif cmd == "close":
                p.send(env.close())
                p.close()
                break
            elif cmd == "seed":
                p.send(env.seed(data) if hasattr(env, "seed") else None)
            elif cmd == "getattr":
                if data=="body_mass":
                    p.send(env.sim.model.body_mass)
                elif data=="body_inertia":
                    p.send(env.sim.model.body_inertia)
                elif data=="dof_damping":
                    p.send(env.sim.model.dof_damping)
                elif data=="geom_friction":
                    p.send(env.sim.model.geom_friction)
                else:
                    p.send(getattr(env, data) if hasattr(env, data) else None)
            elif cmd == "setattr":
                if data["key"]=="body_mass":
                    env.model.body_mass[:] = data["value"]
                elif data["key"]=="body_inertia":
                    env.model.body_inertia[:] = data["value"]
                elif data["key"]=="dof_damping":
                    env.model.dof_damping[:] = data["value"]
                elif data["key"]=="geom_friction":
                    env.model.geom_friction[:] = data["value"]    
                else:
                    setattr(env, data["key"], data["value"])
            else:
                p.close()
                raise NotImplementedError
    except KeyboardInterrupt:
        p.close()



class SubprocEnvWorker:
    def __init__(self, env_fn):
        self.parent_remote, self.child_remote = Pipe()
        args = (
            self.parent_remote,
            self.child_remote,
            CloudpickleWrapper(env_fn),
        )
        self.process = Process(target=_worker, args=args, daemon=True)
        self.process.start()
        self.child_remote.close()
        self.is_reset = False
        self._env_fn = env_fn
        self.is_closed = False
        self.action_space = self.get_env_attr("action_space")  # noqa: B009


    def get_env_attr(self, key):
        self.parent_remote.send(["getattr", key])
        return self.parent_remote.recv()

    def set_env_attr(self, key, value):
        self.parent_remote.send(["setattr", {"key": key, "value": value}])

    def send(self, action) -> None:
        self.parent_remote.send(["step", action])

    def recv(self):
        result = self.parent_remote.recv()
        if isinstance(result, tuple):
            obs, rew, done, info = result
            return obs, rew, done, info
        else:
            obs = result
            return obs

    def seed(self, seed=None):
        self.action_space.seed(seed)
        self.parent_remote.send(["seed", seed])
        return self.parent_remote.recv()


    def close_env(self):
        try:
            self.parent_remote.send(["close", None])
            self.parent_remote.recv()
            self.process.join()
        except (BrokenPipeError, EOFError, AttributeError):
            pass
        self.process.terminate()

    def reset(self) -> np.ndarray:
        self.send(None)
        return self.recv()  # type: ignore

    def step(self, action):
        self.send(action)
        return self.recv()  # type: ignore

    def close(self):
        if self.is_closed:
            return None
        self.is_closed = True
        self.close_env()


class BaseVectorEnv(object):
    def __init__(self, env_fns, norm_obs=False, obs_rms=None, update_obs_rms=True):
        self._env_fns = env_fns
        self.workers = [SubprocEnvWorker(fn) for fn in env_fns]

        self.env_num = len(env_fns)

        self.ready_id = list(range(self.env_num))
        self.is_closed = False

        # initialize observation running mean/std
        self.norm_obs = norm_obs
        self.update_obs_rms = update_obs_rms
        self.obs_rms = RunningMeanStd() if obs_rms is None and norm_obs else obs_rms
        self.__eps = np.finfo(np.float32).eps.item()

    def __len__(self) -> int:
        return self.env_num

    def __getattribute__(self, key: str):
        if key in [
            'metadata', 'reward_range', 'spec', 'action_space', 'observation_space'
        ]:  # reserved keys in gym.Env
            return self.get_env_attr(key)
        else:
            return super().__getattribute__(key)

    def get_env_attr(self, key, id=None):
        id = self._wrap_id(id)
        return [self.workers[j].get_env_attr(key) for j in id]

    def set_env_attr(self, key, value, id=None):
        if id is None:
            id = self._wrap_id(id)
        for j in id:
            self.workers[j].set_env_attr(key, value)

    def _wrap_id(self, id=None):
        if id is None:
            return list(range(self.env_num))
        return [id] if np.isscalar(id) else id  # type: ignore


    def reset(self, id=None):
        id = self._wrap_id(id)
        for i in id:
            self.workers[i].send(None)
        obs_list = [self.workers[i].recv() for i in id]
        try:
            obs = np.stack(obs_list)
        except ValueError:  
            obs = np.array(obs_list, dtype=object)
        if self.obs_rms and self.update_obs_rms:
            self.obs_rms.update(obs)           
        return self.normalize_obs(obs)

    def step(self, action, id=None):
        id = self._wrap_id(id)

        assert len(action) == len(id)
        for i, j in enumerate(id):
            self.workers[j].send(action[i])
        result = []
        for j in id:
            obs, rew, done, info = self.workers[j].recv()
            info["env_id"] = j
            result.append((obs, rew, done, info))

        obs_list, rew_list, done_list, info_list = zip(*result)
        try:
            obs_stack = np.stack(obs_list)
        except ValueError:  
            obs_stack = np.array(obs_list, dtype=object)
        rew_stack, done_stack, info_stack = map(
            np.stack, [rew_list, done_list, info_list]
        )
        if self.obs_rms and self.update_obs_rms:
            self.obs_rms.update(obs_stack)
        return self.normalize_obs(obs_stack), rew_stack, done_stack, info_stack

    def seed(self, seed=None):
        seed_list = list()
        if seed is None:
            seed_list = [seed] * self.env_num
        elif isinstance(seed, int):
            seed_list = [seed + i for i in range(self.env_num)]
        else:
            seed_list = seed
        return [w.seed(s) for w, s in zip(self.workers, seed_list)]

    def close(self):
        for w in self.workers:
            w.close()
        self.is_closed = True

    def normalize_obs(self, obs):
        if self.obs_rms and self.norm_obs:
            clip_max = 10.0  
            obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.__eps)
            obs = np.clip(obs, -clip_max, clip_max)
        return obs

if __name__ == "__main__":
    import gym
    from random_env import get_init_params, get_random_params
    env = gym.make("HalfCheetah-v3")
    envs = BaseVectorEnv(
        [lambda: gym.make("HalfCheetah-v3") for _ in range(16)], norm_obs=True
    )

    envs.reset()
    p = envs.get_env_attr("body_mass")
    print(len(p),p[0])
    p = envs.get_env_attr("body_inertia")
    print(len(p),p[0])
    p = envs.get_env_attr("dof_damping")
    print(len(p),p[0])        
    p = envs.get_env_attr("geom_friction")
    print(len(p),p[0])
    print("################################")

    init_params = get_init_params(env)
    random_params = get_random_params(init_params, log_scale_limit = 3.0)
    envs.set_env_attr("body_mass",random_params["body_mass"])
    envs.set_env_attr("body_inertia",random_params["body_inertia"])
    envs.set_env_attr("dof_damping",random_params["dof_damping"])
    envs.set_env_attr("geom_friction",random_params["geom_friction"])
    print("################################")

    envs.reset()
    p = envs.get_env_attr("body_mass")
    print(len(p),p[0])
    p = envs.get_env_attr("body_inertia")
    print(len(p),p[0])
    p = envs.get_env_attr("dof_damping")
    print(len(p),p[0])        
    p = envs.get_env_attr("geom_friction")
    print(len(p),p[0])
    print("################################")   
    
