from __future__ import annotations

import gzip
import pickle
import random
from typing import Dict, List, Tuple

import numpy as np
import torch
import torch_geometric

from rl.agents.segment_tree import MinSegmentTree, SumSegmentTree


class BipartiteNodeData(torch_geometric.data.Data):

    """
    Store node observation.
    """

    def __init__(
        self,
        constraint_features,
        edge_indices,
        edge_features,
        variable_features,
        candidates=None,
        action=None,
    ):
        super().__init__()
        self.constraint_features = constraint_features
        self.edge_index = edge_indices
        self.edge_attr = edge_features
        self.variable_features = variable_features
        self.nb_variables = variable_features.size(0) if variable_features is not None else 0
        self.nb_constraints = constraint_features.size(0) if variable_features is not None else 0
        self.candidates = candidates
        self.nb_candidates = candidates.size(0) if candidates is not None else 0
        self.actions = action if action is not None else 0

    def __inc__(self, key, value, store, *args, **kwargs):
        if key == "edge_index":
            return torch.tensor([[self.constraint_features.size(0)], [self.variable_features.size(0)]])
        elif key == "candidates" or key == "actions":
            return self.variable_features.size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)


class GraphDataset(torch_geometric.data.Dataset):

    """
    Dataset class for storing input (static and dynamic features)
    and output (candidate scores) of a set of B&B nodes.
    """

    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files
        self.idx_problems = []

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        try:
            with gzip.open(self.sample_files[index], "rb") as f:
                sample = pickle.load(f)
        except EOFError:
            print(f"Sample {self.sample_files[index]} is raising an issue.")
            self.idx_problems += [index]
            return self.get(index + 1)

        sample_observation, _, _, _ = sample

        constraint_features = sample_observation.row_features
        edge_indices = sample_observation.edge_features.indices.astype(np.int32)
        edge_features = np.expand_dims(sample_observation.edge_features.values, axis=-1)
        variable_features = sample_observation.variable_features  # nan can be there

        graph = BipartiteNodeData(
            torch.FloatTensor(constraint_features),
            torch.LongTensor(edge_indices),
            torch.FloatTensor(edge_features),
            torch.FloatTensor(variable_features),
        )

        graph.num_nodes = constraint_features.shape[0] + variable_features.shape[0]

        return graph


