import numpy as np
from rand_param_envs.half_cheetah_vel_rand_params import CheetahVelRandParamsEnv

from . import register_env

@register_env('cheetah-vel-sparse-rand-params')
class CheetahVelRandParamsWrappedEnv(CheetahVelRandParamsEnv):
    def __init__(self, n_tasks=2, randomize_tasks=True):
        super(CheetahVelRandParamsWrappedEnv, self).__init__()
        np.random.seed(1)
        self.tasks = self.sample_tasks(n_tasks)
        self.goals = self.sample_goals(n_tasks)
        self.reset_task(0)
        self._goal_vel = 0.

    def get_all_task_idx(self):
        return range(len(self.tasks))

    def reset_task(self, idx):
        self._task = self.tasks[idx]
        # self._goal = idx
        self._goal = self.get_train_dynamic(self.tasks[idx])
        goal = self.goals[idx]
        self._goal_vel = goal["velocity"]
        # print('reset_task idx', idx, 'goal_pos', self._goal_pos, 'dynamics', self._task) 
        self.set_task(self._task)
        self.reset()

    def get_train_goals(self, n_train_tasks):
        '''
        print('rand_params', self.rand_params)
        if 'body_mass' in self.rand_params:
            print('body mass in rand params')
        if 'body_inertia' in self.rand_params:
            print('body inertia in rand params')
        if 'dof_damping' in self.rand_params:
            print('dof damping in rand params')
        if 'geom_friction' in self.rand_params:
            print('geom friction in rand params')
        return [self.get_train_goal(task) for task in self.tasks[:n_train_tasks]]
        '''
        return [goal["velocity"] for goal in self.goals[:n_train_tasks]]

    # dynamics
    def get_train_dynamics(self, n_train_tasks):
        return [self.get_train_dynamic(task) for task in self.tasks[:n_train_tasks]]

    def reward(self, info, goal):
        # return 0, info['done']
        if 'predicted_next_obs' in info:
            return 0, False
        else:
            reward_ctrl, velocity = info["reward_ctrl"], info["velocity"]
            forward_reward = -1.0 * abs(velocity - goal)
            sparse_reward = self.sparsify_rewards(forward_reward)
            return sparse_reward + reward_ctrl, False

    def sample_goals(self, num_tasks):
        velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,))
        tasks = [{'velocity': velocity} for velocity in velocities]
        return tasks

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(
                    low=-0.1, high=0.1, size=self.model.nq
                    )
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1
        self.set_state(qpos, qvel)
        return self._get_obs()

    def get_train_dynamic(self, task):
        task_arr = []
        if 'body_mass' in self.rand_params:
            print('body mass', task['body_mass'].flatten().shape)
            task_arr.append(task['body_mass'].flatten())
        if 'body_inertia' in self.rand_params:
            print('body inertia', task['body_inertia'].flatten().shape)
            task_arr.append(task['body_inertia'].flatten())
        if 'dof_damping' in self.rand_params:
            print('dof damping', task['dof_damping'].flatten().shape)
            task_arr.append(task['dof_damping'].flatten())
        if 'geom_friction' in self.rand_params:
            print('geom friction', task['geom_friction'].flatten().shape)
            task_arr.append(task['geom_friction'].flatten())
        '''
        print('train goal dim:', len(np.array(task_arr).flatten()), 'task_arr',
                np.array(task_arr).flatten())
        return np.array(task_arr).flatten()
        '''
        return np.concatenate([task for task in task_arr])
