"""Environment definition for the Ant environment.

Robot definitions are from DeepMind's PyMJCF tutorial:
https://colab.research.google.com/github/deepmind/dm_control/blob/main/tutorial.ipynb#scrollTo=UAMItwu8e1WR
"""
from dataclasses import dataclass
from typing import Union
from dm_control import mjcf
import numpy as np


BODY_RADIUS = 0.1
BODY_SIZE = (BODY_RADIUS, BODY_RADIUS, BODY_RADIUS / 2)
random_state = np.random.RandomState(42)


class Leg(object):
    """A 2-DoF leg with position actuators.

    Source:
    https://colab.research.google.com/github/deepmind/dm_control/blob/main/tutorial.ipynb#scrollTo=gKny5EJ4uVzu
    """

    def __init__(self, length, rgba):
        self.model = mjcf.RootElement()

        # Defaults:
        self.model.default.joint.damping = 2
        self.model.default.joint.type = 'hinge'
        self.model.default.geom.type = 'capsule'
        self.model.default.geom.rgba = rgba  # Continued below...

        # Thigh:
        self.thigh = self.model.worldbody.add('body')
        self.hip = self.thigh.add('joint', axis=[0, 0, 1])
        self.thigh.add('geom', fromto=[0, 0, 0, length, 0, 0], size=[length/4])

        # Hip:
        self.shin = self.thigh.add('body', pos=[length, 0, 0])
        self.knee = self.shin.add('joint', axis=[0, 1, 0])
        self.shin.add('geom', fromto=[0, 0, 0, 0, 0, -length], size=[length/5])

        # Position actuators:
        self.model.actuator.add('position', joint=self.hip, kp=10)
        self.model.actuator.add('position', joint=self.knee, kp=10)


def make_creature(num_legs: int) -> mjcf.RootElement:
    """Constructs a creature with `num_legs` legs.

    Source:
    https://colab.research.google.com/github/deepmind/dm_control/blob/main/tutorial.ipynb#scrollTo=SESlL_TidKHx
    """
    model = mjcf.RootElement()
    model.compiler.angle = 'radian'  # Use radians.
    rgba = random_state.uniform([0, 0, 0, 1], [1, 1, 1, 1])

    # Make the torso geom.
    model.worldbody.add(
        'geom',
        name='torso',
        type='ellipsoid',
        size=BODY_SIZE,
        rgba=rgba
    )

    # Attach legs to equidistant sites on the circumference.
    for i in range(num_legs):
        theta = 2 * i * np.pi / num_legs
        hip_pos = BODY_RADIUS * np.array([np.cos(theta), np.sin(theta), 0])
        hip_site = model.worldbody.add(
            'site',
            pos=hip_pos,
            euler=[0, 0, theta]
        )
        leg = Leg(length=BODY_RADIUS, rgba=rgba)
        hip_site.attach(leg.model)
    return model


@dataclass(frozen=True)
class Cylinder:
    x: float
    y: float
    r: float
    h: float


@dataclass(frozen=True)
class Box:
    x: float
    y: float
    w: float
    l: float
    h: float


def make_cylinder(cylinder: Cylinder) -> mjcf.RootElement:
    """Make a cylindrical objetct."""
    root = mjcf.RootElement()
    # Place button
    root.worldbody.add(
        'geom',
        type='cylinder',
        size=(cylinder.r, cylinder.h),
        rgba=[0.2, 0.5, 0.2, 1],
        pos=(cylinder.x, cylinder.y, 0),
    )
    return root


def make_box(box: Box) -> mjcf.RootElement:
    """Make a cylindrical objetct."""
    root = mjcf.RootElement()
    # Place button
    root.worldbody.add(
        'geom',
        type='box',
        size=(box.w, box.l, box.h),
        rgba=[0.2, 0.2, 0.2, 1],
        pos=(box.x, box.y, 0),
    )
    return root


