import numpy as np

from .replay_pool import ReplayPool


class UnionPool(ReplayPool):
    def __init__(self, pools, adaptive_sampling=False):
        pool_sizes = np.array([b.size for b in pools])
        self._total_size = sum(pool_sizes)

        self.pools = pools
        self.curr_pool_id = 0
        self.adaptive_sampling = adaptive_sampling

    def add_sample(self, **kwargs):
        assert 'observations' in kwargs
        self.curr_pool_id = kwargs['observations'][-len(self.pools):].argmax()
        self.pools[self.curr_pool_id].add_sample(**kwargs)

    def add_samples(self, samples):
        assert 'observations' in samples
        for i in range(len(self.pools)):
            mask = samples['observations'][:, -len(self.pools):].argmax(axis=1) == i
            data = {}
            for key in samples:
                data[key] = samples[key][mask]
            self.pools[i].add_samples(data)

    def terminate_episode(self):
        pass

    @property
    def size(self):
        pool_sizes = np.array([b.size for b in self.pools])
        self._total_size = sum(pool_sizes)
        return self._total_size

    def add_path(self, **kwargs):
        raise NotImplementedError

    def last_n_batch(self, last_n, field_name_filter=None, **kwargs):
        return self.pools[self.curr_pool_id].last_n_batch(last_n,
                                                        field_name_filter=field_name_filter,
                                                        **kwargs)


    def random_batch(self, batch_size, field_name_filter=None, task_id=None, use_hipi=False, **kwargs):

        # TODO: Hack
        if not self.adaptive_sampling:
            partial_batch_sizes = np.array([float(batch_size) / len(self.pools)] * len(self.pools))
        else:
            assert 'sample_probs' in kwargs
            sample_probs = kwargs.pop('sample_probs')
            partial_batch_sizes = np.array([float(batch_size)] * len(self.pools)) * sample_probs
        partial_batch_sizes = partial_batch_sizes.astype(int)
        partial_batch_sizes[0] = batch_size - sum(partial_batch_sizes[1:])

        if task_id is not None:
            partial_batch_sizes = np.zeros(len(self.pools)).astype(int)
            partial_batch_sizes[task_id] = batch_size

        # only sample from pushing tasks
        # partial_batch_sizes[2] = batch_size // len(self.pools)
        # partial_batch_sizes[:2] = 0

        partial_batches = [
            pool.random_batch(partial_batch_size, field_name_filter=field_name_filter, **kwargs) for pool,
            partial_batch_size in zip(self.pools, partial_batch_sizes)
        ]

        keys = partial_batches[0].keys()

        if kwargs.get('random_order', False):
            import random
            random.shuffle(partial_batches)

        if use_hipi:
            for i in range(len(self.pools)):
                partial_batches[i]['rewards'] = partial_batches[i]['rewards'].reshape(-1, 1)
                partial_batches[i]['terminals'] = partial_batches[i]['terminals'].reshape(-1, 1)
                assert partial_batches[i]['rewards'].shape[0] == len(self.pools) * partial_batches[i]['observations'].shape[0]

        def all_values(key):
            return [partial_batch[key] for partial_batch in partial_batches]

        return {key: np.concatenate(all_values(key), axis=0) for key in keys}

    def return_all_samples(self, return_list=False):
        if return_list:
            all_samples = []
        else:
            all_samples = None
        for pool in self.pools:
            samples = pool.return_all_samples()
            if not return_list:
                if all_samples is None:
                    all_samples = samples
                else:
                    for key in all_samples:
                        all_samples[key] = np.concatenate([all_samples[key], samples[key]], axis=0)
            else:
                all_samples.append(samples)
        return all_samples

