import logging

from replay.per import ProportionalPER
from replay.seq import SequentialBase


logger = logging.getLogger(__name__)

# class SequentialPER(ProportionalPER):
#     def __init__(self, config, state_keys=None):
#         self._state_keys = state_keys or getattr(self, '_state_keys', [])
#         super().__init__(config)
#         self._memory = collections.deque(maxlen=self._capacity)
#         self._first = True
        
#     def _construct_temp_buff(self):
#         from replay.func import create_local_buffer
#         self._tmp_buf = create_local_buffer({
#             'replay_type': self._replay_type,
#             'n_envs': self._n_envs,
#             'sample_size': self._sample_size,
#             'state_keys': self._state_keys,
#             'extra_keys': self._extra_keys,
#         })

#     def __len__(self):
#         return len(self._memory)
        
#     def add(self, **data):
#         self._tmp_buf.add(**data)
#         if self._tmp_buf.is_full():
#             self.merge(self._tmp_buf.sample())
#             self._tmp_buf.reset()

#     def merge(self, local_buffer):
#         """ Add local_buffer to memory """
#         if isinstance(local_buffer, (list, tuple)):
#             length = len(local_buffer)
#             mem_idxes = np.arange(self._mem_idx, self._mem_idx + length) % self._capacity
#             priorities = np.array([b.pop('priority', self._top_priority) for b in local_buffer])
#             np.testing.assert_array_less(0, priorities)
#             self._data_structure.batch_update(mem_idxes, priorities)
#             self._memory.extend(local_buffer)
#             self._mem_idx = self._mem_idx + length
#         else:
#             priority = local_buffer.pop('priority', self._top_priority)
#             assert priority > 0, priority
#             self._data_structure.update(self._mem_idx, priority)
#             self._memory.append(local_buffer)
#             self._mem_idx = self._mem_idx + 1
#         if self._first:
#             do_logging('First sample', logger=logger)
#             for k, v in self._memory[0].items():
#                 do_logging(f'\t{k}, {v.shape}, {v.dtype}', logger=logger)
#             self._first = False
        
#         if not self._is_full and self._mem_idx >= self._capacity:
#             print(f'Memory is full({len(self)})')
#             self._is_full = True
#         self._mem_idx = self._mem_idx % self._capacity
        
#     def clear_temp_buffer(self):
#         self._tmp_buf.clear()

#     def _get_samples(self, idxes):
#         assert len(idxes) == self._batch_size, idxes
#         results = collections.defaultdict(list)
#         [results[k].append(v) for i in idxes for k, v in self._memory[i].items()]
#         results = {k: np.stack(v) for k, v in results.items()}
#         for k, v in results.items():
#             if k in self._state_keys:
#                 assert v.shape[0] == self._batch_size, (k, v.shape)
#             elif k in self._extra_keys:
#                 assert v.shape[:2] == (self._batch_size, self._sample_size+1), (k, v.shape)
#             else:
#                 assert v.shape[:2] == (self._batch_size, self._sample_size), (k, v.shape)

#         return results

class SequentialPER(SequentialBase, ProportionalPER):
    def __init__(self, config, state_keys=None):
        super().__init__(config, state_keys=state_keys)
