import numpy as np
import os

os.environ["MUJOCO_GL"] = "egl"

import gym
import matplotlib

# matplotlib.use("Agg")
import matplotlib.pyplot as plt

from matplotlib import patches
import imageio

# from envs.d4rl.pixel_wrappers import RenderWrapper
# import d4rl

import functools as ft

import dowel_wrapper

assert dowel_wrapper is not None
import dowel

from garage import wrap_experiment
from garage.experiment.deterministic import set_seed
from garage.torch.distributions import TanhNormal

from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garagei.experiment.option_local_runner import OptionLocalRunner
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from garagei.torch.modules.with_encoder import WithEncoder, Encoder
from garagei.torch.modules.gaussian_mlp_module_ex import (
    GaussianMLPTwoHeadedModuleEx,
    GaussianMLPIndependentStdModuleEx,
    GaussianMLPModuleEx,
)
from garagei.torch.modules.gaussian_lstm_module_ex import (
    GaussianLSTMTwoHeadedModuleEx,
    GaussianLSTMIndependentStdModuleEx,
    GaussianLSTMModuleEx,
)
from garagei.torch.modules.parameter_module import ParameterModule
from garagei.torch.policies.policy_ex import PolicyEx, RecurrentPolicyEx
from garagei.torch.q_functions.continuous_mlp_q_function_ex import (
    ContinuousMLPQFunctionEx,
)
from garagei.torch.q_functions.continuous_lstm_q_function_ex import (
    ContinuousLSTMQFunctionEx,
)
from garagei.torch.optimizers.optimizer_group_wrapper import OptimizerGroupWrapper
from garagei.torch.utils import xavier_normal_ex
from iod.metra import METRA
from iod.recurrent_metra import RecurrentMETRA
from iod.dads import DADS
from iod.utils import get_normalizer_preset

from dm_control import mujoco
from dm_control.mujoco.wrapper import mjbindings
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import xml_tools
from lxml import etree
import numpy as np

from envs.custom_dmc_tasks import dmc
from envs.custom_dmc_tasks.pixel_wrappers import RenderWrapper
from envs.custom_dmc_tasks.pixel_wrappers import FrameStackWrapper

import d4rl

from main import make_env

import sys

print("usage: --env env_name")
env_name = sys.argv[2]

from envs.custom_dmc_tasks import dmc

# env = dmc.make(
#     "quadruped_escape",
#     obs_type="states",
#     frame_stack=1,
#     action_repeat=2,
#     seed=0,
#     task_kwargs={
#         "random": 0,
#     },
# )


# def _find_non_contacting_height(physics, orientation, x_pos=0.0, y_pos=0.0):
#     """Find a height with no contacts given a body orientation.
#     Args:
#       physics: An instance of `Physics`.
#       orientation: A quaternion.
#       x_pos: A float. Position along global x-axis.
#       y_pos: A float. Position along global y-axis.
#     Raises:
#       RuntimeError: If a non-contacting configuration has not been found after
#       10,000 attempts.
#     """
#     z_pos = 0.0  # Start embedded in the floor.
#     num_contacts = 1
#     num_attempts = 0
#     # Move up in 1cm increments until no contacts.
#     while num_contacts > 0:
#         try:
#             with physics.reset_context():
#                 physics.named.data.qpos["root"][:3] = x_pos, y_pos, z_pos
#                 physics.named.data.qpos["root"][3:] = orientation
#         except control.PhysicsError:
#             # We may encounter a PhysicsError here due to filling the contact
#             # buffer, in which case we simply increment the height and continue.
#             pass
#         num_contacts = physics.data.ncon
#         z_pos += 0.01
#         num_attempts += 1
#         if num_attempts > 10000:
#             raise RuntimeError("Failed to find a non-contacting configuration.")