def make_arena() -> mjcf.RootElement:
    """ Create a flat arena.

    Source:
    https://colab.research.google.com/github/deepmind/dm_control/blob/main/tutorial.ipynb#scrollTo=F7_Tx9P9U_VJ
    """
    arena = mjcf.RootElement()
    chequered = arena.asset.add(
        'texture',
        type='2d',
        builtin='checker',
        width=300,
        height=300,
        rgb1=[.2, .3, .4],
        rgb2=[.3, .4, .5]
    )
    grid = arena.asset.add(
        'material',
        name='grid',
        texture=chequered,
        texrepeat=[5, 5],
        reflectance=.2
    )
    arena.worldbody.add(
        'geom',
        type='plane',
        size=[2, 2, .1],
        material=grid
    )
    for x in [-2, 2]:
        arena.worldbody.add('light', pos=[x, -1, 3], dir=[-x, 1, -2])
    return arena


class AntSimulation:
    """Encapsulates the simulator state of an ant simulation. The ant is
    controlled through a time-varying parametric periodic signal.
    """

    def __init__(
            self,
            num_legs: int,
            num_buttons: int,
            button_combination: tuple[int, ...],
            animation_fps: Union[None, float],
            sub_step_s: float,  # how much time to advance for step()
            torso_positions_recording_fps: float,
            ):
        self.torso_positions_recording_fps = torso_positions_recording_fps
        # Create an arena
        arena = make_arena()

        distance = 5*BODY_RADIUS

        # Place buttons in the arena
        button_distance = 3*BODY_RADIUS
        button_total_distance = button_distance*(num_buttons-1)
        self.buttons = [
            Cylinder(
                x=i*button_distance-button_total_distance/2,
                y=distance,
                r=BODY_RADIUS,
                h=BODY_RADIUS/10,
            )
            for i in range(num_buttons)
        ]
        self.button_roots = [make_cylinder(button) for button in self.buttons]
        for button_root in self.button_roots:
            button_spawn = arena.worldbody.add('site')
            button_spawn.attach(button_root)

        # Store button combination
        self.button_combination = button_combination
        self.activated_buttons = list[int]()

        # Place goal in the arena
        self.goal = Cylinder(
            x=0,
            y=-distance,
            r=BODY_RADIUS,
            h=BODY_RADIUS/10,
        )
        goal = make_cylinder(self.goal)
        goal_spawn = arena.worldbody.add('site')
        goal_spawn.attach(goal)

        # Place walls around goal
        walls = [
            Box(
                x=self.goal.x,
                y=self.goal.y-2*self.goal.r,
                w=BODY_RADIUS*2,
                l=BODY_RADIUS*0.1,
                h=BODY_RADIUS*2,
            ),
            Box(
                x=self.goal.x,
                y=self.goal.y+2*self.goal.r,
                w=BODY_RADIUS*2,
                l=BODY_RADIUS*0.1,
                h=BODY_RADIUS*2,
            ),
            Box(
                x=self.goal.x+2*self.goal.r,
                y=self.goal.y,
                w=BODY_RADIUS*0.2,
                l=BODY_RADIUS*2,
                h=BODY_RADIUS*2,
            ),
            Box(
                x=self.goal.x-2*self.goal.r,
                y=self.goal.y,
                w=BODY_RADIUS*0.2,
                l=BODY_RADIUS*2,
                h=BODY_RADIUS*2,
            ),
        ]
        self.wall_bodies = list()
        for wall in walls:
            wall_body = make_box(wall)
            self.wall_bodies.append(wall_body)
            spawn = arena.worldbody.add('site')
            spawn.attach(wall_body)

        # Place creature in the arena.
        self.creature = make_creature(num_legs=num_legs)
        height = BODY_RADIUS*1.5
        spawn_pos = (0, 0, height)
        spawn_site = arena.worldbody.add('site', pos=spawn_pos, group=3)
        spawn_site.attach(self.creature).add('freejoint')

        # Initiate simulation
        self.physics = mjcf.Physics.from_mjcf_model(arena)

        # Prepare animation data structure
        self.animation_fps = animation_fps
        self.video = []

        # Keep track of the list of actuators
        self.actuators = []
        self.torsos = []
        self.joints = []
        self.torsos.append(self.creature.find('geom', 'torso'))
        self.actuators.extend(self.creature.find_all('actuator'))
        self.joints.extend(self.creature.find_all('joint'))

        # Define sub-step number
        self.sub_step_s = sub_step_s

        self.reset()

    def reset(self):
        """Reset the simulation state to the beginning of the simulation."""
        self.physics.reset()
        self.activated_buttons = list()
        if len(self.button_combination) == 0:
            self.remove_walls()

    def sub_step(self, actions: np.ndarray):
        """Step the simulation once.

        The given `actions` is an array of `len(self.actuators)` numbers
        between that 0 and 1 that define the phase of each actuator.

        There are two actuators per leg.
        """
        action = np.clip(actions[0], -1.0, 1.0)

        # Inject controls and step the physics.
        self.physics.bind(self.actuators).ctrl = action
        self.physics.step()

        # Record animation frame if appropriate
        if self.animation_fps is not None:
            target_video_n = self.physics.data.time*self.animation_fps
            if len(self.video) < target_video_n:
                pixels = self.physics.render()
                self.video.append(pixels.copy())

    def step(self, actions: np.ndarray) -> list[tuple[float, float]]:
        """Step the simulation until `self.sub_step_s` seconds have elapsed.
        Return the torso horizontal positions recorded during the step.

        The given `actions` is an array of `len(self.actuators)` numbers
        between that 0 and 1 that define the target position of each actuator.

        There are two actuators per leg.
        """
        start_t = self.physics.data.time
        torso_horizontal_positions = list()
        while self.physics.data.time - start_t < self.sub_step_s:
            self.sub_step(actions)
            torso_n = self.physics.data.time*self.torso_positions_recording_fps
            if len(torso_horizontal_positions) < torso_n:
                torso_horizontal_positions.append(
                    self.torso_horizontal_position
                    )
        return torso_horizontal_positions

    @property
    def simulation_time_s(self) -> float:
        """Return the current simulation time."""
        return self.physics.data.time

    @property
    def torso_horizontal_position(self) -> tuple[float, float]:
        """Return the horizontal position of the torso of the ant."""
        return (
            self.physics.bind(self.torsos).xpos[0, 0].copy(),
            self.physics.bind(self.torsos).xpos[0, 1].copy()
        )

    @property
    def actuator_state(self) -> tuple[float, ...]:
        """Return the state of the actuators."""
        actuator_state = [
            *list(self.physics.data.qpos),
            *list(self.physics.data.qvel),
        ]
        return tuple(actuator_state)

    def remove_walls(self):
        """Remove the walls that surround the goal."""
        for wall_body in self.wall_bodies:
            for geom in wall_body.find_all('geom'):
                o = self.physics.bind(geom)
                o.pos = (o.pos[0], o.pos[1], -5)

    def activate_button(self, i: int):
        """Push the button."""
        if i not in self.activated_buttons:
            self.activated_buttons.append(i)

        # Change button color
        for geom in self.button_roots[i].find_all('geom'):
            self.physics.bind(geom).rgba = ([0.5, 0.2, 0.2, 1])

        # Remove walls if reached correct button combination
        if tuple(self.activated_buttons) == self.button_combination:
            self.remove_walls()

    @property
    def is_success(self) -> bool:
        x, y = self.torso_horizontal_position
        target_x, target_y = self.goal.x, self.goal.y
        distance = ((x-target_x)**2 + (y-target_y)**2)**(1/2)
        is_success = distance <= self.goal.r
        return is_success


def get_simulation(
        animation_fps: Union[float, None],
        password: tuple[int, ...],
        num_buttons: int,
        sub_step_s: float,
        ) -> AntSimulation:
    """Return the simulation that defines the ant task."""
    num_legs = 4
    num_buttons = max(password, default=-1)+1
    env = AntSimulation(
        num_legs=num_legs,
        animation_fps=animation_fps,
        button_combination=password,
        num_buttons=num_buttons,
        sub_step_s=sub_step_s,
        torso_positions_recording_fps=60/4,
    )
    return env
