import random
import numpy as np
from operator import itemgetter

class ReplayMemory:
    def __init__(self, capacity, topk=0):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

        self.topk = topk
        self.normal_buffer = []
        self.top_buffer = []

    def reward_sort(self, x):
        return x[2]


    def push(self, data, reward=None):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)

        self.buffer[self.position] = data
        self.position = (self.position + 1) % self.capacity



    def push_batch(self, batch):
        if len(self.buffer) < self.capacity:
            append_len = min(self.capacity - len(self.buffer), len(batch))
            self.buffer.extend([None] * append_len)

        if self.position + len(batch) < self.capacity:
            self.buffer[self.position : self.position + len(batch)] = batch
            self.position += len(batch)
        else:
            self.buffer[self.position : len(self.buffer)] = batch[:len(self.buffer) - self.position]
            self.buffer[:len(batch) - len(self.buffer) + self.position] = batch[len(self.buffer) - self.position:]
            self.position = len(batch) - len(self.buffer) + self.position


    def sample(self, batch_size, printer=False):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        batch = np.asarray(random.sample(self.buffer, int(batch_size)), dtype=object)
        if printer:
            print(batch.shape)
            print(batch)

        return batch[:, 0], batch[:, 1], batch[:, 2], batch[:, 3], batch[:, 4]

    def sample_topk(self, batch_size, top_frac):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)

        topk_samples = int(np.ceil(top_frac * batch_size))
        normal_samples = int(batch_size - topk_samples)

        if topk_samples > len(self.top_buffer):
            normal_samples += (topk_samples-len(self.top_buffer))
            topk_samples = len(self.top_buffer)

        if topk_samples == 0:
            batch = np.asarray(random.sample(self.buffer, normal_samples), dtype=object)
        elif normal_samples == 0:
            batch = np.asarray(random.sample(self.top_buffer, topk_samples), dtype=object)
        else:
            batch_topk = np.asarray(random.sample(self.top_buffer, topk_samples), dtype=object)
            batch_normal = np.asarray(random.sample(self.buffer, normal_samples), dtype=object)
            batch = np.concatenate((batch_topk, batch_normal), axis=0)

        return batch[:, 0], batch[:, 1], batch[:, 2], batch[:, 3], batch[:, 4]

    def sample_all_batch(self, batch_size):
        idxes = np.random.randint(0, len(self.buffer), batch_size)
        batch = np.asarray(list(itemgetter(*idxes)(self.buffer)), dtype=object)

        return batch[:, 0], batch[:, 1], batch[:, 2], batch[:, 3], batch[:, 4], batch[:, 5] #raw_state

    def return_all(self):
        return self.buffer

    def __len__(self):
        return len(self.buffer)


class ReplayMemoryPER:

    def __init__(self, capacity):
        self.length = 0
        self.size = capacity
        self.curr_write_idx = 0
        self.available_samples = 0
        self.buffer = [None] * self.size
        self.base_node, self.leaf_nodes = create_tree([0 for i in range(self.size)])
        self.state_idx = 0
        self.action_idx = 1
        self.reward_idx = 2
        self.next_state_idx = 3
        self.done_idx = 4
        self.beta = 0.4
        self.alpha = 0.6
        self.min_priority = 0.01

    def push(self, experience: tuple, priority: float):
        self.buffer[self.curr_write_idx] = experience
        self.update(self.curr_write_idx, priority)
        self.curr_write_idx += 1
        if self.curr_write_idx > self.length:
            self.length = self.curr_write_idx
        # reset the current writer position index if greater than the allowed size
        if self.curr_write_idx >= self.size:
            self.length = self.size
            self.curr_write_idx = 0
        # max out available samples at the memory buffer size
        if self.available_samples + 1 < self.size:
            self.available_samples += 1
        else:
            self.available_samples = self.size - 1

    def push_batch(self, batch):
        for experience in batch:
            if experience == None:
                continue
            priority = experience[-2]
            self.push(experience, priority)

    def update(self, idx: int, priority: float):
        update(self.leaf_nodes[idx], self.adjust_priority(priority))

    def adjust_priority(self, priority: float):
        return np.power(priority + self.min_priority, self.alpha)

    def sample(self, num_samples: int):
        sampled_idxs = []
        is_weights = []
        sample_no = 0
        while sample_no < num_samples:
            sample_val = np.random.uniform(0, self.base_node.value)
            samp_node = retrieve(sample_val, self.base_node)
            if samp_node.idx < self.available_samples - 1: #is this necessary?
                sampled_idxs.append(samp_node.idx)
                p = samp_node.value / self.base_node.value
                is_weights.append((self.available_samples + 1) * p) #why +1?
                sample_no += 1
        # apply the beta factor and normalise so that the maximum is_weight < 1
        is_weights = np.array(is_weights)
        is_weights = np.power(is_weights, -self.beta)
        is_weights = is_weights / np.max(is_weights)
        states, actions, rewards, next_states, done = [], [], [], [], []
        for i, idx in enumerate(sampled_idxs):
            states.append(self.buffer[idx][self.state_idx])
            actions.append(self.buffer[idx][self.action_idx])
            rewards.append(self.buffer[idx][self.reward_idx])
            next_states.append(self.buffer[idx][self.next_state_idx])
            done.append(self.buffer[idx][self.done_idx])
        #print(np.array(rewards).mean())
        return states, actions, np.array(rewards), next_states, np.array(done), sampled_idxs, is_weights

    def sample_all_batch(self, batch_size):
        idxes = np.random.randint(0, self.length, batch_size)
        batch = np.asarray(list(itemgetter(*idxes)(self.buffer)), dtype=object)
        return batch[:, 0], batch[:, 1], batch[:, 2], batch[:, 3], batch[:, 4], batch[:, 5] #raw_state

    def __len__(self):
        return self.length


class Node:
    def __init__(self, left, right, is_leaf: bool = False, idx = None):
        self.left = left
        self.right = right
        self.is_leaf = is_leaf
        if not self.is_leaf:
            self.value = self.left.value + self.right.value
        self.parent = None
        self.idx = idx  # this value is only set for leaf nodes
        if left is not None:
            left.parent = self
        if right is not None:
            right.parent = self

    @classmethod
    def create_leaf(cls, value, idx):
        leaf = cls(None, None, is_leaf=True, idx=idx)
        leaf.value = value
        return leaf

def create_tree(input: list):
    nodes = [Node.create_leaf(v, i) for i, v in enumerate(input)]
    leaf_nodes = nodes
    while len(nodes) > 1:
        inodes = iter(nodes)
        nodes = [Node(*pair) for pair in zip(inodes, inodes)]

    return nodes[0], leaf_nodes


def retrieve(value: float, node: Node):
    if node.is_leaf:
        return node

    if node.left.value >= value:
        return retrieve(value, node.left)
    else:
        return retrieve(value - node.left.value, node.right)


def update(node: Node, new_value: float):
    change = new_value - node.value

    node.value = new_value
    propagate_changes(change, node.parent)


def propagate_changes(change: float, node: Node):
    node.value += change

    if node.parent is not None:
        propagate_changes(change, node.parent)
