from dm_control.suite.point_mass import SUITE as point_mass_suite
from dm_control.suite.point_mass import PointMass, get_model_and_assets, Physics
from dm_control.rl import control
import numpy as np


def make(task,
         task_kwargs=None,
         environment_kwargs=None,
         visualize_reward=False):
    task_kwargs = task_kwargs or {}
    if environment_kwargs is not None:
        task_kwargs = task_kwargs.copy()
        task_kwargs['environment_kwargs'] = environment_kwargs
    env = point_mass_suite[task](**task_kwargs)
    env.task.visualize_reward = visualize_reward
    return env


class RandomPointMass(PointMass):

    def __init__(self, randomize_gains, random=None):
        super().__init__(randomize_gains, random)
        self.target_geom_pos = None

    def initialize_episode(self, physics: Physics):
        super().initialize_episode(physics)
        if self.target_geom_pos is None:
            ground_size = physics.named.model.geom_size["ground"]
            self.target_geom_pos = np.concatenate([self.random.uniform(
                low=-0.1, high=0.1, size=2), [physics.named.model.geom_pos["target"][-1]]])
            print(self.target_geom_pos)
            print(physics.named.model.geom_pos["target"])
        physics.named.model.geom_pos["target"] = self.target_geom_pos


_DEFAULT_TIME_LIMIT = 20


@ point_mass_suite.add()
def random(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None):
    """Returns the easy point_mass task."""
    physics = Physics.from_xml_string(*get_model_and_assets())
    task = RandomPointMass(randomize_gains=False, random=random)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(
        physics, task, time_limit=time_limit, **environment_kwargs)