# env.reset()
# # first we simply plot the height field of this env
# fig = plt.figure(figsize=(10, 10))
# physics = env.physics
# # with physics.reset_context():
# #                 physics.named.data.qpos["root"][:3] = x_pos, y_pos, z_pos
# #                 physics.named.data.qpos["root"][3:] = orientation

# print(physics.model.hfield_data.shape)
# _HEIGHTFIELD_ID = 0
# res = physics.model.hfield_nrow[_HEIGHTFIELD_ID]
# hfield = physics.model.hfield_data
# print(hfield.min(), hfield.max(), hfield.mean())

# hfield_data = hfield.reshape(res, res)
# hfield_data = hfield_data[50:150, 50:150]  # 201
# print(hfield_data.shape)
# plt.imshow(hfield_data)
# plt.colorbar()  # Adds a colorbar to indicate the scale of the imshow plot

# plt.savefig("heightfield.png")


# def generate_goal_obs():
#     env.reset()
#     state = env.physics.get_state().copy()
#     orientation = np.random.randn(4)

#     # make a radius xpos and ypos
#     # Generate points around the circumference of the circle
#     radius = 10
#     theta = np.linspace(0, 2 * np.pi, 48)
#     x = radius * np.cos(theta)
#     y = radius * np.sin(theta)
#     z = np.zeros_like(x)

#     goals = []
#     for i in range(len(x)):
#         x_pos = x[i]
#         y_pos = y[i]
#         _find_non_contacting_height(physics, orientation, x_pos, y_pos)
#         ob, *_ = env.step(env.action_space.sample())
#         # ob should be our goal
#         goals.append(ob.copy())
#         env.reset()

#     return goals


# # and also we need to fix the


# # img = env.render(camera_id=0)
# # plt.imsave("img.png", img)
# goals = generate_goal_obs()

# print(goals)


# state = env.physics.get_state().copy()
# goal_loc = (np.random.rand(2) * 2 - 1) * 100
# print(goal_loc)
# state[:2] = goal_loc

# state[2] = 40
# env.physics.set_state(state)
# img = env.render(camera_id=0)
# plt.imsave("img.png", img)
# exit(0)

# with FigManager(
#     runner, f"{description_prefix}TrajPhiPlot_RandomZ", subplot_spec=(1, 2)
# ) as fm:
#     runner._env.render_trajectories(
#         random_trajectories, option_colors, self.eval_plot_axis, fm.ax[0]
#     )
#     if self.goal_reaching:
#         # draw options on top of this
#         goals = runner._env._apply_unnormalize_obs(options)
#         for goal, color in zip(goals, option_colors):
#             fm.ax[0].plot(goal[0], goal[1], "*", color=color, markersize=10)


class args:
    # env = "dmc_humanoid_state"
    # env = "dmc_jaco_state"
    # env = "fetchpush"
    # env = env_name
    env = env_name
    encoder = 0
    seed = 0
    frame_stack = None
    normalizer_type = "off"


print(env_name)
env = make_env(args, 200)

print(env.observation_space)
print(env.action_space)
print(env.action_space.sample())

ob = env.reset()
print(ob.shape)

tensor = []
for i in range(300):
    ob, _, _, _ = env.step(env.action_space.sample())
    # env.draw()

    res = env.render(mode="rgb_array")
    # res[:, :128, :] = 0)
    tensor.append(res)
    # res = env.render(mode="rgb_array")
    # plt.imshow(res)
    # plt.show()
    if i % 100 == 0:
        env.reset()
        print("reset every 100 steps")
from moviepy import editor as mpy

clip = mpy.ImageSequenceClip(list(tensor), fps=30)

# plot_path = (
#     pathlib.Path(runner._snapshotter.snapshot_dir)
#     / "plots"
#     # / f'{label}_{runner.step_itr}.gif')
#     / f"{label}_{runner.step_itr}.mp4"
# )
# plot_path.parent.mkdir(parents=True, exist_ok=True)

clip.write_videofile(str("vid.mp4"), audio=False, verbose=False, logger=None)

print("result saved as vid.mp4")
