import collections
import os
import random
import xml.etree.ElementTree as ET

from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from dm_control.suite import common
from dm_control.suite.utils import randomizers
from dm_control.utils import containers
from dm_control.utils import rewards
from lxml import etree
import numpy as np
from six.moves import range

from dmc_remastered import SUITE_DIR, register, DMCR_VARY
from dmc_remastered.rng import dmcr_random
from .generate_visuals import get_assets

_DEFAULT_TIME_LIMIT = 30
_CONTROL_TIMESTEP = 0.03  # (Seconds)


def get_model(visual_seed, vary=["camera", "light"]):
    with open(os.path.join(SUITE_DIR, os.path.join("assets", "swimmer.xml")), "r") as f:
        xml = ET.fromstring(f.read())
    if visual_seed != 0:
        with dmcr_random(visual_seed):
            camera_x = random.uniform(-0.05, 0.05)
            camera_y = random.uniform(-0.25, -0.15)
            camera_z = random.uniform(0.45, 0.55)

            light_x = random.uniform(-1, -0.8)
            light_y = random.uniform(-0.1, 0.1)
            light_z = random.uniform(1.4, 1.6)
        if "camera" in vary:
            xml[5][1][4].attrib["pos"] = f"{camera_x} {camera_y} {camera_z}"
        """
        if 'light' in vary:
            xml[5][1][0].attrib["pos"] = f"{light_x} {light_y} {light_z}"
        """
    return ET.tostring(xml, encoding="utf8", method="xml")


@register("swimmer", "swimmer6")
def swimmer6(
    time_limit=_DEFAULT_TIME_LIMIT, dynamics_seed=None, visual_seed=None, vary=DMCR_VARY
):
    return _make_swimmer(
        6, time_limit, dynamics_seed=dynamics_seed, visual_seed=visual_seed, vary=vary
    )


@register("swimmer", "swimmer15")
def swimmer15(
    time_limit=_DEFAULT_TIME_LIMIT, dynamics_seed=None, visual_seed=None, vary=DMCR_VARY
):
    return _make_swimmer(
        15, time_limit, dynamics_seed=dynamics_seed, visual_seed=visual_seed, vary=vary
    )


def _make_swimmer(
    n_joints,
    time_limit=_DEFAULT_TIME_LIMIT,
    dynamics_seed=None,
    visual_seed=None,
    vary=DMCR_VARY,
):
    model = get_model(visual_seed)
    assets, _ = get_assets(visual_seed)
    physics = Physics.from_xml_string(model, assets)
    task = Swimmer(random=dynamics_seed)
    return control.Environment(
        physics, task, time_limit=time_limit, control_timestep=_CONTROL_TIMESTEP,
    )


def _make_model(n_bodies):
    """Generates an xml string defining a swimmer with `n_bodies` bodies."""
    if n_bodies < 3:
        raise ValueError("At least 3 bodies required. Received {}".format(n_bodies))
    mjcf = etree.fromstring(common.read_model("swimmer.xml"))
    head_body = mjcf.find("./worldbody/body")
    actuator = etree.SubElement(mjcf, "actuator")
    sensor = etree.SubElement(mjcf, "sensor")

    parent = head_body
    for body_index in range(n_bodies - 1):
        site_name = "site_{}".format(body_index)
        child = _make_body(body_index=body_index)
        child.append(etree.Element("site", name=site_name))
        joint_name = "joint_{}".format(body_index)
        joint_limit = 360.0 / n_bodies
        joint_range = "{} {}".format(-joint_limit, joint_limit)
        child.append(etree.Element("joint", {"name": joint_name, "range": joint_range}))
        motor_name = "motor_{}".format(body_index)
        actuator.append(etree.Element("motor", name=motor_name, joint=joint_name))
        velocimeter_name = "velocimeter_{}".format(body_index)
        sensor.append(
            etree.Element("velocimeter", name=velocimeter_name, site=site_name)
        )
        gyro_name = "gyro_{}".format(body_index)
        sensor.append(etree.Element("gyro", name=gyro_name, site=site_name))
        parent.append(child)
        parent = child

    # Move tracking cameras further away from the swimmer according to its length.
    cameras = mjcf.findall("./worldbody/body/camera")
    scale = n_bodies / 6.0
    for cam in cameras:
        if cam.get("mode") == "trackcom":
            old_pos = cam.get("pos").split(" ")
            new_pos = " ".join([str(float(dim) * scale) for dim in old_pos])
            cam.set("pos", new_pos)

    return etree.tostring(mjcf, pretty_print=True)


