import numpy as np
import os

os.environ["MUJOCO_GL"] = "egl"

import gym
import matplotlib

# matplotlib.use("Agg")
import matplotlib.pyplot as plt

from matplotlib import patches
import imageio

# from envs.d4rl.pixel_wrappers import RenderWrapper
# import d4rl

import functools as ft

import dowel_wrapper

assert dowel_wrapper is not None
import dowel

from garage import wrap_experiment
from garage.experiment.deterministic import set_seed
from garage.torch.distributions import TanhNormal

from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garagei.experiment.option_local_runner import OptionLocalRunner
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from garagei.torch.modules.with_encoder import WithEncoder, Encoder
from garagei.torch.modules.gaussian_mlp_module_ex import (
    GaussianMLPTwoHeadedModuleEx,
    GaussianMLPIndependentStdModuleEx,
    GaussianMLPModuleEx,
)
from garagei.torch.modules.gaussian_lstm_module_ex import (
    GaussianLSTMTwoHeadedModuleEx,
    GaussianLSTMIndependentStdModuleEx,
    GaussianLSTMModuleEx,
)
from garagei.torch.modules.parameter_module import ParameterModule
from garagei.torch.policies.policy_ex import PolicyEx, RecurrentPolicyEx
from garagei.torch.q_functions.continuous_mlp_q_function_ex import (
    ContinuousMLPQFunctionEx,
)
from garagei.torch.q_functions.continuous_lstm_q_function_ex import (
    ContinuousLSTMQFunctionEx,
)
from garagei.torch.optimizers.optimizer_group_wrapper import OptimizerGroupWrapper
from garagei.torch.utils import xavier_normal_ex
from iod.metra import METRA
from iod.recurrent_metra import RecurrentMETRA
from iod.dads import DADS
from iod.utils import get_normalizer_preset

from dm_control import mujoco
from dm_control.mujoco.wrapper import mjbindings
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import xml_tools
from lxml import etree
import numpy as np

from envs.custom_dmc_tasks import dmc
from envs.custom_dmc_tasks.pixel_wrappers import RenderWrapper
from envs.custom_dmc_tasks.pixel_wrappers import FrameStackWrapper

import d4rl

from main import make_env


class args:
    env = "dmc_quadruped_state_step"
    encoder = 0
    seed = 0
    # frame_stack = 3
    frame_stack = None
    normalizer_type = "off"


env = make_env(args, 200)

# ob = env.reset()

tensor = []
for tot in range(1000):
    if tot % 500 == 0:
        ob = env.reset()
        print("reset", tot)
    else:
        action = env.env.action_space.sample()
        ob, _, _, _ = env.step(action)

    res = env.render(height=256, width=256, mode="rgb_array")
    tensor.append(res)


from moviepy import editor as mpy

clip = mpy.ImageSequenceClip(list(tensor), fps=30)

# plot_path = (
#     pathlib.Path(runner._snapshotter.snapshot_dir)
#     / "plots"
#     # / f'{label}_{runner.step_itr}.gif')
#     / f"{label}_{runner.step_itr}.mp4"
# )
# plot_path.parent.mkdir(parents=True, exist_ok=True)

clip.write_videofile(str("vid.mp4"), audio=False, verbose=False, logger=None)
