from tonic.replays import Buffer
from collections import deque
import numpy as np
import torch
# import tonic.logger


class DequeBuffer:
    '''
    Deque buffer hybrid. Store data twice, use contigious array structure
    for easy checks across episodes, use deque structure for fast appending
    and popping
    '''
    def __init__(self, maxlen):
        self.data = deque(maxlen=maxlen)
        super().__init__()

    def __len__(self):
        return len(self.data)

    def store(self, **kwargs):
        self.data.append(kwargs)

    def get_data(self):
        return self.data.popleft().copy()


class HERBuffer(Buffer):
    def __init__(self, *args, **kwargs):
        self.goal_threshold = kwargs.pop('goal_threshold', 0.08)
        maxlen = kwargs.pop('episode_length', 300)
        super().__init__(*args, **kwargs)
        self._buffer = DequeBuffer(maxlen=maxlen)
        print(f'{self.goal_threshold=}')
        print(f'{maxlen=}')

    def store(self, **kwargs):
        self._buffer.store(**kwargs)
        if self.ready_her():
            self._add_to_real_buffer(kwargs)

    def _sample_new_goals(self, bounds, kwargs):
        mask = np.argwhere(bounds == 0)
        if len(mask) > 0:
            bounds[mask] = 100
        if len(self._buffer) > 2:
            indices = np.random.randint(np.ones_like(bounds), bounds + 1, size=(kwargs['observations'].shape[0]))
        else:
            indices = [0]
        if len(mask) > 0 and len(self._buffer) > 2:
            indices[mask] = 0
        new_goals = [self._buffer.data[time_step]['next_observations'][idx, -3:] for idx, time_step in enumerate(indices)]
        return new_goals

    def _get_bounds(self):
        """
        Counts how many steps in each episode were played before it was
        reset. This function counts this for transitions from multiple workers
        where the workers can have unequal resets.
        """
        counting = np.ones_like(self._buffer.data[0]['resets'])
        bounds = np.zeros(shape=self._buffer.data[0]['resets'].shape)
        idx = 0
        while True:
            counting[self._buffer.data[idx]['resets']] = 0
            bounds += counting
            if np.sum(counting) == 0:
                break
            idx += 1
        return bounds

    def _add_to_real_buffer(self, kwargs):
        """
        kwargs are most recent transition tuples
        """
        while self.ready_her():
            bounds = self._get_bounds()
            new_goals = self._sample_new_goals(bounds, kwargs)
            data = self._buffer.get_data()
            super().store(**data)
            if np.random.randint(0, 1) < 0.8:
                relabelled_data = self._relabel(data, new_goals)
                super().store(**relabelled_data)

    def ready_her(self):
        """
        As long as each worker has at least one full episode.
        """
        return np.all(np.sum([x['resets'] for x in self._buffer.data], axis=0))

    def _relabel(self, data, new_goals):
        # TODO check if DEP data is different across parallel envs
        for idx in range(len(new_goals)):
            data['next_observations'][idx, -6:-3] = new_goals[idx]
            data['observations'][idx, -6:-3] = new_goals[idx]
            term = 1 if np.linalg.norm(data['next_observations'][idx, -3:] - new_goals[idx]) < self.goal_threshold else 0
            if term:
                data['terminations'][idx] = 1
                data['resets'][idx] = 1
                data['rewards'][idx] += 10
        return data

    def save(self, path):
        # tonic.logger.log('Saving HER Buffer')
        if hasattr(self, '_buffer'):
            save_path = self.get_path(path, '_buffer')
            torch.save(getattr(self, '_buffer'), save_path)
        super().save(path)

    def load(self, load_fn, path):
        # tonic.logger.log('Loading HER Buffer')
        if hasattr(self, '_buffer'):
            load_path = self.get_path(path, '_buffer')
            try:
                setattr(self, '_buffer', load_fn(load_path))
            except:
                print('Error in buffer loading, it is freshly initialized')
        super().load(load_fn, path)

class OstrichHER(HERBuffer):
    def __init__(self, *args, **kwargs):
        self.goal_threshold = kwargs.pop('goal_threshold', 5e-2)
        maxlen = kwargs.pop('episode_length', 400)
        super().__init__(*args, **kwargs)
        self._buffer = DequeBuffer(maxlen=maxlen)
        print(f'{self.goal_threshold=}')
        print(f'{maxlen=}')

    def _relabel(self, data, new_goals):
        for idx in range(len(new_goals)):
            data['next_observations'][idx, -6:-3] = new_goals[idx]
            data['observations'][idx, -6:-3] = new_goals[idx]
            term = 1 if np.sqrt(((data['next_observations'][idx, -9:-6] - new_goals[idx]) ** 2).sum()) < 5e-2 else 0
            if term:
                data['terminations'][idx] = 1
                data['resets'][idx] = 1
                data['rewards'][idx] += 10
        return data

    def _sample_new_goals(self, bounds, kwargs):
        mask = np.argwhere(bounds == 0)
        if len(mask) > 0:
            bounds[mask] = 100
        if len(self._buffer) > 2:
            indices = np.random.randint(np.ones_like(bounds), bounds + 1, size=(kwargs['observations'].shape[0]))
        else:
            indices = [0]
        if len(mask) > 0 and len(self._buffer) > 2:
            indices[mask] = 0
        new_goals = [self._buffer.data[time_step]['next_observations'][idx, -9:-6] for idx, time_step in enumerate(indices)]
        return new_goals
