import numpy as np
from itertools import combinations
from . import register_env
from .half_cheetah import HalfCheetahEnv


@register_env('cheetah-negated-joints')
class HalfCheetahModControlEnv(HalfCheetahEnv):
    """Half-cheetah environment with target velocity, as described in [1]. The
    code is adapted from
    https://github.com/cbfinn/maml_rl/blob/9c8e2ebd741cb0c7b8bf2d040c4caeeb8e06cc95/rllab/envs/mujoco/half_cheetah_env_rand.py
    The half-cheetah follows the dynamics from MuJoCo [2], and receives at each
    time step a reward composed of a control cost and a penalty equal to the
    difference between its current velocity and the target velocity. The tasks
    are generated by sampling the target velocities from the uniform
    distribution on [0, 2].
    [1] Chelsea Finn, Pieter Abbeel, Sergey Levine, "Model-Agnostic
        Meta-Learning for Fast Adaptation of Deep Networks", 2017
        (https://arxiv.org/abs/1703.03400)
    [2] Emanuel Todorov, Tom Erez, Yuval Tassa, "MuJoCo: A physics engine for
        model-based control", 2012
        (https://homes.cs.washington.edu/~todorov/papers/TodorovIROS12.pdf)
    """

    def __init__(self, type='mask', n_tasks=20):
        self.task_type = type

        if type == 'swp':
            self.tasks = gen_swp_tasks()
            self._joint_permutation = self.tasks[0].get('joint_permutation')
        elif type == 'mask':
            # 10 train tasks, 10 test tasks. 6th joint negated for test tasks
            self.tasks = gen_neg_tasks()
            self.mask = self.tasks[0].get('mask')

        assert n_tasks == len(self.tasks)
        super(HalfCheetahModControlEnv, self).__init__()

    def step(self, action):
        if self.task_type == 'swp':
            action = action[self._joint_permutation]
        elif self.task_type == 'mask':
            action = self.mask * action

        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        xposafter = self.sim.data.qpos[0]
        ob = self._get_obs()
        reward_ctrl = - 0.1 * np.square(action).sum()
        reward_run = (xposafter - xposbefore) / self.dt
        reward = reward_ctrl + reward_run
        done = False
        return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)

    def get_all_task_idx(self):
        return range(len(self.tasks))

    def reset_task(self, idx):

        self._task = self.tasks[idx]
        if self.task_type == 'swp':
            self._joint_permutation = self._task['joint_permutation']
        elif self.task_type == 'mask':
            self.mask = self._task['mask']
        return self.reset()

def gen_swp_tasks():
    all_tasks = []
    swp_idxs = list(combinations(np.arange(6), 2))
    orig_lst = np.arange(6)
    for a, b in swp_idxs:
        task_lst = orig_lst.copy()
        task_lst[a], task_lst[b] = task_lst[b], task_lst[a]
        all_tasks.append({'joint_permutation': task_lst})

    for task in all_tasks:
        print(task)
    # print(all_tasks)
    return all_tasks


def gen_neg_tasks():
    # 10 train tasks, followed by 10 test tasks
    all_tasks = []
    all_train_neg_idxs = list(combinations(np.arange(5), 3))

    for i, neg_idxs in enumerate(all_train_neg_idxs):
        mask = np.ones(6)
        for idx in neg_idxs:
            mask[idx] = -1
        all_tasks.append({'mask': mask})

    all_test_neg_idxs = list(combinations(np.arange(5), 2))
    for i, neg_idxs in enumerate(all_test_neg_idxs):
        mask = np.ones(6)
        mask[-1] = -1
        for idx in neg_idxs:
            mask[idx] = -1
        all_tasks.append({'mask': mask})

    return all_tasks

if __name__ == '__main__':

    env =  HalfCheetahModControlEnv()
    
    for idx in range(81):
        env.reset()
        env.reset_task(idx)
        print('task', idx)
        
        for step in range(100):
            env.step(np.zeros(env.action_space.shape))
            #print('step', step)
            #import ipdb; ipdb.set_trace()
            #env.render()
