# This section is required to find the sconepy module and DLLs
# Change SCONE_BIN_DIR if SCONE was installed to a different folder
SCONE_BIN_DIR = 'C:/Program Files/SCONE/bin'
import sys
sys.path.append( SCONE_BIN_DIR )

# Import the sconepy module (should be in the right path)
import sconepy
import math
import numpy as np
import gym
from abc import ABC, abstractmethod
from typing import Optional
import tonic


class DummySconePy(gym.Env, ABC):
    def __init__(self, *args, **kwargs):
        kwargs = self._setup(kwargs)
        super().__init__(*args, **kwargs)
        self._setup_action_observation_spaces()
        self._max_episode_steps = 1000

    def _setup_action_observation_spaces(self):
        self.action_space = gym.spaces.Box(low=np.zeros(shape=(5,)),
                                           high=np.ones(shape=(5,)), dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=-10000, high=10000, shape=self._get_obs().shape, dtype=np.float32)

    def _setup(self, *args, **kwargs):
        if 'identifier' in kwargs:
            self.identifier = kwargs.pop('identifier')
        return kwargs

    def step(self, action):
        return self._get_obs(), 0.1, False, {}

    def _get_obs(self):
        return np.array([0, 1, 2])

    def reset(
            self,
            *,
            seed: Optional[int] = None,
            return_info: bool = False,
            options: Optional[dict] = None,
    ):
        obs = np.array([3, 4, 5])
        if return_info:
            return obs, (obs, {})
        else:
            return obs


class SconePyGym(gym.Env, ABC):
    def __init__(self, *args, **kwargs):
        kwargs = self._setup(kwargs)
        super().__init__(*args, **kwargs)
        if not hasattr(self, 'model_file'):
            raise Exception('The child class has to define a model_file')
        self.model = sconepy.load_model(self.model_file)
        sconepy.set_log_level(7)
        self.init_pose = self.model.dof_values().copy()
        # control time step; physics time step is variable
        self.step_size = 0.01
        self.has_reset = False
        self.has_called_render = False
        self._setup_action_observation_spaces()
        self.manual_load_model = False

    def step(self, action):
        if not self.has_reset:
            raise Exception('You have to call reset() once before step()')
        action = np.clip(action, 0.0, 1.0)
        self.model.set_actuator_inputs(action)
        self.model.advance_simulation_to(self.time + self.step_size)
        reward = self._get_rew()
        obs = self._get_obs()
        done = self._get_done()
        self.time += self.step_size
        return obs, reward, done, {}

    def reset(
            self,
            *,
            seed: Optional[int] = None,
            return_info: bool = False,
            options: Optional[dict] = None,
    ):
        if not self.has_reset:
            self.has_reset = True
        if not self.manual_load_model:
            self.model = sconepy.load_model(self.model_file)
        pose = self.init_pose + np.random.normal(0, 0.25)
        #pose[2] = 3.0
        self.model.set_dof_values(pose)
        muscle_activations = 0.1 + 0.4 * np.random.normal(0, 0.1, size=len(self.model.actuators(),))
        #muscle_activations = 0.1 + 0.4 * np.random.normal(0, 0.1, size=(22,))
        self.model.init_muscle_activations(muscle_activations)
        self.time = 0
        obs = self._get_obs()
        if return_info:
            return obs, (obs, {})
        else:
            return obs

    def set_manually_load_model(self):
        self.manual_load_model = True

    def manually_load_model(self):
        self.model = sconepy.load_model(self.model_file)
        self.model.set_store_data(True)

    def render(self, *args, **kwargs):
        return
        #if not self.has_called_render:
        #    self.model = sconepy.load_model(self.model_file)
        #    self.model.set_store_data(True)
        #    sconepy.set_log_level(7)
        #    self.init_pose = self.model.dof_values().copy()
        #    self.has_called_render = True

    def _setup_action_observation_spaces(self):
        self.action_space = gym.spaces.Box(low=np.zeros(shape=(len(self.model.actuators()),)),
                                           high=np.ones(shape=(len(self.model.actuators()),)), dtype=np.float32)
        #self.action_space = gym.spaces.Box(low=np.zeros(shape=(25,)),
        #                                   high=np.ones(shape=(25,)), dtype=np.float32)
        self.observation_space = gym.spaces.Box(low=-10000, high=10000, shape=self._get_obs().shape, dtype=np.float32)

    def _setup(self, *args, **kwargs):
        if 'identifier' in kwargs:
            self.identifier = kwargs.pop('identifier')
            np.random.seed(self.identifier)
        return kwargs

    @abstractmethod
    def _get_obs(self):
        pass

    @abstractmethod
    def _get_rew(self):
        pass

    @abstractmethod
    def _get_done(self):
        pass


model_files = {'3d': 'H1622/H1622_hyfydy.scone',
               '3dx2': 'H1622x2/H1622x2_hyfydy.scone',
               '2d': 'H0918_hyfydy.scone',
               '2d_jump': 'H0918_hyfydy_jump.scone',
               '2d_jump_terrain': 'H0918_hyfydy_jump_terrain.scone',
               '2d_5cm': 'obstacle_files/H0918_hyfydy_terrain_5cm.scone',
               '2d_10cm': 'obstacle_files/H0918_hyfydy_terrain_10cm.scone',
               '2d_15cm': 'obstacle_files/H0918_hyfydy_terrain_15cm.scone',
               '2d_20cm': 'obstacle_files/H0918_hyfydy_terrain_20cm.scone',
               '2d_30cm': 'obstacle_files/H0918_hyfydy_terrain_30cm.scone',
               }


