from collections import defaultdict

import numpy as np
from gym.spaces import Box, Dict, Discrete
import pdb

from .flexible_replay_pool import FlexibleReplayPool


def normalize_observation_fields(observation_space, name='observations'):
    if isinstance(observation_space, Dict):
        fields = [
            normalize_observation_fields(child_observation_space, name)
            for name, child_observation_space
            in observation_space.spaces.items()
        ]
        fields = {
            'observations.{}'.format(name): value
            for field in fields
            for name, value in field.items()
        }
    elif isinstance(observation_space, (Box, Discrete)):
        fields = {
            name: {
                'shape': observation_space.shape,
                'dtype': observation_space.dtype,
            }
        }
    else:
        raise NotImplementedError(
            "Observation space of type '{}' not supported."
            "".format(type(observation_space)))

    return fields


class SimpleReplayPool(FlexibleReplayPool):
    def __init__(self, observation_space, action_space, *args, **kwargs):
        self._observation_space = observation_space
        self._action_space = action_space
        add_relabel_mask = kwargs.pop('add_relabel_mask', False)
        num_tasks = kwargs.pop('num_tasks', 1)
        hipi = kwargs.pop('hipi', False)
        if hipi:
            assert not add_relabel_mask
        goal_conditioned = kwargs.pop('goal_conditioned', False)

        # observation_fields = normalize_observation_fields(observation_space)
        # active_size = sum(
        #     np.prod(observation_space.spaces[key].shape)
        #     for key in list(observation_space.spaces.keys()))

        # active_observation_shape = (active_size, )
        # # It's a bit memory inefficient to save the observations twice,
        # # but it makes the code *much* easier since you no longer have
        # # to worry about termination conditions.
        # observation_fields.update({
        #     'next_' + key: value
        #     for key, value in observation_fields.items()
        # })

        # fields = {
        #     **observation_fields,
        #     **{
        #         'actions': {
        #             'shape': self._action_space.shape,
        #             'dtype': 'float32'
        #         },
        #         'rewards': {
        #             'shape': (1, ),
        #             'dtype': 'float32'
        #         },
        #         # self.terminals[i] = a terminal was received at time i
        #         'terminals': {
        #             'shape': (1, ),
        #             'dtype': 'bool'
        #         },
        #         'observations': {
        #             'shape': active_observation_shape,
        #             'dtype': 'float32'
        #         },
        #         'next_observations': {
        #             'shape': active_observation_shape,
        #             'dtype': 'float32'
        #         },
        #     }
        # }
        active_size = sum(
            np.prod(observation_space.spaces[key].shape)
            for key in list(observation_space.spaces.keys()))

        active_observation_shape = (active_size, )

        fields = {
            'actions': {
                'shape': self._action_space.shape,
                'dtype': 'float32'
            },
            'rewards': {
                'shape': (1, ),
                'dtype': 'float32'
            },
            # self.terminals[i] = a terminal was received at time i
            'terminals': {
                'shape': (1, ),
                'dtype': 'bool'
            },
            'observations': {
                'shape': active_observation_shape,
                'dtype': 'float32'
            },
            'next_observations': {
                'shape': active_observation_shape,
                'dtype': 'float32'
            },
        }
        if add_relabel_mask:
            fields['relabel_masks'] = {
                'shape': (1, ),
                'dtype': 'float32'
            }
            # fields['original_task_id_onehot'] = {
            #     'shape': (num_tasks, ),
            #     'dtype': 'float32'
            # }
        if hipi and not goal_conditioned:
            fields['rewards'] = {
                'shape': (num_tasks, ),
                'dtype': 'float32'
            }
            fields['terminals'] = {
                'shape': (num_tasks, ),
                'dtype': 'bool'
            }


        super(SimpleReplayPool, self).__init__(
            *args, fields_attrs=fields, relabel=add_relabel_mask, **kwargs)

    def add_samples(self, samples):
        if not isinstance(self._observation_space, Dict):
            return super(SimpleReplayPool, self).add_samples(samples)

        dict_observations = defaultdict(list)
        # for observation in samples['observations']:
        #     for key, value in observation.items():
        #         dict_observations[key].append(value)

        # dict_next_observations = defaultdict(list)
        # for next_observation in samples['next_observations']:
        #     for key, value in next_observation.items():
        #         dict_next_observations[key].append(value)

        # samples.update(
        #    **{
        #        f'observations.{observation_key}': np.array(values)
        #        for observation_key, values in dict_observations.items()
        #    },
        #    **{
        #        f'next_observations.{observation_key}': np.array(values)
        #        for observation_key, values in dict_next_observations.items()
        #    },
        # )

        # del samples['observations']
        # del samples['next_observations']

        return super(SimpleReplayPool, self).add_samples(samples)

    # def add_model_samples(self, samples):
    #     field_names = list(samples.keys())
    #     num_samples = samples[field_names[0]].shape[0]

    #     index = np.arange(
    #         self._pointer, self._pointer + num_samples) % self._max_size

    #     for field_name in self.field_names:
    #         values = samples[field_name] 
    #         assert values.shape[0] == num_samples
    #         self.fields[field_name][index] = values

    #     self._advance(num_samples)
    #     pdb.set_trace()
        # field_names = samples.keys()
        # num_samples = samples['observations'].shape[0]
        # for i in range(num_samples):
        #     sample = {field: samples[field][i] for field in field_names}
        #     self.add_model_sample(sample)
        #     pdb.set_trace()
        #     # self.fields


    # def add_model_sample(self, sample):
    #     # self._size
    #     pass
    #     self._advance()

    def batch_by_indices(self,
                         indices,
                         field_name_filter=None,
                         observation_keys=None):
        if not isinstance(self._observation_space, Dict):
            return super(SimpleReplayPool, self).batch_by_indices(
                indices, field_name_filter=field_name_filter)

        batch = {
            field_name: self.fields[field_name][indices]
            for field_name in self.field_names
        }

        # if observation_keys is None:
        #     observation_keys = tuple(self._observation_space.spaces.keys())

        # observations = np.concatenate([
        #     batch['observations.{}'.format(key)]
        #     for key in observation_keys
        # ], axis=-1)

        # next_observations = np.concatenate([
        #     batch['next_observations.{}'.format(key)]
        #     for key in observation_keys
        # ], axis=-1)

        # batch['observations'] = observations
        # batch['next_observations'] = next_observations

        if field_name_filter is not None:
            filtered_fields = self.filter_fields(
                batch.keys(), field_name_filter)
            batch = {
                field_name: batch[field_name]
                for field_name in filtered_fields
            }

        return batch

    # def return_all_samples(self):
    #     all_samples = {
    #         field_name: self.fields[field_name][:self.size]
    #         for field_name in self.field_names
    #     }
    #     observation_keys = tuple(self._observation_space.spaces.keys())
    #     observations = np.concatenate([
    #         all_samples['observations.{}'.format(key)]
    #         for key in observation_keys
    #     ], axis=-1)

    #     next_observations = np.concatenate([
    #         all_samples['next_observations.{}'.format(key)]
    #         for key in observation_keys
    #     ], axis=-1)

    #     all_samples['observations'] = observations
    #     all_samples['next_observations'] = next_observations
    #     return all_samples

    def terminate_episode(self):
        pass