def _make_body(body_index):
    """Generates an xml string defining a single physical body."""
    body_name = "segment_{}".format(body_index)
    visual_name = "visual_{}".format(body_index)
    inertial_name = "inertial_{}".format(body_index)
    body = etree.Element("body", name=body_name)
    body.set("pos", "0 .1 0")
    etree.SubElement(body, "geom", {"class": "visual", "name": visual_name})
    etree.SubElement(body, "geom", {"class": "inertial", "name": inertial_name})
    return body


class Physics(mujoco.Physics):
    """Physics simulation with additional features for the swimmer domain."""

    def nose_to_target(self):
        """Returns a vector from nose to target in local coordinate of the head."""
        nose_to_target = (
            self.named.data.geom_xpos["target"] - self.named.data.geom_xpos["nose"]
        )
        head_orientation = self.named.data.xmat["head"].reshape(3, 3)
        return nose_to_target.dot(head_orientation)[:2]

    def nose_to_target_dist(self):
        """Returns the distance from the nose to the target."""
        return np.linalg.norm(self.nose_to_target())

    def body_velocities(self):
        """Returns local body velocities: x,y linear, z rotational."""
        xvel_local = self.data.sensordata[12:].reshape((-1, 6))
        vx_vy_wz = [0, 1, 5]  # Indices for linear x,y vels and rotational z vel.
        return xvel_local[:, vx_vy_wz].ravel()

    def joints(self):
        """Returns all internal joint angles (excluding root joints)."""
        return self.data.qpos[3:].copy()


class Swimmer(base.Task):
    """A swimmer `Task` to reach the target or just swim."""

    def __init__(self, random=None):
        """Initializes an instance of `Swimmer`.
    Args:
      random: Optional, either a `numpy.random.RandomState` instance, an
        integer seed for creating a new `RandomState`, or None to select a seed
        automatically (default).
    """
        super(Swimmer, self).__init__(random=random)

    def initialize_episode(self, physics):
        """Sets the state of the environment at the start of each episode.
    Initializes the swimmer orientation to [-pi, pi) and the relative joint
    angle of each joint uniformly within its range.
    Args:
      physics: An instance of `Physics`.
    """
        # Random joint angles:
        randomizers.randomize_limited_and_rotational_joints(physics, self.random)
        # Random target position.
        close_target = self.random.rand() < 0.2  # Probability of a close target.
        target_box = 0.3 if close_target else 2
        xpos, ypos = self.random.uniform(-target_box, target_box, size=2)
        physics.named.model.geom_pos["target", "x"] = xpos
        physics.named.model.geom_pos["target", "y"] = ypos
        physics.named.model.light_pos["target_light", "x"] = xpos
        physics.named.model.light_pos["target_light", "y"] = ypos

        super(Swimmer, self).initialize_episode(physics)

    def get_observation(self, physics):
        """Returns an observation of joint angles, body velocities and target."""
        obs = collections.OrderedDict()
        obs["joints"] = physics.joints()
        obs["to_target"] = physics.nose_to_target()
        obs["body_velocities"] = physics.body_velocities()
        return obs

    def get_reward(self, physics):
        """Returns a smooth reward."""
        target_size = physics.named.model.geom_size["target", 0]
        return rewards.tolerance(
            physics.nose_to_target_dist(),
            bounds=(0, target_size),
            margin=5 * target_size,
            sigmoid="long_tail",
        )