class Walking2D(SconePyGym):
    def __init__(self, *args, **kwargs):
        #self.model_file = model_files['3dx2']
        self.model_file = model_files['2d']
        #self.model_file = model_files['2d_20cm']
        # 3d, not working yet
        #self.model_file = 'H1922_hyfydy.scone'
        self._max_episode_steps = 1000
        super().__init__(*args, **kwargs)

    def set_obstacle_course(self, name: str) -> None:
        self.model_file = model_files[name]
        self.reset()

    def _get_obs(self):
        return np.concatenate([
            self.model.muscle_fiber_lengths(),
            self.model.muscle_fiber_velocities(),
            self.model.muscle_forces(),
            self.model.muscle_activations(),
            self.get_body_pos(),
            self.get_body_orientation(),
            self.get_body_ang_vel(),
            self.get_body_vel(),
            self.joints(),
            self.get_com_vel(),
            np.array([self.get_torso_y_angle()]),
            np.array([self.model.com_pos().y])], dtype=np.float32).copy()

    def get_com_vel(self):
        com_vel = self.model.com_vel()
        return np.array([com_vel.x, com_vel.y, com_vel.z])

    def joints(self):
        joints = self.model.dof_values()
        if 'H16' not in self.model_file:
            # 2d
            joints[1] = 0
        else:
            # 3d
            joints[3] = 0
        return joints.copy()

    def get_body_ang_vel(self):
        return np.concatenate([self._ang_vel(b) for b in self.model.bodies()])

    def get_body_pos(self):
        return np.concatenate([self._pos(b) for b in self.model.bodies()])

    def get_body_orientation(self):
        return np.concatenate([self._orientation(b) for b in self.model.bodies()])

    def get_body_vel(self):
        return np.concatenate([self._vel(b) for b in self.model.bodies()])

    def _ang_vel(self, body):
        ang_vel = body.ang_vel()
        return np.array([ang_vel.x, ang_vel.y, ang_vel.z])

    def get_torso_y_angle(self):
        orientation = self.model.bodies()[-1].orientation()
        euler_angle = euler_from_quaternion(orientation.x, orientation.y, orientation.z, orientation.w)
        return euler_angle[-1]


    def _pos(self, body):
        com_pos = body.com_pos()
        if 'H16' not in self.model_file:
            # 2d
            return np.array([com_pos.y])
        # 3d
        return np.array([com_pos.y, com_pos.z])
        #return np.array([com_pos.x, com_pos.y, com_pos.z])

    def _vel(self, body):
        com_vel = body.com_vel()
        return np.array([com_vel.x, com_vel.y, com_vel.z])
        #return np.array([com_pos.x, com_pos.y, com_pos.z])

    def _orientation(self, body):
        orientation = body.orientation()
        return np.array([orientation.x, orientation.y, orientation.z, orientation.w])

    def _get_rew(self):
        #pun = -50 if self._get_done() else 0
        #return self.model.com_pos().y + pun
        return self.model.com_vel().x
        #return np.clip(self.model.com_vel().y, 0, 1000)

    def _get_done(self) -> bool:
        if self.model.com_pos().y < 0.5:
            return True
        #angle = self.get_torso_y_angle()
        #if angle < -0.8 or angle > 0.8:
        #if angle < -0.8 or angle > 0.8:
        #    return True
        return False

    @property
    def horizon(self):
        return 1000


class Jumping2D(Walking2D):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_file = model_files['2d_jump']

    def _get_rew(self) -> float:
        if self.model.com_pos().y > 1.08:
            return 1.0
        else:
            return 0.0
        
    def reset(
            self,
            *,
            seed: Optional[int] = None,
            return_info: bool = False,
            options: Optional[dict] = None,
    ):
        if not self.has_reset:
            self.has_reset = True
        if not self.manual_load_model:
            self.model = sconepy.load_model(self.model_file)
        pose = self.init_pose + np.random.normal(0, 0.01)
        pose[2] = 1.3
        self.model.set_dof_values(pose)
        muscle_activations = 0.1 + 0.4 * np.random.normal(0, 0.1, size=len(self.model.actuators(),))
        #muscle_activations = 0.1 + 0.4 * np.random.normal(0, 0.1, size=(22,))
        self.model.init_muscle_activations(muscle_activations)
        self.time = 0
        obs = self._get_obs()
        if return_info:
            return obs, (obs, {})
        else:
            return obs

    #def step(self, action):
    #    if np.random.uniform() < 0.2:
    #        action += np.random.uniform(0, 2)
    #    return super().step(action)


def euler_from_quaternion(x, y, z, w):
    """
    Convert a quaternion into euler angles (roll, pitch, yaw)
    roll is rotation around x in radians (counterclockwise)
    pitch is rotation around y in radians (counterclockwise)
    yaw is rotation around z in radians (counterclockwise)
    """
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    roll_x = math.atan2(t0, t1)

    t2 = +2.0 * (w * y - z * x)
    t2 = +1.0 if t2 > +1.0 else t2
    t2 = -1.0 if t2 < -1.0 else t2
    pitch_y = math.asin(t2)

    t3 = +2.0 * (w * z + x * y)
    t4 = +1.0 - 2.0 * (y * y + z * z)
    yaw_z = math.atan2(t3, t4)

    return roll_x, pitch_y, yaw_z