class ReplayBuffer:
    """A simple numpy replay buffer."""

    def __init__(self, size: int, batch_size: int = 32, n_step: int = 1):
        self.state_buf = np.empty(size, dtype=object)
        self.next_state_buf = np.empty(size, dtype=object)
        self.action_buf = np.zeros([size], dtype=np.int16)
        self.action_idx_buf = np.zeros([size], dtype=np.int16)
        self.reward_buf = np.zeros([size], dtype=np.float32)
        self.done_buf = np.empty(size, dtype=np.bool_)
        self.n_step_next_state_buf = np.empty(size, dtype=object)
        self.n_step_reward_buf = np.zeros([size], dtype=np.float32)
        self.n_step_done_buf = np.empty(size, dtype=np.bool_)
        self.depth_buf = np.zeros([size], dtype=np.int16)
        self.max_size, self.batch_size = size, batch_size
        self.fill_state_initialized = False
        self.ptr, self.size = 0, 0

        # for N-step Learning
        self.n_step = n_step

    def __iter__(self):
        self._iter_idx = 0
        return self

    def __next__(self):
        if self._iter_idx >= self.size:
            raise StopIteration
        else:
            transition = {
                "state": self.state_buf[self._iter_idx],
                "action": self.action_buf[self._iter_idx],
                "action_idx": self.action_idx_buf[self._iter_idx],
                "reward": self.reward_buf[self._iter_idx],
                "next_state": self.next_state_buf[self._iter_idx],
                "done": self.done_buf[self._iter_idx],
                "n_step_next_state": self.n_step_next_state_buf[self._iter_idx],
                "n_step_reward": self.n_step_reward_buf[self._iter_idx],
                "n_step_done": self.n_step_done_buf[self._iter_idx],
                "depth": self.depth_buf[self._iter_idx],
            }
            self._iter_idx += 1
            return transition

    def reset(self):
        self.state_buf = np.empty(self.max_size, dtype=object)
        self.next_state_buf = np.empty(self.max_size, dtype=object)
        self.action_buf = np.zeros([self.max_size], dtype=np.int16)
        self.action_idx_buf = np.zeros([self.max_size], dtype=np.int16)
        self.reward_buf = np.zeros([self.max_size], dtype=np.float32)
        self.done_buf = np.empty(self.max_size, dtype=np.bool_)
        self.n_step_next_state_buf = np.empty(self.max_size, dtype=object)
        self.n_step_reward_buf = np.zeros([self.max_size], dtype=np.float32)
        self.n_step_done_buf = np.empty(self.max_size, dtype=np.bool_)
        self.depth_buf = np.zeros([self.max_size], dtype=np.int16)
        self.ptr, self.size = 0, 0
        return self

    def store(
        self,
        state: Tuple[np.ndarray],
        action: int,
        action_idx: int,
        reward: float,
        next_state: object,
        done: bool,
        depth: int,
        n_step_next_state: object = None,
        n_step_reward: float = None,
        n_step_done: bool = None,
        **kwargs,
    ):
        self.state_buf[self.ptr] = state
        self.next_state_buf[self.ptr] = next_state
        self.action_buf[self.ptr] = action
        self.action_idx_buf[self.ptr] = action_idx
        self.reward_buf[self.ptr] = reward
        self.done_buf[self.ptr] = done
        self.n_step_next_state_buf[self.ptr] = n_step_next_state
        self.n_step_reward_buf[self.ptr] = n_step_reward
        self.n_step_done_buf[self.ptr] = n_step_done
        self.depth_buf[self.ptr] = depth
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample_batch(self, indices: np.ndarray = None) -> Dict[str, np.ndarray]:
        if indices is None:
            assert len(self) >= self.batch_size
            indices = np.random.choice(self.size, size=self.batch_size, replace=False)
        state = self.state_buf[indices]
        next_state = self.next_state_buf[indices]
        action = self.action_buf[indices]
        action_idx = self.action_idx_buf[indices]
        reward = self.reward_buf[indices]
        done = self.done_buf[indices]
        depth = self.depth_buf[indices]
        n_step_next_state = self.n_step_next_state_buf[indices]
        n_step_reward = self.n_step_reward_buf[indices]
        n_step_done = self.n_step_done_buf[indices]
        children_mask, n_step_children_mask = None, None
        weights = np.ones_like(indices)

        if isinstance(next_state[0], List):
            next_state, children_mask = self._preprocess_next_states(next_state)
            if self.n_step > 1:
                n_step_next_state, n_step_children_mask = self._preprocess_next_states(n_step_next_state)

        return {
            "state": state,
            "next_state": next_state,
            "action": action,
            "action_idx": action_idx,
            "reward": reward,
            "done": done,
            "depth": depth,
            "n_step_next_state": n_step_next_state,
            "n_step_reward": n_step_reward,
            "n_step_done": n_step_done,
            "weights": weights,
            "indices": indices,
            "children_mask": children_mask,
            "n_step_children_mask": n_step_children_mask,
        }

    def __len__(self) -> int:
        return self.size

    def _preprocess_next_states(self, next_states: np.ndarray[List[Tuple]]) -> List[np.ndarray]:
        # First, get max size, i.e. max number of children in batch.
        # Final returned list will have this size, and will be composed of arrays of shape batch_size.
        # List[np.ndarray[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]
        batch_size = len(next_states)
        max_children_number = max([len(sample) for sample in next_states])
        if max_children_number == 0:
            return None, None

        children_mask = np.ones((max_children_number, batch_size))
        if not self.fill_state_initialized:
            self.initialize_fill_state()

        for i, sample in enumerate(next_states):
            if len(sample) < max_children_number:
                next_states[i] = sample + [self.fill_state] * (max_children_number - len(sample))
                children_mask[len(sample) :, i] = 0

        preprocessed_next_states = []
        for i in range(max_children_number):
            sample = np.array([next_states[j][i] for j in range(batch_size)], dtype=object)
            preprocessed_next_states.append(sample)

        return preprocessed_next_states, children_mask

    def initialize_fill_state(self):
        constraint_features, edge_indices, edge_features, variable_features, action_set = self.state_buf[0]
        constraint_features *= 0
        variable_features *= 0
        fill_tuple_state = (constraint_features, edge_indices, edge_features, variable_features, action_set)
        self.fill_state_initialized = True
        self.fill_state = fill_tuple_state


