import numpy as np

import gym
import matplotlib

# matplotlib.use("Agg")
import matplotlib.pyplot as plt

from matplotlib import patches
import imageio

import functools as ft

from envs.mujoco.ant_env import AntEnv

from lxml import etree

# env = AntEnv(render_hw=100, model_path="ant_hill.xml", hill_height=10, half=True)
from gym.vector import AsyncVectorEnv
from gym.vector.utils import (
    create_shared_memory,
    create_empty_array,
    write_to_shared_memory,
    read_from_shared_memory,
    concatenate,
    CloudpickleWrapper,
    clear_mpi_env_vars,
)
import sys


def sim_get_state(self, data=None):
    self._assert_is_running()
    for pipe in self.parent_pipes:
        pipe.send(("sim.get_state", data))
    states, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
    self._raise_if_errors(successes)
    return states


def sim_set_state(self, states):
    self._assert_is_running()
    for pipe, state in zip(self.parent_pipes, states):
        pipe.send(("sim.set_state", state))
    _, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
    self._raise_if_errors(successes)
    return True


vectorized_env = AsyncVectorEnv(
    [
        lambda: AntEnv(
            render_hw=100, model_path="ant_hill.xml", hill_height=10, half=True
        )
        for _ in range(8)
    ],
    worker=_worker_shared_memory,
)
vectorized_env.sim_get_state = ft.partial(sim_get_state, vectorized_env)
vectorized_env.sim_set_state = ft.partial(sim_set_state, vectorized_env)

print(vectorized_env)
# # print(states)
# flattened = states[0].flatten()
# print(flattened)

# vectorized_env.envs[0].sim.get_state().from_flattened(flattened)


# prev_state = vectorized_env.sim.get_state()
tensor = []
states = []
for _ in range(100):
    res, _, _, _ = vectorized_env.step(vectorized_env.action_space.sample())
    # state = vectorized_env.envs[0].sim.get_state().flatten()
    res = vectorized_env.sim_get_state()
    states.append(res)
    # res = vectorized_env.render(height=256, width=256, mode="rgb_array", camera_id=0)
    # tensor.append(res)
print(len(states))
print(len(states[0]))
print(states[0][0].shape)

for i in range(1000):
    vectorized_env.step(vectorized_env.action_space.sample())
    if i < 100:
        vectorized_env.sim_set_state(states[i])
    # res = vectorized_env.render(height=256, width=256, mode="rgb_array", camera_id=0)
    tensor.append(res)

from moviepy import editor as mpy

clip = mpy.ImageSequenceClip(list(tensor), fps=30)

# plot_path = (
#     pathlib.Path(runner._snapshotter.snapshot_dir)
#     / "plots"
#     # / f'{label}_{runner.step_itr}.gif')
#     / f"{label}_{runner.step_itr}.mp4"
# )
# plot_path.parent.mkdir(parents=True, exist_ok=True)

clip.write_videofile(str("vid.mp4"), audio=False, verbose=False, logger=None)
