import numpy as np

class ReplayBuffer:
    def __init__(self, capacity, obs_size, batch_size, replay_start_size, HER_sample_state_constant):
        self.capacity = capacity
        self.batch_size = batch_size
        self.replay_start_size = replay_start_size
        self.is_full = False
        self.buffer = []
        self.tau_prioritized_selection = 1
        self.HER_constant = HER_sample_state_constant

    def push(self, input):
        if len(self.buffer) == self.capacity:
            self.is_full = True
            del self.buffer[0]
        self.buffer.append(input)

    def sample(self):
        indices = np.random.choice(len(self.buffer), self.batch_size, replace = False)
        return [self.buffer[i] for i in indices]

    def sample_states(self, k, existing_states):
        state_dict = {}
        for dic in self.buffer:
            if tuple(dic['obs_t'][:self.HER_constant]) in state_dict.keys():
                state_dict[tuple(dic['obs_t'])[:self.HER_constant]] = np.maximum(dic['td_error_squared'], state_dict[tuple(dic['obs_t'])[:self.HER_constant]])
            else:
                state_dict[tuple(dic['obs_t'])[:self.HER_constant]] = dic['td_error_squared']

        state_list = []
        td_error_squared_list = []
        for state in state_dict:
            state_list.append(np.asarray(state))
            td_error_squared_list.append(state_dict[state])

        # == print ==
        # states = np.asarray([dic['obs_t'] for dic in self.buffer])
        # print_states = []
        # for i in np.arange(states.shape[0]):
        #     print_states.append(np.where(states[i, :] == 1)[0][0])
        # print(print_states)
        # print(np.around(probs, 2))
        # == end print ==

        new_states = []
        while len(new_states) < k:
            interims = td_error_squared_list - np.max(td_error_squared_list)
            interims = interims * self.tau_prioritized_selection
            probs = np.exp(interims) / np.sum(np.exp(interims))

            index = np.random.choice(len(state_list), 1, p = probs)[0]
            new_state = state_list[index].flatten()

            if self.is_in(new_state, existing_states[:, :self.HER_constant]) == 0:
                new_states.append(new_state)

            del state_list[index]
            del td_error_squared_list[index]
            if len(td_error_squared_list) == 0:
                break

        new_states = np.asarray(new_states)

        # == print ==
        # print_states = []
        # for i in np.arange(new_states.shape[0]):
        #     print_states.append(np.where(new_states[i, :] == 1)[0][0])
        # print(print_states)
        # print('----')
        # == end print ==
        return new_states, new_states.shape[0]

    def sample_features(self, k, existing_target_features):
        main_features = 4
        feature_list = np.arange(main_features).to_list()
        td_error_squared_list = [0] * main_features
        for dic in self.buffer:
            td_error_squared_list += dic['feature_tp1'][:main_features] * dic['td_error_squared']

        new_target_features = []
        while len(new_target_features) < k:
            interims = td_error_squared_list - np.max(td_error_squared_list)
            interims = interims * self.tau_prioritized_selection
            probs = np.exp(interims) / np.sum(np.exp(interims))

            index = np.random.choice(main_features, 1, p = probs)[0]
            new_target_feature = np.zeros(self.feature_size)
            new_target_feature[feature_list[index]] = 1

            if self.is_in(new_target_feature, existing_target_features) == 0:
                new_target_features.append(new_target_feature)

            del feature_list[index]
            del td_error_squared_list[index]
            if len(td_error_squared_list) == 0:
                break

        new_states = np.asarray(new_states)

        return new_states, new_states.shape[0]

    def replay_start(self):
        if len(self.buffer) >= self.replay_start_size:
            return True
        return False

    def is_in(self, new_state, states_list):
        new_state_tiled = np.tile(new_state, (len(states_list), 1))
        remainder = states_list - new_state_tiled
        is_in = np.all(np.isclose(remainder, 0), axis=1)
        is_in = is_in.astype(int)
        return np.sum(is_in)

