import numpy as np

import gym
import matplotlib

# matplotlib.use("Agg")
import matplotlib.pyplot as plt

from matplotlib import patches
import imageio

# from envs.d4rl.pixel_wrappers import RenderWrapper
# import d4rl

import functools as ft


def valid_goal_sampler(self, np_random):
    valid_cells = []
    goal_cells = []

    for i in range(len(self._maze_map)):
        for j in range(len(self._maze_map[0])):
            if self._maze_map[i][j] in [0, "r", "g"]:
                valid_cells.append((i, j))

    sample_choices = valid_cells
    cell = sample_choices[np_random.choice(len(sample_choices))]
    xy = self._rowcol_to_xy(cell, add_random_noise=True)

    random_x = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling
    random_y = np.random.uniform(low=0, high=0.5) * 0.25 * self._maze_size_scaling

    xy = (max(xy[0] + random_x, 0), max(xy[1] + random_y, 0))

    return xy


class GoalReachingAnt(gym.Wrapper):
    def __init__(self, env_name):
        self.env = gym.make(env_name)
        # self.env.env.env._wrapped_env.goal_sampler = ft.partial(
        #     valid_goal_sampler, self.env.env.env._wrapped_env
        # )
        self.observation_space = gym.spaces.Dict(
            {
                "observation": self.env.observation_space,
                "goal": self.env.observation_space,
            }
        )
        self.action_space = self.env.action_space

    def step(self, action):
        next_obs, r, done, info = self.env.step(action)

        achieved = self.get_xy()
        desired = self.target_goal
        distance = np.linalg.norm(achieved - desired)
        info["x"], info["y"] = achieved
        info["achieved_goal"] = np.array(achieved)
        info["desired_goal"] = np.copy(desired)
        info["success"] = float(distance < 0.5)
        done = "TimeLimit.truncated" in info

        return self.get_obs(next_obs), r, done, info

    def get_obs(self, obs):
        target_goal = obs.copy()
        target_goal[:2] = self.target_goal
        return dict(observation=obs, goal=target_goal)

    def reset(self):
        obs = self.env.reset()
        return self.get_obs(obs)

    def get_starting_boundary(self):
        self = self.env.env.env
        torso_x, torso_y = self._init_torso_x, self._init_torso_y
        S = self._maze_size_scaling
        return (0 - S / 2 + S - torso_x, 0 - S / 2 + S - torso_y), (
            len(self._maze_map[0]) * S - torso_x - S / 2 - S,
            len(self._maze_map) * S - torso_y - S / 2 - S,
        )

    def XY(self, n=20):
        bl, tr = self.get_starting_boundary()
        X = np.linspace(
            bl[0] + 0.04 * (tr[0] - bl[0]), tr[0] - 0.04 * (tr[0] - bl[0]), n
        )
        Y = np.linspace(
            bl[1] + 0.04 * (tr[1] - bl[1]), tr[1] - 0.04 * (tr[1] - bl[1]), n
        )

        X, Y = np.meshgrid(X, Y)
        states = np.array([X.flatten(), Y.flatten()]).T
        return states

    def four_goals(self):
        self = self.env.env.env

        valid_cells = []
        goal_cells = []

        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                if self._maze_map[i][j] in [0, "r", "g"]:
                    valid_cells.append(
                        self._rowcol_to_xy((i, j), add_random_noise=False)
                    )

        goals = []
        goals.append(max(valid_cells, key=lambda x: -x[0] - x[1]))
        goals.append(max(valid_cells, key=lambda x: x[0] - x[1]))
        goals.append(max(valid_cells, key=lambda x: x[0] + x[1]))
        goals.append(max(valid_cells, key=lambda x: -x[0] + x[1]))
        return goals

    def draw(self, ax=None):
        if not ax:
            ax = plt.gca()
        self = self.env.env
        torso_x, torso_y = self._init_torso_x, self._init_torso_y
        S = self._maze_size_scaling
        for i in range(len(self._maze_map)):
            for j in range(len(self._maze_map[0])):
                struct = self._maze_map[i][j]
                if struct == 1:
                    rect = patches.Rectangle(
                        (j * S - torso_x - S / 2, i * S - torso_y - S / 2),
                        S,
                        S,
                        linewidth=1,
                        edgecolor="none",
                        facecolor="grey",
                        alpha=1.0,
                    )

                    ax.add_patch(rect)
        ax.set_xlim(
            0 - S / 2 + 0.6 * S - torso_x,
            len(self._maze_map[0]) * S - torso_x - S / 2 - S * 0.6,
        )
        ax.set_ylim(
            0 - S / 2 + 0.6 * S - torso_y,
            len(self._maze_map) * S - torso_y - S / 2 - S * 0.6,
        )
        ax.axis("off")


# env = GoalReachingAnt("antmaze-medium-play-v0")
# dataset = d4rl.qlearning_dataset(env)
# dataset["masks"] = 1.0 - dataset["terminals"]
# dataset["dones_float"] = 1.0 - np.isclose(
#     np.roll(dataset["observations"], -1, axis=0), dataset["next_observations"]
# ).all(-1)
# dataset = Dataset.create(**dataset)
# return env, dataset


# env = gym.make("antmaze-medium-play-v0")
# env = RenderWrapper(env)

import dowel_wrapper

assert dowel_wrapper is not None
import dowel

from garage import wrap_experiment
from garage.experiment.deterministic import set_seed
from garage.torch.distributions import TanhNormal

