import numpy as np
import torch
from train.common.config import Config

class LessonBuffer:

    min_size = 256

    def __init__(self, size, max_time, state_shape, n_actions):
        self.max_size = size
        self.curr_size = 0
        self.next_spot_to_add = 0

        self.states_buffer = np.empty(shape=((size, max_time) + state_shape))
        self.actions_buffer = np.empty(shape=(size, max_time, n_actions))
        self.rewards_buffer = np.empty(shape=(size, max_time))
        self.lengths_buffer = np.empty(shape=(size, 1), dtype=np.int32)
        self.loss = np.empty(shape=(size, ))

    @property
    def is_full(self):
        return self.curr_size >= self.max_size

    def different_returns_encountered(self):
        """
        Training only makes sense if the sequences in the buffer have
        different returns. Otherwise the LSTM could learn to ignore the input.
        """
        buffer = self.rewards_buffer[:self.curr_size, -1]
        return np.unique(buffer).shape[0] > 1

    def full_enough(self):
        """There have to be enough samples before training makes sense."""
        return self.is_full or self.curr_size > self.min_size

    def add(self, states, actions, rewards, length):
        """Add episode to the buffer"""
        next_idx = self.next_spot_to_add
        self.next_spot_to_add = (next_idx + 1) % self.max_size
        self.curr_size = min(self.curr_size + 1, self.max_size)

        traj_length = states.shape[0]
        self.states_buffer[next_idx, :traj_length] = states.squeeze()
        self.actions_buffer[next_idx, :traj_length] = actions
        self.rewards_buffer[next_idx, :traj_length] = rewards
        self.rewards_buffer[next_idx, traj_length:] = 0
        self.lengths_buffer[next_idx] = length

    def get_samples(self, batch_size):
        "Get samples acc to softmax prob of losses"
        probs = self.softmax(self.loss)
        probs += 1e-5
        probs /= probs.sum()
        indices = np.random.choice(self.loss, batch_size, replace=True, p=probs)
        return indices

    def softmax(self, values):
        e_values = np.exp((values - values.max())/ self.temperature)
        return e_values/e_values.sum()

    def sample(self, batch_size, randomize=True, balance=False):
        """Randomly pick `batch_size` samples from the buffer or pick according to losses"""

        if randomize:
            indices = np.random.randint(0, self.curr_size, batch_size)

        if balance:
            # sample higher loss samples more often
            temp_temp = self.temperature
            self.temperature = 0.001
            indices = self.get_samples(batch_size)

        return (
            self.states_buffer[indices], self.actions_buffer[indices],
            self.rewards_buffer[indices], self.lengths_buffer[indices], indices
        )



if __name__ == "__main__":
    args = Config()
    model = net()
    if args.model:
        path_to_model_parameters = args.model_path
        print("Loading behaviour cloned model!")
        model.load_state_dict(torch.load(path_to_model_parameters))
