"""
Rollout buffer for storing the data collected during the training of the agent.
"""

from typing import Tuple

import torch


class RolloutBuffer:
    def __init__(
        self, state_dim: int, n_actions: int, batch_size: int, device: torch.device
    ) -> None:
        self.states = torch.zeros((batch_size, 1, state_dim)).to(device)
        self.actions = torch.zeros((batch_size, 1, n_actions)).to(device)
        self.logprobs = torch.zeros((batch_size, 1)).to(device)
        self.rewards = torch.zeros((batch_size, 1)).to(device)
        self.dones = torch.zeros((batch_size, 1)).to(device)
        self.values = torch.zeros((batch_size, 1)).to(device)

        self.device = device
        self.step = 0

    def add(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        logprob: torch.Tensor,
        reward: torch.Tensor,
        done: torch.Tensor,
        value: torch.Tensor,
    ) -> None:
        self.states[self.step] = state.to(self.device)
        self.actions[self.step] = action.to(self.device)
        self.logprobs[self.step] = logprob.to(self.device)
        self.rewards[self.step] = reward.to(self.device)
        self.dones[self.step] = done.to(self.device)
        self.values[self.step] = value.to(self.device)

        self.step += 1

    def reset(self) -> None:
        self.step = 0

    def get_data(
        self,
    ) -> Tuple[torch.Tensor, ...]:
        return (
            self.states,
            self.actions,
            self.logprobs,
            self.rewards,
            self.dones,
            self.values,
        )


class DQNBuffer:
    def __init__(
        self, state_dim: int, n_actions: int, batch_size: int, device: torch.device
    ) -> None:
        self.states = torch.zeros((batch_size, 1, state_dim)).to(device)
        self.next_states = torch.zeros((batch_size, 1, state_dim)).to(device)
        self.actions = torch.zeros((batch_size, 1, n_actions)).to(device)
        self.rewards = torch.zeros((batch_size, 1)).to(device)
        self.dones = torch.zeros((batch_size, 1)).to(device)

        self.device = device
        self.step = 0

    def add(
        self,
        state: torch.Tensor,
        next_state: torch.Tensor,
        action: torch.Tensor,
        reward: torch.Tensor,
        done: torch.Tensor,
    ) -> None:
        self.states[self.step] = state.to(self.device)
        self.next_states[self.step] = next_state.to(self.device)
        self.actions[self.step] = action.to(self.device)
        self.rewards[self.step] = reward.to(self.device)
        self.dones[self.step] = done.to(self.device)

        self.step += 1

    def reset(self) -> None:
        self.step = 0

    def get_data(
        self,
    ) -> Tuple[torch.Tensor, ...]:
        return (
            self.states,
            self.next_states,
            self.actions,
            self.rewards,
            self.dones,
        )


class BisimulatorBuffer:
    def __init__(
        self,
        state_without_group_dim: int,
        actual_next_state_dim: int,
        group_dim: int,
        batch_size: int,
        device: torch.device,
    ) -> None:
        self.states = torch.zeros((batch_size, 1, state_without_group_dim)).to(device)
        self.actions = torch.zeros((batch_size, 1)).to(device)
        self.groups = torch.zeros((batch_size, 1, group_dim)).to(device)
        self.rewards = torch.zeros((batch_size, 1)).to(device)
        self.next_states = torch.zeros((batch_size, 1, actual_next_state_dim)).to(
            device
        )

        self.device = device
        self.step = 0

    def add(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        group: torch.Tensor,
        reward: torch.Tensor,
        next_state: torch.Tensor,
    ) -> None:
        self.states[self.step] = state.to(self.device)
        self.actions[self.step] = action.to(self.device)
        self.groups[self.step] = group.to(self.device)
        self.rewards[self.step] = reward.to(self.device)
        self.next_states[self.step] = next_state.to(self.device)

        self.step += 1

    def reset(self) -> None:
        self.step = 0

    def get_data(self) -> Tuple[torch.Tensor, ...]:
        return (
            self.states,
            self.actions,
            self.groups,
            self.rewards,
            self.next_states,
        )


