import numpy as np
from collections import OrderedDict

from .replay_pool import ReplayPool
from . import (
    simple_replay_pool,
    extra_policy_info_replay_pool)


POOL_CLASSES = {
    'SimpleReplayPool': simple_replay_pool.SimpleReplayPool,
    'ExtraPolicyInfoReplayPool': (
        extra_policy_info_replay_pool.ExtraPolicyInfoReplayPool),
}

class ContinuousGoalPool(ReplayPool):
    def __init__(self, observation_space, action_space, max_size, sub_pool_type=None, goal_dim=1, meta_batch_size=10, **kwargs):
        self.pools = OrderedDict()
        self.observation_space = observation_space
        self.action_space = action_space
        self.max_size = max_size
        self.goal_dim = goal_dim
        self.sub_pool_type = sub_pool_type
        self.curr_pool_id = str(np.zeros(self.goal_dim))
        self.meta_batch_size = meta_batch_size
        self.relabel_balance_batch = kwargs.pop('relabel_balance_batch', False)
        self.pool_kwargs = kwargs

    def add_sample(self, **kwargs):
        assert 'observations' in kwargs
        self.curr_pool_id = str(kwargs['observations'][-self.goal_dim:])
        if self.curr_pool_id not in self.pools.keys():
            self.pools[self.curr_pool_id] = POOL_CLASSES[self.sub_pool_type](observation_space=self.observation_space,
                                                               action_space=self.action_space,
                                                               max_size=self.max_size,
                                                               **self.pool_kwargs)
        self.pools[self.curr_pool_id].add_sample(**kwargs)

    def add_samples(self, samples):
        assert 'observations' in samples
        unique_goals = np.unique(samples['observations'][:, -self.goal_dim:], axis=0)
        for unique_goal in unique_goals:
            self.pools[str(unique_goal)] = POOL_CLASSES[self.sub_pool_type](observation_space=self.observation_space,
                                                               action_space=self.action_space,
                                                               max_size=self.max_size,
                                                               **self.pool_kwargs)
            mask = np.all(samples['observations'][:, -self.goal_dim:] == unique_goal, axis=1)
            data = {}
            for key in samples:
                data[key] = samples[key][mask]
            self.pools[str(unique_goal)].add_samples(data)

    def terminate_episode(self):
        pass

    @property
    def size(self):
        pool_sizes = np.array([b.size for b in self.pools.values()])
        self._total_size = sum(pool_sizes)
        return self._total_size

    def add_path(self, **kwargs):
        raise NotImplementedError

    def last_n_batch(self, last_n, field_name_filter=None, **kwargs):
        if last_n is None:
            data = None
            for i, pool in self.pools.items():
                last_n = pool._samples_since_save
                last_n_batch_task = pool.last_n_batch(last_n,
                    field_name_filter=field_name_filter,
                    **kwargs)
                if data is None:
                    data = last_n_batch_task
                else:
                    for key in data.keys():
                        data[key] = np.concatenate([data[key], last_n_batch_task[key]], axis=0)
            return data
        else:
            return self.pools[self.curr_pool_id].last_n_batch(last_n,
                field_name_filter=field_name_filter,
                **kwargs)

    def random_batch(self, batch_size, field_name_filter=None, task_id=None, use_hipi=False, relabel_all=False, **kwargs):

        # TODO: Hack
        partial_batch_sizes = np.array([float(batch_size) / self.meta_batch_size] * self.meta_batch_size)
        partial_batch_sizes = partial_batch_sizes.astype(int)
        partial_batch_sizes[0] = batch_size - sum(partial_batch_sizes[1:])

        if task_id is not None:
            partial_batch_sizes = np.zeros(self.meta_batch_size).astype(int)
            partial_batch_sizes[task_id] = batch_size

        # only sample from pushing tasks
        # partial_batch_sizes[2] = batch_size // len(self.pools)
        # partial_batch_sizes[:2] = 0
        pools = list(np.random.choice(list(self.pools.values()), size=self.meta_batch_size, replace=False))
        partial_batches = [
            pool.random_batch(partial_batch_size, field_name_filter=field_name_filter, **kwargs) for pool,
            partial_batch_size in zip(pools, partial_batch_sizes)
        ]

        goals = np.array([partial_batches[i]['observations'][0, -self.goal_dim:] for i in range(self.meta_batch_size)])

        if relabel_all:
            partial_batches_relabeled = []
            for i, pool in enumerate(pools):
                assert self.relabel_balance_batch
                relabeled_sizes = np.array([float(partial_batch_sizes[i] // 2) / (self.meta_batch_size-1)] * self.meta_batch_size)
                relabeled_sizes = relabeled_sizes.astype(int)
                relabeled_sizes[i] = partial_batch_sizes[i] // 2
                if i != 0:
                    relabeled_sizes[0] = partial_batch_sizes[i] - sum(relabeled_sizes[1:])
                else:
                    relabeled_sizes[1] = partial_batch_sizes[i] // 2 - sum(relabeled_sizes[2:])
                partial_batch_relabeled = {}
                for key in partial_batches[0].keys():
                    if i < len(pools) - 1:
                        partial_batch_relabeled[key] = np.concatenate([partial_batches[i][key][:partial_batch_sizes[i] // 2]] + 
                            [partial_batches[j][key][(partial_batch_sizes[j]//2 + relabeled_sizes[j]*i):(partial_batch_sizes[j]//2 + relabeled_sizes[j]*(i+1))] for j in range(self.meta_batch_size) if j != i])
                    else:
                        partial_batch_relabeled[key] = np.concatenate([partial_batches[i][key][:partial_batch_sizes[i] // 2]] + 
                            [partial_batches[j][key][(partial_batch_sizes[j]//2 + relabeled_sizes[j]*j):(partial_batch_sizes[j]//2 + relabeled_sizes[j]*(j+1))] for j in range(self.meta_batch_size) if j != i])
                    if 'observations' in key:
                        partial_batch_relabeled[key][:, -self.goal_dim:] = goals[i]
                partial_batch_relabeled['rewards'] = np.expand_dims(-np.linalg.norm(partial_batch_relabeled['observations'][:, 3:5] - goals[i, -3:-1], axis=1), axis=1)
                partial_batch_relabeled['relabel_masks'] = np.ones_like(partial_batch_relabeled['rewards'])
                partial_batch_relabeled['relabel_masks'][:partial_batch_sizes[i] // 2] = 0.0
                partial_batches_relabeled.append(partial_batch_relabeled)
            partial_batches = partial_batches_relabeled
        elif use_hipi:
            for i in range(self.meta_batch_size):
                observations_tiled = np.tile(np.expand_dims(partial_batches[i]['observations'], axis=1), (1, self.meta_batch_size, 1))
                partial_batches[i]['rewards'] = np.reshape(-np.linalg.norm(observations_tiled[:, :, 3:5] - goals[None, :, -3:-1], axis=2), (-1, 1))
                partial_batches[i]['terminals'] = np.tile(partial_batches[i]['terminals'], (1, self.meta_batch_size)).reshape(-1, 1)
                assert partial_batches[i]['rewards'].shape[0] == self.meta_batch_size * partial_batches[i]['observations'].shape[0]

        def all_values(key):
            return [partial_batch[key] for partial_batch in partial_batches]

        keys = partial_batches[0].keys()

        if use_hipi:
            batch = {key: np.concatenate(all_values(key), axis=0) for key in keys}
            batch['tasks'] = goals
            return batch
        return {key: np.concatenate(all_values(key), axis=0) for key in keys}

    def return_all_samples(self, return_list=False):
        if return_list:
            all_samples = []
        else:
            all_samples = None
        for pool in self.pools.values():
            samples = pool.return_all_samples()
            if not return_list:
                if all_samples is None:
                    all_samples = samples
                else:
                    for key in all_samples:
                        all_samples[key] = np.concatenate([all_samples[key], samples[key]], axis=0)
            else:
                all_samples.append(samples)
        return all_samples

    def save_latest_experience(self, pickle_path):
        latest_samples = self.last_n_batch(self._samples_since_save)

        with gzip.open(pickle_path, 'wb') as f:
            pickle.dump(latest_samples, f)

        self._samples_since_save = 0

