import numpy as np
import torch

from mas_sat.graph.base import BaseGraph

class ReplayBuffer(object):
    def __init__(self, args) -> None:
        # hyper-parameters
        self.size = args.buffer_size
        self.batch_size = args.batch_size

        # internal states
        self.index = 0
        self.counter = 0
        self.ready = False
        self.full = False

        # buffers
        self.states = np.zeros(self.size, dtype=object)
        self.actions = torch.zeros(self.size, dtype=torch.long)
        self.rewards = torch.zeros(self.size)
        self.terminals = torch.zeros(self.size)

    # basic get functions
    def get_counter(self) -> int:
        return self.counter
    
    def is_ready(self) -> bool:
        return self.ready
    
    def is_full(self) -> bool:
        return self.full
    
    # main interfaces
    def add_transition(self, state, action, reward, terminal):
        self.states[self.index] = state.detach().cpu().clone() # deep copy
        self.actions[self.index] = action
        self.rewards[self.index] = reward
        self.terminals[self.index] = float(terminal)

        self.index += 1
        self.counter += 1
        if self.index >= self.batch_size:
            self.ready = True
        if self.index >= self.size:
            self.index = 0
            self.full = True

    def sample(self):
        """"
        return a batch or transition
        (state, action, reward, state_next, non_terminal)
        """
        assert self.is_ready()
        current_size = self.size if self.full else self.index
        # minus one to make sure having a next
        indices = np.random.choice(range(0, current_size-1), self.batch_size)
        return (
            BaseGraph.batch_graph(self.states[indices]),
            self.actions[indices],
            self.rewards[indices],
            BaseGraph.batch_graph(self.states[indices+1]),
            1-self.terminals[indices],
        )
