import copy
import numpy as np
from dm_control.rl import control
from dm_control.utils import containers
from dm_control.utils import rewards
from dm_control.suite.fish import Swim, Physics, _DEFAULT_TIME_LIMIT, _CONTROL_TIMESTEP
from dm_control.suite import common
from lxml import etree
from .import utils

SUITE = containers.TaggedTasks()

def get_model_and_assets(dynamics_kwargs=None):
  """Returns a tuple containing the model XML string and a dict of assets."""
  return _make_model(dynamics_kwargs), common.ASSETS


@SUITE.add('benchmarking')
def swim(time_limit=_DEFAULT_TIME_LIMIT, random=None, environment_kwargs=None, reward_kwargs=None, dynamics_kwargs=None):
    """Returns the Fish Swim task."""
    physics = Physics.from_xml_string(*get_model_and_assets(dynamics_kwargs))
    task = SwimReward(random=random, reward_kwargs=reward_kwargs)
    environment_kwargs = environment_kwargs or {}
    return control.Environment(
        physics, task, control_timestep=_CONTROL_TIMESTEP, time_limit=time_limit,
        **environment_kwargs)


def _make_model(dynamics_kwargs=None):
    """Generates an xml string defining with a modified torso."""
    xml_string = common.read_model('fish.xml')
    if dynamics_kwargs is None:
        return xml_string

    assert isinstance(dynamics_kwargs, dict)

    mjcf = etree.fromstring(xml_string)
    # Find the geom of the fins
    finR = mjcf.findall('./worldbody/body/body')[1]
    finL = mjcf.findall('./worldbody/body/body')[2]
    finright = finR.findall('./geom')[0]
    finleft = finL.findall('./geom')[0]

    if 'length' in dynamics_kwargs:
        finright.set('size', f"{dynamics_kwargs['length'] * 20} {dynamics_kwargs['length'] * 15} {dynamics_kwargs['length']}")
        finright.set('pos', f"{dynamics_kwargs['length'] * 15} 0 0")
        finleft.set('size', f"{dynamics_kwargs['length'] * 20} {dynamics_kwargs['length'] * 15} {dynamics_kwargs['length']}")
        finleft.set('pos', f"-{dynamics_kwargs['length'] * 15} 0 0")

    return etree.tostring(mjcf, pretty_print=True)


class SwimReward(Swim):
    """A Fish `Task` for swimming with smooth reward."""

    def __init__(self, random=None, reward_kwargs=None):
        """Initializes an instance of `Swim`.
        Args:
            random: Optional, either a `numpy.random.RandomState` instance, an
            integer seed for creating a new `RandomState`, or None to select a seed
            automatically (default).
        """
        super().__init__(random=random)



        default_reward_parameters = {
            'swim': {
                'bounds': [0, 0.045],
                'margin': 2*0.045
            }
        }

        # update reward parameters
        reward_kwargs_copy = copy.deepcopy(reward_kwargs)
        self.reward_parameters = utils.set_reward_parameters(default_reward_parameters, reward_kwargs_copy)


    def get_reward(self, physics):
        """Returns a smooth reward."""
        in_target = rewards.tolerance(np.linalg.norm(physics.mouth_to_target()),
                                    **self.reward_parameters['swim'])
        is_upright = 0.5 * (physics.upright() + 1)
        return (7*in_target + is_upright) / 8