class APPORolloutBuffer:
    def __init__(
        self, state_dim: int, n_actions: int, batch_size: int, device: torch.device
    ) -> None:
        self.states = torch.zeros((batch_size, 1, state_dim)).to(device)
        self.actions = torch.zeros((batch_size, 1, n_actions)).to(device)
        self.logprobs = torch.zeros((batch_size, 1)).to(device)
        self.rewards = torch.zeros((batch_size, 1)).to(device)
        self.dones = torch.zeros((batch_size, 1)).to(device)
        self.values = torch.zeros((batch_size, 1)).to(device)
        self.deltas = torch.zeros((batch_size, 1)).to(device)
        self.delta_deltas = torch.zeros((batch_size, 1)).to(device)

        self.device = device
        self.step = 0

    def add(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        logprob: torch.Tensor,
        reward: torch.Tensor,
        done: torch.Tensor,
        value: torch.Tensor,
        delta: torch.Tensor,
        delta_delta: torch.Tensor,
    ) -> None:
        self.states[self.step] = state.to(self.device)
        self.actions[self.step] = action.to(self.device)
        self.logprobs[self.step] = logprob.to(self.device)
        self.rewards[self.step] = reward.to(self.device)
        self.dones[self.step] = done.to(self.device)
        self.values[self.step] = value.to(self.device)
        self.deltas[self.step] = delta.to(self.device)
        self.delta_deltas[self.step] = delta_delta.to(self.device)

        self.step += 1

    def reset(self) -> None:
        self.step = 0

    def get_data(
        self,
    ) -> Tuple[torch.Tensor, ...]:
        return (
            self.states,
            self.actions,
            self.logprobs,
            self.rewards,
            self.dones,
            self.values,
            self.deltas,
            self.delta_deltas,
        )


class ELBERTRolloutBuffer:
    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        batch_size: int,
        group_dim: int,
        device: torch.device,
    ) -> None:
        self.states = torch.zeros((batch_size, 1, state_dim)).to(device)
        self.actions = torch.zeros((batch_size, 1, n_actions)).to(device)
        self.logprobs = torch.zeros((batch_size, 1)).to(device)
        self.rewards = torch.zeros((batch_size, 1)).to(device)
        self.dones = torch.zeros((batch_size, 1)).to(device)
        self.values = torch.zeros((batch_size, 1)).to(device)
        self.supply_rewards = torch.zeros((batch_size, 1, group_dim)).to(device)
        self.demand_rewards = torch.zeros((batch_size, 1, group_dim)).to(device)
        self.supply_values = torch.zeros((batch_size, 1, group_dim)).to(device)
        self.demand_values = torch.zeros((batch_size, 1, group_dim)).to(device)

        self.device = device
        self.step = 0

    def add(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        logprob: torch.Tensor,
        reward: torch.Tensor,
        done: torch.Tensor,
        value: torch.Tensor,
        supply_reward: torch.Tensor,
        demand_reward: torch.Tensor,
        supply_value: torch.Tensor,
        demand_value: torch.Tensor,
    ) -> None:
        self.states[self.step] = state.to(self.device)
        self.actions[self.step] = action.to(self.device)
        self.logprobs[self.step] = logprob.to(self.device)
        self.rewards[self.step] = reward.to(self.device)
        self.dones[self.step] = done.to(self.device)
        self.values[self.step] = value.to(self.device)
        self.supply_rewards[self.step] = supply_reward.to(self.device)
        self.demand_rewards[self.step] = demand_reward.to(self.device)
        self.supply_values[self.step] = supply_value.to(self.device)
        self.demand_values[self.step] = demand_value.to(self.device)

        self.step += 1

    def reset(self) -> None:
        self.step = 0

    def get_data(
        self,
    ) -> Tuple[torch.Tensor, ...]:
        return (
            self.states,
            self.actions,
            self.logprobs,
            self.rewards,
            self.dones,
            self.values,
            self.supply_rewards,
            self.demand_rewards,
            self.supply_values,
            self.demand_values,
        )
