import gzip
import pickle

import numpy as np

from .replay_pool import ReplayPool


class FlexibleReplayPool(ReplayPool):
    def __init__(self, max_size, fields_attrs, obs_filter=False, modify_rew=False):
        super(FlexibleReplayPool, self).__init__()

        max_size = int(max_size)
        self._max_size = max_size

        self.fields = {}
        self.fields_attrs = {}

        self.add_fields(fields_attrs)

        self.obs_filter = obs_filter
        self.modify_rew = modify_rew

        self._pointer = 0
        self._size = 0
        self._samples_since_save = 0

    @property
    def size(self):
        return self._size

    @property
    def field_names(self):
        return list(self.fields.keys())

    def add_fields(self, fields_attrs):
        self.fields_attrs.update(fields_attrs)

        for field_name, field_attrs in fields_attrs.items():
            field_shape = (self._max_size, *field_attrs['shape'])
            initializer = field_attrs.get('initializer', np.zeros)
            self.fields[field_name] = initializer(
                field_shape, dtype=field_attrs['dtype'])

    def _advance(self, count=1):
        self._pointer = (self._pointer + count) % self._max_size
        self._size = min(self._size + count, self._max_size)
        self._samples_since_save += count

    def add_sample(self, sample):
        samples = {
            key: value[None, ...]
            for key, value in sample.items()
        }
        self.add_samples(samples)

    def add_samples(self, samples):
        field_names = list(samples.keys())
        num_samples = samples[field_names[0]].shape[0]

        index = np.arange(
            self._pointer, self._pointer + num_samples) % self._max_size

        for field_name in self.field_names:
            default_value = (
                self.fields_attrs[field_name].get('default_value', 0.0))
            values = samples.get(field_name, default_value)
            if field_name not in samples.keys() and field_name in samples['infos'][0].keys():
                values = np.expand_dims(np.array([samples['infos'][i].get(field_name, default_value) for i in range(num_samples)]), axis=1)
            try:
                assert values.shape[0] == num_samples
                if isinstance(values[0], dict):
                    values = np.stack([np.concatenate([
                                value[key]
                                for key in value.keys()
                            ], axis=-1) for value in values])
                self.fields[field_name][index] = values
            except:
                import pdb; pdb.set_trace()
        self._advance(num_samples)

    def random_indices(self, batch_size):
        if self._size == 0: return np.arange(0, 0)
        return np.random.randint(0, self._size, batch_size)

    def random_batch(self, batch_size, field_name_filter=None, **kwargs):
        random_indices = self.random_indices(batch_size)
        return self.batch_by_indices(
            random_indices, field_name_filter=field_name_filter, **kwargs)

    def last_n_batch(self, last_n, field_name_filter=None, **kwargs):
        last_n_indices = np.arange(
            self._pointer - min(self.size, last_n), self._pointer
        ) % self._max_size
        return self.batch_by_indices(
            last_n_indices, field_name_filter=field_name_filter, **kwargs)

    def filter_fields(self, field_names, field_name_filter):
        if isinstance(field_name_filter, str):
            field_name_filter = [field_name_filter]

        if isinstance(field_name_filter, (list, tuple)):
            field_name_list = field_name_filter

            def filter_fn(field_name):
                return field_name in field_name_list

        else:
            filter_fn = field_name_filter

        filtered_field_names = [
            field_name for field_name in field_names
            if filter_fn(field_name)
        ]

        return filtered_field_names

    def batch_by_indices(self, indices, field_name_filter=None):
        if np.any(indices % self._max_size > self.size):
            raise ValueError(
                "Tried to retrieve batch with indices greater than current"
                " size")

        field_names = self.field_names
        if field_name_filter is not None:
            field_names = self.filter_fields(
                field_names, field_name_filter)

        return {
            field_name: self.fields[field_name][indices]
            for field_name in field_names
        }

    def save_latest_experience(self, pickle_path):
        latest_samples = self.last_n_batch(self._samples_since_save)

        with gzip.open(pickle_path, 'wb') as f:
            pickle.dump(latest_samples, f)

        self._samples_since_save = 0

    def load_experience(self, experience_path):
        with gzip.open(experience_path, 'rb') as f:
            latest_samples = pickle.load(f)

        key = list(latest_samples.keys())[0]
        num_samples = latest_samples[key].shape[0]
        for field_name, data in latest_samples.items():
            assert data.shape[0] == num_samples, data.shape

        self.add_samples(latest_samples)
        self._samples_since_save = 0

    def return_all_samples(self):
        return {
            field_name: self.fields[field_name][:self.size]
            for field_name in self.field_names
        }

    def __getstate__(self):
        state = self.__dict__.copy()
        state['fields'] = {
            field_name: self.fields[field_name][:self.size]
            for field_name in self.field_names
        }

        return state

    def __setstate__(self, state):
        if state['_size'] < state['_max_size']:
            pad_size = state['_max_size'] - state['_size']
            for field_name in state['fields'].keys():
                field_shape = state['fields_attrs'][field_name]['shape']
                state['fields'][field_name] = np.concatenate((
                    state['fields'][field_name],
                    np.zeros((pad_size, *field_shape))
                ), axis=0)

        self.__dict__ = state