class PrioritizedReplayBuffer(ReplayBuffer):
    """
    Attributes:
        max_priority (float): max priority
        tree_ptr (int): next index of tree
        alpha (float): alpha parameter for prioritized replay buffer
        sum_tree (SumSegmentTree): sum tree for prior
        min_tree (MinSegmentTree): min tree for min prior to get max weight
    """

    def __init__(
        self,
        size: int,
        batch_size: int = 32,
        alpha: float = 0.6,
        n_step: int = 1,
    ):
        """Initialization."""
        assert alpha >= 0

        super(PrioritizedReplayBuffer, self).__init__(size, batch_size, n_step)
        self.max_priority, self.tree_ptr = 1.0, 0
        self.alpha = alpha

        # capacity must be positive and a power of 2.
        tree_capacity = 1
        while tree_capacity < self.max_size:
            tree_capacity *= 2

        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)

    def store(
        self,
        state: Tuple[np.ndarray],
        action: int,
        action_idx: int,
        reward: float,
        next_state: object,
        done: bool,
        depth: int,
        n_step_next_state: object = None,  # n_step
        n_step_reward: float = None,  # n_step
        n_step_done: bool = None,  # n_step
        **kwargs,
    ):
        """Store experience and priority."""
        super().store(
            state=state,
            action=action,
            action_idx=action_idx,
            reward=reward,
            next_state=next_state,
            done=done,
            depth=depth,
            n_step_next_state=n_step_next_state,
            n_step_reward=n_step_reward,
            n_step_done=n_step_done,
            **kwargs,
        )

        self.sum_tree[self.tree_ptr] = self.max_priority**self.alpha
        self.min_tree[self.tree_ptr] = self.max_priority**self.alpha
        self.tree_ptr = (self.tree_ptr + 1) % self.max_size

    def sample_batch(self, indices: np.ndarray = None, beta: float = 0.4) -> Dict[str, np.ndarray]:
        """Sample a batch of experiences."""
        assert len(self) >= self.batch_size
        assert beta > 0

        indices = self._sample_proportional() if indices is None else indices

        state = self.state_buf[indices]
        next_state = self.next_state_buf[indices]
        action = self.action_buf[indices]
        action_idx = self.action_idx_buf[indices]
        reward = self.reward_buf[indices]
        done = self.done_buf[indices]
        depth = self.depth_buf[indices]
        n_step_next_state = self.n_step_next_state_buf[indices]
        n_step_reward = self.n_step_reward_buf[indices]
        n_step_done = self.n_step_done_buf[indices]
        weights = np.array([self._calculate_weight(i, beta) for i in indices])
        children_mask, n_step_children_mask = None, None

        if isinstance(next_state[0], List):
            next_state, children_mask = self._preprocess_next_states(next_state)
            if self.n_step > 1:
                n_step_next_state, n_step_children_mask = self._preprocess_next_states(n_step_next_state)

        return {
            "state": state,
            "next_state": next_state,
            "action": action,
            "action_idx": action_idx,
            "reward": reward,
            "done": done,
            "depth": depth,
            "n_step_next_state": n_step_next_state,
            "n_step_reward": n_step_reward,
            "n_step_done": n_step_done,
            "weights": weights,
            "indices": indices,
            "children_mask": children_mask,
            "n_step_children_mask": n_step_children_mask,
        }

    def update_priorities(self, indices: List[int], priorities: np.ndarray):
        """Update priorities of sampled transitions."""
        assert len(indices) == len(priorities)

        for idx, priority in zip(indices, priorities):
            assert priority > 0
            assert 0 <= idx < len(self)

            self.sum_tree[idx] = priority**self.alpha
            self.min_tree[idx] = priority**self.alpha

            self.max_priority = max(self.max_priority, priority)

    def _sample_proportional(self) -> List[int]:
        """Sample indices based on proportions."""
        indices = []
        p_total = self.sum_tree.sum(0, len(self) - 1)
        segment = p_total / self.batch_size

        for i in range(self.batch_size):
            a = segment * i
            b = segment * (i + 1)
            upperbound = random.uniform(a, b)
            idx = self.sum_tree.retrieve(upperbound)
            indices.append(idx)

        return indices

    def _calculate_weight(self, idx: int, beta: float):
        """Calculate the weight of the experience at idx."""
        # get max weight
        p_min = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (p_min * len(self)) ** (-beta)

        # calculate weights
        p_sample = self.sum_tree[idx] / self.sum_tree.sum()
        weight = (p_sample * len(self)) ** (-beta)
        weight = weight / max_weight

        return weight
