import numpy as np
import pickle
from collections import OrderedDict

from typing import Optional, Tuple, List
# from . import register_env
from .half_cheetah import HalfCheetahEnv
# from .half_cheetah import HalfCheetahVelEnv_


# @register_env('cheetah-vel')
# @register_env('HalfCheetahVel-v0')


class HalfCheetahVelEnv_(HalfCheetahEnv):
    def __init__(self, tasks=[{}], num_tasks=20, num_train_tasks=15, randomize_tasks=True):
        self.tasks = tasks
        self.num_tasks = num_tasks
        self.num_train_tasks = num_train_tasks
        self._task = self.tasks[0]
        self._goal_vel = self._task.get('velocity', 0.0)
        self._goal = self._goal_vel
        super(HalfCheetahVelEnv_, self).__init__()

    def step(self, action):
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(action, self.frame_skip)
        xposafter = self.sim.data.qpos[0]

        forward_vel = (xposafter - xposbefore) / self.dt
        forward_reward = -1.0 * abs(forward_vel - self._goal_vel)
        ctrl_cost = 0.5 * 1e-1 * np.sum(np.square(action))

        observation = self._get_obs()
        reward = forward_reward - ctrl_cost
        done = False
        infos = dict(reward_forward=forward_reward,
                     reward_ctrl=-ctrl_cost, task=self._task)
        return (observation, reward, done, infos)

    def sample_tasks(self, num_tasks, seed: int = 1337):
        np.random.seed(seed)
        # velocities = np.random.uniform(0.0, 3.0, size=(num_tasks,))
        # velocities = np.linspace(0.075, 3, 40)
        velocities = np.linspace(
            0.075, 3, num_tasks)
        np.random.seed(0)
        np.random.shuffle(velocities)
        tasks = [{'velocity': velocity} for velocity in velocities]
        self.tasks = tasks
        return tasks

    def get_all_task_idx(self):
        return range(len(self.tasks))

    def reset_task(self, idx):
        self._task = self.tasks[idx]
        self._goal_vel = self._task['velocity']
        self._goal = self._goal_vel
        self.reset()

    def print_task(self):
        print(f'Task information: Goal velocity {self._goal}')


class HalfCheetahVelEnv(HalfCheetahVelEnv_):
    def __init__(self, tasks: List[dict] = None, include_goal: bool = False, one_hot_goal: bool = False, n_tasks: int = 50, train_task_ids: Optional[List[int]] = None):
        self.include_goal = include_goal
        self.one_hot_goal = one_hot_goal
        # velocities = np.random.uniform(0.0, 3.0, size=(20,))
        # tasks = np.load(
        #     f'data_collection/HalfCheetahVel-v0/task_goals.npy', allow_pickle=True)
        # print(tasks)
        if tasks is None:
            assert n_tasks is not None, "Either tasks or n_tasks must be non-None"
            tasks = self.sample_tasks(n_tasks)
        self.tasks = tasks  # not used
        self.n_tasks = len(tasks)
        self.train_task_ids = train_task_ids if train_task_ids is not None else list(range(self.n_tasks - 5))
        super().__init__(tasks, num_tasks=self.n_tasks, num_train_tasks=len(self.train_task_ids))
        self.set_task_idx(0)
        self._max_episode_steps = 200

    def _get_obs(self):
        if self.include_goal:
            obs = super()._get_obs()
            if self.one_hot_goal:
                goal = np.zeros((self.n_tasks,))
                goal[self.tasks.index(self._task)] = 1
            else:
                goal = np.array([self._goal_vel])
            obs = np.concatenate([obs, goal])
        else:
            obs = super()._get_obs()

        return obs

    def set_task(self, task):
        self._task = task
        self._goal_vel = self._task['velocity']
        self.reset()

    def set_task_idx(self, idx):
        self.task_idx = idx
        self.set_task(self.tasks[idx])

    def get_dataset(self):
        train_dataset = OrderedDict()
        with open(f"datasets/HalfCheetahVel-v0/dataset_task_{self.train_task_ids[0]}.pkl", "rb") as f:
            dataset = pickle.load(f)
        for key, value in dataset.items():
            train_dataset[key] = [value]
        train_dataset['task_id'] = [(np.ones(dataset['observations'].shape[0]) * self.train_task_ids[0]).astype(np.int32)]

        for task_id in self.train_task_ids[1:]:
            with open(f"datasets/HalfCheetahVel-v0/dataset_task_{task_id}.pkl", "rb") as f:
                dataset = pickle.load(f)

            for key, value in dataset.items():
                train_dataset[key].append(value)
            train_dataset['task_id'].append((np.ones(dataset['observations'].shape[0]) * task_id).astype(np.int32))

        for key, value in train_dataset.items():
            train_dataset[key] = np.concatenate(value, axis=0)

        num_samples = train_dataset['observations'].shape[0]

        print('================================================')
        print(f'Successfully constructed the training dataset from {self.num_train_tasks} tasks...')
        print(f'Number of training samples: {num_samples}')
        print('================================================')

        return train_dataset

    def get_prompt_trajectories(self, num_demos=5):
        prompt_trajectories = []

        for task_id in range(self.num_tasks):
            with open(f'datasets/HalfCheetahVel-v0/dataset_task_{task_id}.pkl', "rb") as f:
                dataset = pickle.load(f)
            f.close()
            trajectories = convert_data_to_trajectories(dataset)

            returns = [traj['rewards'].sum() for traj in trajectories]
            sorted_inds = sorted(range(len(returns)),
                                 key=lambda x: returns[x], reverse=True)

            demos = [trajectories[sorted_inds[i]] for i in range(num_demos)]
            prompt_trajectories.append(demos)

        return prompt_trajectories

    def set_goal(self, goal):
        # self._goal = np.asarray(goal)
        self._task = goal
        self._goal_vel = self._task['velocity']
        self.reset()

    def get_normalized_score(self, values):
        return values


def convert_data_to_trajectories(data):
    trajectories = []
    start_ind = 0
    for ind, terminal in enumerate(data['terminals']):
        timeout = data['timeouts'][ind]
        if terminal or timeout:
            traj = OrderedDict()
            for key, value in data.items():
                traj[key] = value[start_ind: ind+1]

            trajectories.append(traj)
            start_ind = ind + 1

    print(f'Convert {ind} transitions to {(len(trajectories))} trajectories.')
    return trajectories