from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garagei.experiment.option_local_runner import OptionLocalRunner
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from garagei.torch.modules.with_encoder import WithEncoder, Encoder
from garagei.torch.modules.gaussian_mlp_module_ex import (
    GaussianMLPTwoHeadedModuleEx,
    GaussianMLPIndependentStdModuleEx,
    GaussianMLPModuleEx,
)
from garagei.torch.modules.gaussian_lstm_module_ex import (
    GaussianLSTMTwoHeadedModuleEx,
    GaussianLSTMIndependentStdModuleEx,
    GaussianLSTMModuleEx,
)
from garagei.torch.modules.parameter_module import ParameterModule
from garagei.torch.policies.policy_ex import PolicyEx, RecurrentPolicyEx
from garagei.torch.q_functions.continuous_mlp_q_function_ex import (
    ContinuousMLPQFunctionEx,
)
from garagei.torch.q_functions.continuous_lstm_q_function_ex import (
    ContinuousLSTMQFunctionEx,
)
from garagei.torch.optimizers.optimizer_group_wrapper import OptimizerGroupWrapper
from garagei.torch.utils import xavier_normal_ex
from iod.metra import METRA
from iod.recurrent_metra import RecurrentMETRA
from iod.dads import DADS
from iod.utils import get_normalizer_preset

floor_size = None
walls_and_ball = False
terrain = False
rangefinders = False

from dm_control import mujoco
from dm_control.mujoco.wrapper import mjbindings
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.utils import xml_tools
from lxml import etree
import numpy as np

# enums = mjbindings.enums
# mjlib = mjbindings.mjlib

# parser = etree.XMLParser(remove_blank_text=True)
# xml_string = common.read_model(
#     os.path.join(root_dir, "custom_dmc_tasks", "quadruped_color.xml")
# )
# mjcf = etree.XML(xml_string, parser)

# # Set floor size.
# if floor_size is not None:
#     floor_geom = mjcf.find(".//geom[@name='floor']")
#     floor_geom.attrib["size"] = f"{floor_size} {floor_size} .5"

# # Remove walls, ball and target.
# if not walls_and_ball:
#     for wall in _WALLS:
#         wall_geom = xml_tools.find_element(mjcf, "geom", wall)
#         wall_geom.getparent().remove(wall_geom)

#     # Remove ball.
#     ball_body = xml_tools.find_element(mjcf, "body", "ball")
#     ball_body.getparent().remove(ball_body)

#     # Remove target.
#     target_site = xml_tools.find_element(mjcf, "site", "target")
#     target_site.getparent().remove(target_site)

# # Remove terrain.
# if not terrain:
#     terrain_geom = xml_tools.find_element(mjcf, "geom", "terrain")
#     terrain_geom.getparent().remove(terrain_geom)

# # Remove rangefinders if they're not used, as range computations can be
# # expensive, especially in a scene with heightfields.
# if not rangefinders:
#     rangefinder_sensors = mjcf.findall(".//rangefinder")
#     for rf in rangefinder_sensors:
#         rf.getparent().remove(rf)

# return etree.tostring(mjcf, pretty_print=True)


from envs.custom_dmc_tasks import dmc
from envs.custom_dmc_tasks.pixel_wrappers import RenderWrapper


# env = dmc.make(
#     "quadruped_run_forward_color",
#     obs_type="states",
#     frame_stack=1,
#     action_repeat=2,
#     seed=10,
# )
# env = RenderWrapper(env)
import d4rl

#
# env = gym.make("antmaze-medium-play-v0")
# env = gym.make("antmaze-large-playinthemiddle-v2")
env = gym.make("antmaze-large-play-v0")

CHANGE_FLOORS = False
CHANGE_WALLS = True
if CHANGE_FLOORS:
    l = len(env.physics.model.tex_type)
    for i in range(l):
        if env.physics.model.tex_type[i] == 0:
            height = env.physics.model.tex_height[i]
            width = env.physics.model.tex_width[i]
            s = env.physics.model.tex_adr[i]
            for x in range(height):
                for y in range(width):
                    cur_s = s + (x * width + y) * 3
                    env.physics.model.tex_rgb[cur_s : cur_s + 3] = [
                        int(x / height * 255),
                        int(y / width * 255),
                        128,
                    ]
    env.physics.model.mat_texrepeat[:, :] = 1

if CHANGE_WALLS:
    # find the walls (geom type and name starts with block_)
    walls = []

    for i in range(env.physics.model.ngeom):
        geom_name = env.physics.model.geom_id2name(i)
        if geom_name.startswith("block_"):
            walls.append(i)
    # change the texture of the walls
    cmap = plt.get_cmap("tab20")
    for wall in walls:
        env.physics.model.geom_rgba[wall] = cmap(wall % 20)

    # for index in range(9):
    #     wall_textures[str(index + 1)] = WallTexture(cmap(index)[:3])
env.reset()

tensor = []
for _ in range(1000):
    env.step(env.action_space.sample())
    # env.draw()

    res = env.render(height=256, width=256, mode="rgb_array", camera_name="rotate")
    tensor.append(res)
    if _ % 200 == 0:
        env.reset()
from moviepy import editor as mpy

clip = mpy.ImageSequenceClip(list(tensor), fps=30)

# plot_path = (
#     pathlib.Path(runner._snapshotter.snapshot_dir)
#     / "plots"
#     # / f'{label}_{runner.step_itr}.gif')
#     / f"{label}_{runner.step_itr}.mp4"
# )
# plot_path.parent.mkdir(parents=True, exist_ok=True)

clip.write_videofile(str("vid.mp4"), audio=False, verbose=False, logger=None)
