from typing import Union
import numpy as np

from abc import ABC, abstractmethod
import gym


def apply_wrapper(env):
    print(str(type(env.unwrapped)))
    if 'sconegym' in str(type(env.unwrapped)):
        print('yes!')
        return SconeWrapper(env)
    return GymWrapper(env)


class AbstractDEPWrapper(gym.Wrapper, ABC):
    def __init__(
        self,
        env: gym.Env,
        # 2d
        #force_scale: Union[float, int] = 3.275102e-5,
        # 3d
        #force_scale: Union[float, int] = 3.48662e-5,
        #2d new?
        #force_scale: Union[float, int] = 2.255555e-5,
        #jumping
        force_scale: Union[float, int] = 0.00054749
    ):
        """
        Adds the state information for DEP as a property.
        Args:
            env (Env): The environment to apply the wrapper
            force_scale (float, int): The scaling factor for the muscle force input that is summed to the muscle lengths.
        """
        super().__init__(env)
        assert force_scale >= 0, f"expected positive value, got {force_scale}"
        self.set_force_scale(float(force_scale))

    def set_force_scale(self, force_scale):
        self.force_scale = force_scale

    def merge_args(self, *args, **kwargs):
        pass

    def apply_args(self, *args, **kwargs):
        pass

    def render(self, *args, **kwargs):
        pass

    @property
    @abstractmethod
    def muscle_lengths(self):
        pass

    @property
    @abstractmethod
    def muscle_forces(self):
        pass

    @property
    @abstractmethod
    def muscles_dep(self):
        pass


class GymWrapper(AbstractDEPWrapper):

    # def step(self, *args, **kwargs):
    #    obs, reward, done, info = self.unwrapped.step(*args, **kwargs)
    #    reward = 0 if info['solved'] else -1
    #    reward -= info['done'] * 50
    #    done = 1 if (info['solved'] or info['done']) else 0

    #    if reward == 0:
    #        print('-----------DONE-------------')
    #        print(reward)
    #    return obs, reward, done, info

    def render(self, *args, **kwargs):
        kwargs["mode"] = "window"
        self.unwrapped.sim.render(*args, **kwargs)

    @property
    def muscle_lengths(self):
        return self.unwrapped.sim.data.actuator_length

    @property
    def muscle_forces(self):
        return self.unwrapped.sim.data.actuator_force

    @property
    def muscles_dep(self):
        return self.muscle_lengths + self.force_scale * self.muscle_forces


class SconeWrapper(AbstractDEPWrapper):

    # def step(self, *args, **kwargs):
    #    obs, reward, done, info = self.unwrapped.step(*args, **kwargs)
    #    reward = 0 if info['solved'] else -1
    #    reward -= info['done'] * 50
    #    done = 1 if (info['solved'] or info['done']) else 0

    #    if reward == 0:
    #        print('-----------DONE-------------')
    #        print(reward)
    #    return obs, reward, done, info

    def render(self, *args, **kwargs):
        pass

    @property
    def muscle_lengths(self):
        return self.model.muscle_fiber_lengths()
        return self.model.muscle_fiber_length_array()

    @property
    def muscle_forces(self):
        return self.model.muscle_forces()
        return self.model.muscle_force_array()

    @property
    def muscles_dep(self):
        dofs = self.model.dofs()
        muscle_state = self.muscle_lengths + self.force_scale * self.muscle_forces
        return np.asarray(muscle_state, dtype=np.float32).copy()
        #return np.concatenate([[dofs[-3].pos()], [dofs[-2].pos()], [dofs[-1].pos()], muscle_state], dtype=np.float32).copy()
