import random
import numpy as np


class ReplayBuffer:
    def __init__(self, capacity=10000, batch_size=64):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        self.batch_size = batch_size

    def push(self, state, action, reward, next_state):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return zip(*random.sample(self.buffer, batch_size))

def update_q(
    actions,
    t,
    last_observed_state,
    Q_tables,
    learning,
    N,
    gamma=0.95,
    alpha=0.1,
    replay_buffer=None,
):
    states_new_observed = learning[:, t]

    if replay_buffer:
        _update_q_with_buffer(N,
            actions,
            last_observed_state,
            states_new_observed,
            learning,
            gamma,
            Q_tables,
            alpha,
            replay_buffer,
            t
        )
    else:
        _update_q(
            N,
            actions,
            last_observed_state,
            states_new_observed,
            learning,
            gamma,
            Q_tables,
            alpha,
            t
        )


def _update_q(
    N,
    actions,
    last_observed_state,
    states_new_observed,
    learning,
    gamma,
    Q_tables,
    alpha,
    t,
):
    for i in range(N):
        a = actions[i].astype(int)
        state_last_observed = last_observed_state[i].astype(int)
        state_new_observed = states_new_observed[i].astype(int)
        q_target = learning[i, t] + gamma * Q_tables[i][state_new_observed].max()
        q_predict = Q_tables[i][state_last_observed][a]
        Q_tables[i][state_last_observed][a] += alpha * (q_target - q_predict)


def _update_q_with_buffer(
    N,
    actions,
    last_observed_state,
    states_new_observed,
    learning,
    gamma,
    Q_tables,
    alpha,
    replay_buffer,
    t,
):
    batch_size = replay_buffer.batch_size
    if len(replay_buffer.buffer) > batch_size:
        states, sampled_actions, rewards, next_states = replay_buffer.sample(batch_size)
        for j in range(batch_size):
            for i in range(N):
                state = states[j][i].astype(int)
                action = sampled_actions[j][i].astype(int)
                reward = rewards[j][i]
                next_state = next_states[j][i].astype(int)
                q_target = reward + gamma * np.max(Q_tables[i][next_state])
                q_predict = Q_tables[i][state][action]
                Q_tables[i][state][action] += alpha * (q_target - q_predict)

    # Store current experiences in the replay buffer
    reward = learning[:,t]
    replay_buffer.push(last_observed_state, actions, reward, states_new_observed)
