import numpy as np
from rand_param_envs.ant_goal_rand_params import AntGoalRandParamsEnv

from . import register_env


@register_env('ant-goal-rand-params')
class AntGoalRandParamsWrappedEnv(AntGoalRandParamsEnv):
    def __init__(self, n_tasks=2, randomize_tasks=True):
        super(AntGoalRandParamsWrappedEnv, self).__init__()
        np.random.seed(1)
        self.tasks = self.sample_tasks(n_tasks)
        self.goals = self.sample_goals(n_tasks)
        self.reset_task(0)

    def get_all_task_idx(self):
        return range(len(self.tasks))

    def reset_task(self, idx):
        self._task = self.tasks[idx]
        # self._goal = idx
        self._goal = self.get_train_dynamic(self.tasks[idx])
        self._goal_pos = self.goals[idx]
        # print('reset_task idx', idx, 'goal_pos', self._goal_pos, 'dynamics', self._task) 
        self.set_task(self._task)
        self.reset()

    def get_train_goals(self, n_train_tasks):
        '''
        print('rand_params', self.rand_params)
        if 'body_mass' in self.rand_params:
            print('body mass in rand params')
        if 'body_inertia' in self.rand_params:
            print('body inertia in rand params')
        if 'dof_damping' in self.rand_params:
            print('dof damping in rand params')
        if 'geom_friction' in self.rand_params:
            print('geom friction in rand params')
        return [self.get_train_goal(task) for task in self.tasks[:n_train_tasks]]
        '''
        return [goal for goal in self.goals[:n_train_tasks]]

    # dynamics
    def get_train_dynamics(self, n_train_tasks):
        return [self.get_train_dynamic(task) for task in self.tasks[:n_train_tasks]]

    def reward(self, info, goal):
        # return 0, info['done']
        if 'predicted_next_obs' in info:
            next_obs = info['predicted_next_obs']
            notdone = (np.isfinite(next_obs).all() and next_obs[2] >= 0.2 and next_obs[2] <=
            1.0)
            done = not notdone
            return 0, done
        else:
            reward_ctrl, reward_contact = info["reward_ctrl"], info["reward_contact"]
            reward_survive, xposafter = info["reward_survive"], info["xposafter"]
            done = info["done"]
            dist = np.linalg.norm(xposafter[:2] - goal)
            goal_reward = -np.sum(np.abs(xposafter[:2] - goal)) + 4.0
            sparse_goal_reward = goal_reward
            if dist > 0.8:
                sparse_goal_reward = -np.sum(np.abs(goal)) + 4.0
            sparse_reward = sparse_goal_reward + reward_ctrl + reward_contact + \
                reward_survive
            reward = goal_reward + reward_ctrl + reward_contact + reward_survive
            # return reward, sparse_reward, done
            return sparse_reward, done

    def sample_goals(self, num_tasks):
        radius = 2.0
        angles = np.linspace(0, np.pi, num=num_tasks)
        xs = radius * np.cos(angles)
        ys = radius * np.sin(angles)
        goals = np.stack([xs, ys], axis=1)
        np.random.shuffle(goals)
        goals = goals.tolist()
        return goals

    def get_train_dynamic(self, task):
        task_arr = []
        if 'body_mass' in self.rand_params:
            print('body mass', task['body_mass'].flatten().shape)
            task_arr.append(task['body_mass'].flatten())
        if 'body_inertia' in self.rand_params:
            print('body inertia', task['body_inertia'].flatten().shape)
            task_arr.append(task['body_inertia'].flatten())
        if 'dof_damping' in self.rand_params:
            print('dof damping', task['dof_damping'].flatten().shape)
            task_arr.append(task['dof_damping'].flatten())
        if 'geom_friction' in self.rand_params:
            print('geom friction', task['geom_friction'].flatten().shape)
            task_arr.append(task['geom_friction'].flatten())
        '''
        print('train goal dim:', len(np.array(task_arr).flatten()), 'task_arr',
                np.array(task_arr).flatten())
        return np.array(task_arr).flatten()
        '''
        return np.concatenate([task for task in task_arr])
