from typing import Callable, Iterable, NamedTuple, Optional, Tuple

import torch


class ReplayData(NamedTuple):
    """Log weights and samples generated by annealed importance sampling."""

    x: torch.Tensor
    log_w: torch.Tensor
    log_q_old: torch.Tensor


class SimpleReplayData(NamedTuple):
    """Log weights and samples generated by annealed importance sampling."""

    x: torch.Tensor
    energy: torch.Tensor


def sample_without_replacement(logits: torch.Tensor, n: int) -> torch.Tensor:
    # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
    z = (
        torch.distributions.Gumbel(torch.tensor(0.0), torch.tensor(1.0))
        .sample(logits.shape)
        .to(logits.device)
    )
    topk = torch.topk(z + logits, n, sorted=False)
    indices = topk.indices
    indices = indices[torch.randperm(n).to(indices.device)]
    return indices


class PrioritisedReplayBuffer:
    def __init__(
        self,
        dim: int,
        max_length: int,
        min_sample_length: int,
        initial_sampler: Callable[[], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
        device: str = "cpu",
        sample_with_replacement: bool = False,
        fill_buffer_during_init: bool = True,
        prioritize=False,
    ):
        """
        Create prioritised replay buffer for batched sampling and adding of data.
        Args:
            dim: dimension of x data
            max_length: maximum length of the buffer
            min_sample_length: minimum length of buffer required for sampling
            initial_sampler: sampler producing x, log_w and log q, used to fill the buffer up to
                the min sample length. The initialised flow + AIS may be used here,
                or we may desire to use AIS with more distributions to give the flow a "good start".
            device: replay buffer device
            sample_with_replacement: Whether to sample from the buffer with replacement.
            fill_buffer_during_init: Whether to use `initial_sampler` to fill the buffer initially.
                If a checkpoint is going to be loaded then this should be set to False.

        The `max_length` and `min_sample_length` should be sufficiently long to prevent overfitting
        to the replay data. For example, if `min_sample_length` is equal to the
        sampling batch size, then we may overfit to the first batch of data, as we would update
        on it many times during the start of training.
        """
        assert min_sample_length < max_length
        self.dim = dim
        self.max_length = max_length
        self.min_sample_length = min_sample_length

        self.is_full = False  # whether the buffer is full
        self.can_sample = False  # whether the buffer is full enough to begin sampling
        self.sample_with_replacement = sample_with_replacement
        self.prioritize = prioritize
        self.fill_buffer_during_init = fill_buffer_during_init
        self.initial_sampler = initial_sampler
        self.initizalied = False

    def initialize(self, device):
        if not self.initizalied:
            self.buffer = ReplayData(
                x=torch.zeros(self.max_length, self.dim).to(device),
                log_w=torch.zeros(
                    self.max_length,
                ).to(device),
                log_q_old=torch.zeros(
                    self.max_length,
                ).to(device),
            )
            self.possible_indices = torch.arange(self.max_length).to(device)
            self.device = device
            self.current_index = 0
            self.sample_with_replacement = self.sample_with_replacement
            self.prioritize = self.prioritize

            if self.fill_buffer_during_init:
                while self.can_sample is False:
                    # fill buffer up minimum length
                    x, log_w, log_q_old = self.initial_sampler()
                    self.add(x, log_w, log_q_old)
            else:
                print("Buffer not initialised, expected that checkpoint will be loaded.")


    @torch.no_grad()
    def add(self, x: torch.Tensor, log_w: torch.Tensor, log_q_old: torch.Tensor) -> None:
        """Add a new batch of generated data to the replay buffer."""
        batch_size = x.shape[0]
        x = x.to(self.device)
        log_w = log_w.to(self.device)
        log_q_old = log_q_old.to(self.device)
        indices = (torch.arange(batch_size) + self.current_index).to(self.device) % self.max_length
        self.buffer.x[indices] = x
        self.buffer.log_w[indices] = log_w
        self.buffer.log_q_old[indices] = log_q_old
        new_index = self.current_index + batch_size
        if not self.is_full:
            self.is_full = new_index >= self.max_length
            self.can_sample = new_index >= self.min_sample_length
        self.current_index = new_index % self.max_length

    @torch.no_grad()
    def sample(
        self, batch_size: int, prioritize: Optional[bool] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return a batch of sampled data, if the batch size is specified then the batch will have
        a leading axis of length batch_size, otherwise the default self.batch_size will be used."""
        if not self.can_sample:
            raise Exception("Buffer must be at minimum length before calling sample")

        if prioritize is None:
            prioritize = self.prioritize

        max_index = self.max_length if self.is_full else self.current_index
        if self.sample_with_replacement:
            if prioritize:
                indices = torch.distributions.Categorical(
                    logits=self.buffer.log_w[:max_index]
                ).sample((batch_size,))
            else:
                indices = torch.randint(max_index, (batch_size,)).to(self.device)
        else:
            if prioritize:
                indices = sample_without_replacement(self.buffer.log_w[:max_index], batch_size).to(
                    self.device
                )
            else:
                indices = torch.randperm(max_index)[:batch_size].to(self.device)
        x, log_w, log_q_old, indices = (
            self.buffer.x[indices],
            self.buffer.log_w[indices],
            self.buffer.log_q_old[indices],
            indices,
        )
        return x, log_w, log_q_old, indices

    def sample_n_batches(
        self, batch_size: int, n_batches: int
    ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
        """Returns a list of batches."""
        x, log_w, log_q_old, indices = self.sample(batch_size * n_batches)
        x_batches = torch.chunk(x, n_batches)
        log_w_batches = torch.chunk(log_w, n_batches)
        log_q_old_batches = torch.chunk(log_q_old, n_batches)
        indices_batches = torch.chunk(indices, n_batches)
        dataset = [
            (x, log_w, log_q_old, indxs)
            for x, log_w, log_q_old, indxs in zip(
                x_batches, log_w_batches, log_q_old_batches, indices_batches
            )
        ]
        return dataset

    @torch.no_grad()
    def adjust(self, log_w_adjustment, log_q, indices):
        """Adjust log weights and log q to match new value of theta, this is typically performed
        over minibatches, rather than over the whole dataset at once."""
        valid_adjustment = torch.isfinite(log_w_adjustment) & torch.isfinite(log_q)
        log_w_adjustment, log_q, valid_indices = (
            log_w_adjustment[valid_adjustment],
            log_q[valid_adjustment],
            indices[valid_adjustment],
        )
        valid_indices = valid_indices.to(self.device)
        self.buffer.log_w[valid_indices] += log_w_adjustment.to(self.device)
        self.buffer.log_q_old[valid_indices] = log_q.to(self.device)

        # Kill samples in the buffer for which the `log_w_adjustment` is invalid.
        # A common reason this can occur is if AIS discovers a point far outside the reasonable range of the problem.
        # Which causes nan log probs under the flow.
        invalid_indices = indices[~valid_adjustment].to(self.device)
        self.buffer.log_w[invalid_indices] = -torch.ones_like(
            self.buffer.log_w[invalid_indices]
        ) * (float("inf"))

    def save(self, path):
        """Save buffer to file."""
        to_save = {
            "x": self.buffer.x.detach().cpu(),
            "log_w": self.buffer.log_w.detach().cpu(),
            "log_q_old": self.buffer.log_q_old.detach().cpu(),
            "current_index": self.current_index,
            "is_full": self.is_full,
            "can_sample": self.can_sample,
        }
        torch.save(to_save, path)

    def load(self, path):
        """Load buffer from file."""
        old_buffer = torch.load(path)
        indices = torch.arange(self.max_length)
        self.buffer.x[indices] = old_buffer["x"].to(self.device)
        self.buffer.log_w[indices] = old_buffer["log_w"].to(self.device)
        self.buffer.log_q_old[indices] = old_buffer["log_q_old"].to(self.device)
        self.current_index = old_buffer["current_index"]
        self.is_full = old_buffer["is_full"]
        self.can_sample = old_buffer["can_sample"]


class SimpleBuffer:
    def __init__(
        self,
        dim: int,
        max_length: int,
        min_sample_length: int,
        initial_sampler: Callable[[], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
        device: str = "cpu",
        sample_with_replacement: bool = False,
        fill_buffer_during_init: bool = True,
        prioritize=False,
    ):
        """
        Create prioritised replay buffer for batched sampling and adding of data.
        Args:
            dim: dimension of x data
            max_length: maximum length of the buffer
            min_sample_length: minimum length of buffer required for sampling
            initial_sampler: sampler producing x, log_w and log q, used to fill the buffer up to
                the min sample length. The initialised flow + AIS may be used here,
                or we may desire to use AIS with more distributions to give the flow a "good start".
            device: replay buffer device
            sample_with_replacement: Whether to sample from the buffer with replacement.
            fill_buffer_during_init: Whether to use `initial_sampler` to fill the buffer initially.
                If a checkpoint is going to be loaded then this should be set to False.

        The `max_length` and `min_sample_length` should be sufficiently long to prevent overfitting
        to the replay data. For example, if `min_sample_length` is equal to the
        sampling batch size, then we may overfit to the first batch of data, as we would update
        on it many times during the start of training.
        """
        assert min_sample_length < max_length
        self.dim = dim
        self.max_length = max_length
        self.min_sample_length = min_sample_length


        self.is_full = False  # whether the buffer is full
        self.can_sample = False  # whether the buffer is full enough to begin sampling
        self.sample_with_replacement = sample_with_replacement
        self.prioritize = prioritize
        self.fill_buffer_during_init = fill_buffer_during_init
        self.initial_sampler = initial_sampler
        self.initizalied = False


    def initialize(self, device):
        if not self.initizalied:
            self.buffer = SimpleReplayData(
                x=torch.zeros(self.max_length, self.dim).to(device),
                energy=torch.zeros(
                    self.max_length,
                ).to(device),
            )
            self.possible_indices = torch.arange(self.max_length).to(device)
            self.device = device
            self.current_index = 0

            if self.fill_buffer_during_init:
                while self.can_sample is False:
                    # fill buffer up minimum length
                    x, energy = self.initial_sampler()
                    self.add(x, energy)
            else:
                print("Buffer not initialised, expected that checkpoint will be loaded.")

            self.initizalied = True


    def __len__(self):
        if self.is_full:
            return self.max_length
        else:
            return self.current_index

    @torch.no_grad()
    def add(self, x: torch.Tensor, energy: torch.Tensor) -> None:
        """Add a new batch of generated data to the replay buffer."""
        batch_size = x.shape[0]
        x = x.to(self.device)
        energy = energy.to(self.device)
        indices = (torch.arange(batch_size) + self.current_index).to(self.device) % self.max_length
        self.buffer.x[indices] = x
        self.buffer.energy[indices] = energy
        new_index = self.current_index + batch_size
        if not self.is_full:
            self.is_full = new_index >= self.max_length
            self.can_sample = new_index >= self.min_sample_length
        self.current_index = new_index % self.max_length

    def get_last_n_inserted(self, num_to_get: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.is_full:
            assert num_to_get <= self.max_length
        else:
            assert num_to_get < self.current_index

        start_idx = self.current_index - num_to_get
        idxs = [torch.arange(max(start_idx, 0), self.current_index)]
        if start_idx < 0:
            idxs.append(torch.arange(self.max_length + start_idx, self.max_length))

        idx = torch.cat(idxs)

        return self.buffer.x[idx], self.buffer.energy[idx]

    @torch.no_grad()
    def sample(
        self, batch_size: int, prioritize: Optional[bool] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """Return a batch of sampled data, if the batch size is specified then the batch will have
        a leading axis of length batch_size, otherwise the default self.batch_size will be used."""
        if not self.can_sample:
            raise Exception("Buffer must be at minimum length before calling sample")

        if prioritize is None:
            prioritize = self.prioritize

        max_index = self.max_length if self.is_full else self.current_index
        if self.sample_with_replacement:
            if prioritize:
                indices = torch.distributions.Categorical(
                    logits=self.buffer.energy[:max_index]
                ).sample((batch_size,))
            else:
                indices = torch.randint(max_index, (batch_size,)).to(self.device)
        else:
            if prioritize:
                indices = sample_without_replacement(
                    self.buffer.energy[:max_index], batch_size
                ).to(self.device)
            else:
                indices = torch.randperm(max_index)[:batch_size].to(self.device)
        x, energy, indices = (
            self.buffer.x[indices],
            self.buffer.energy[indices],
            indices,
        )
        return x, energy, indices

    def sample_n_batches(
        self, batch_size: int, n_batches: int
    ) -> Iterable[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        """Returns a list of batches."""
        x, log_w, indices = self.sample(batch_size * n_batches)
        x_batches = torch.chunk(x, n_batches)
        log_w_batches = torch.chunk(log_w, n_batches)
        indices_batches = torch.chunk(indices, n_batches)
        dataset = [
            (x, log_w, indxs) for x, log_w, indxs in zip(x_batches, log_w_batches, indices_batches)
        ]
        return dataset

    def save(self, path):
        """Save buffer to file."""
        to_save = {
            "x": self.buffer.x.detach().cpu(),
            "log_w": self.buffer.log_w.detach().cpu(),
            "current_index": self.current_index,
            "is_full": self.is_full,
            "can_sample": self.can_sample,
        }
        torch.save(to_save, path)

    def load(self, path):
        """Load buffer from file."""
        old_buffer = torch.load(path)
        indices = torch.arange(self.max_length)
        self.buffer.x[indices] = old_buffer["x"].to(self.device)
        self.buffer.log_w[indices] = old_buffer["log_w"].to(self.device)
        self.current_index = old_buffer["current_index"]
        self.is_full = old_buffer["is_full"]
        self.can_sample = old_buffer["can_sample"]


if __name__ == "__main__":
    # to check that the replay buffer runs
    dim = 5
    batch_size = 3
    n_batches_total_length = 2
    length = n_batches_total_length * batch_size
    min_sample_length = int(length * 0.5)

    def initial_sampler():
        return (
            torch.ones(batch_size, dim),
            torch.zeros(batch_size),
            torch.ones(batch_size),
        )

    buffer = PrioritisedReplayBuffer(dim, length, min_sample_length, initial_sampler)
    n_batches = 3
    for i in range(100):
        buffer.add(torch.ones(batch_size, dim), torch.zeros(batch_size), torch.ones(batch_size))
        x, log_w, log_q_old, indices = buffer.sample(batch_size)
        buffer.adjust(log_w + 1, log_q_old + 0.1, indices)
