"""Off-policy buffer."""
import numpy as np
import torch
from mas3ac.common.buffers.off_policy_buffer_base import OffPolicyBufferBase


class OffPolicyBufferFP(OffPolicyBufferBase):
    """Off-policy buffer that uses Feature-Pruned (FP) state.
    When FP state is used, the critic takes different global state as input for different actors. Thus, OffPolicyBufferFP has an extra dimension for number of agents.
    """

    def __init__(self, args, share_obs_space, num_agents, obs_spaces, act_spaces):
        """Initialize off-policy buffer.
        Args:
            args: (dict) arguments
            share_obs_space: (gym.Space or list) share observation space
            num_agents: (int) number of agents
            obs_spaces: (gym.Space or list) observation spaces
            act_spaces: (gym.Space) action spaces
        """
        super(OffPolicyBufferFP, self).__init__(args, share_obs_space, num_agents, obs_spaces, act_spaces)

        # Buffer for share observations
        self.share_obs = np.zeros((self.buffer_size, self.num_agents, *self.share_obs_shape), dtype=np.float32)

        # Buffer for next share observations
        self.next_share_obs = np.zeros((self.buffer_size, self.num_agents, *self.share_obs_shape), dtype=np.float32)

        # Buffer for rewards received by agents at each timestep
        self.rewards = np.zeros((self.buffer_size, self.num_agents, 1), dtype=np.float32)
        self.costs = np.zeros((self.buffer_size, self.num_agents, 1), dtype=np.float32)
        self.stability_costs = np.zeros((self.buffer_size, self.num_agents, 1), dtype=np.float32)

        # Buffer for done and termination flags
        self.dones = np.full((self.buffer_size, self.num_agents, 1), False)
        self.terms = np.full((self.buffer_size, self.num_agents, 1), False)

    def sample(self):
        """Sample data for training.
        Returns:
            sp_share_obs
            sp_obs
            sp_actions
            sp_available_actions
            sp_reward
            sp_cost
            sp_stability_cost
            sp_done
            sp_valid_transitions
            sp_term
            sp_next_share_obs
            sp_next_obs
            sp_next_available_actions
            sp_gamma
        """
        self.update_end_flag()  # update the current end flag
        indice = torch.randperm(self.cur_size).numpy()[: self.batch_size]


        sp_share_obs = self.share_obs[indice].transpose(1, 0, 2)
        sp_obs = np.array([self.obs[agent_id][indice] for agent_id in range(self.num_agents)])
        sp_actions = np.array(
            [self.actions[agent_id][indice] for agent_id in range(self.num_agents)]
        )
        sp_valid_transitions = np.array(
            [self.valid_transitions[agent_id][indice] for agent_id in range(self.num_agents)]
        )
        if self.act_spaces[0].__class__.__name__ == 'Discrete':
            sp_available_actions = np.array(
                [self.available_actions[agent_id][indice] for agent_id in range(self.num_agents)]
            )

        # compute the indices along n steps
        indice = np.repeat(np.expand_dims(indice, axis=-1), self.num_agents, axis=-1)  # (batch_size, n_agents)
        indices = [indice]
        for _ in range(self.n_step - 1):
            indices.append(self.next(indices[-1]))


        sp_done = [self.dones[indices[-1][:, agent_id], agent_id] for agent_id in range(self.num_agents)]
        sp_done = np.array(sp_done)  # Use dtype=object to preserve the structure

        sp_term = [self.terms[indices[-1][:, agent_id], agent_id] for agent_id in range(self.num_agents)]
        sp_term = np.array(sp_term)  # Use dtype=object to preserve the structure

        sp_next_share_obs = [self.next_share_obs[indices[-1][:, agent_id], agent_id] for agent_id in range(self.num_agents)]
        sp_next_share_obs = np.array(sp_next_share_obs)

        sp_next_obs = np.array(
            [self.next_obs[agent_id][indices[-1][:, agent_id]] for agent_id in range(self.num_agents)]
        )
        if self.act_spaces[0].__class__.__name__ == 'Discrete':
            sp_next_available_actions = np.array(
                [self.next_available_actions[agent_id][indices[-1][:, agent_id]] for agent_id in range(self.num_agents)]
            )

        # compute accumulated rewards and the corresponding gamma
        gamma_buffer = np.ones((self.num_agents, self.n_step + 1))
        for i in range(1, self.n_step + 1):
            gamma_buffer[:, i] = gamma_buffer[:, i - 1] * self.gamma
        sp_reward = np.zeros((self.batch_size, self.num_agents, 1))
        sp_cost = np.zeros((self.batch_size, self.num_agents, 1))
        sp_stability_cost = np.zeros((self.batch_size, self.num_agents, 1))
        gammas = np.full((self.batch_size, self.num_agents), self.n_step)
        for n in range(self.n_step - 1, -1, -1):
            now = indices[n]
            end_flag = np.column_stack(
                [self.end_flag[now[:, agent_id], agent_id] for agent_id in range(self.num_agents)]
            )
            gammas[end_flag > 0] = n + 1
            sp_reward[end_flag > 0] = 0.0
            sp_cost[end_flag > 0] = 0.0
            sp_stability_cost[end_flag > 0] = 0.0
            rewards = np.expand_dims(
                np.column_stack([self.rewards[now[:, agent_id], agent_id] for agent_id in range(self.num_agents)]), axis=-1
            )
            costs = np.expand_dims(
                np.column_stack([self.costs[now[:, agent_id], agent_id] for agent_id in range(self.num_agents)]), axis=-1
            )
            stability_costs = np.expand_dims(
                np.column_stack([self.stability_costs[now[:, agent_id], agent_id] for agent_id in range(self.num_agents)]), axis=-1
            )
            sp_reward = rewards + self.gamma * sp_reward
            sp_cost = costs + self.gamma * sp_cost
            sp_stability_cost = stability_costs + self.gamma * sp_stability_cost

        sp_reward = sp_reward.transpose(1, 0, 2)
        sp_cost = sp_cost.transpose(1, 0, 2)
        sp_stability_cost = sp_stability_cost.transpose(1, 0, 2)


        sp_gamma = [gamma_buffer[agent_id][gammas[:, agent_id]] for agent_id in range(self.num_agents)]
        sp_gamma = np.expand_dims(sp_gamma, axis=-1)

        if self.act_spaces[0].__class__.__name__ == 'Discrete':
            return (
                sp_share_obs,
                sp_obs,
                sp_actions,
                sp_available_actions,
                sp_reward,
                sp_cost,
                sp_stability_cost,
                sp_done,
                sp_valid_transitions,
                sp_term,
                sp_next_share_obs,
                sp_next_obs,
                sp_next_available_actions,
                sp_gamma
            )
        else:
            return (
                sp_share_obs,
                sp_obs,
                sp_actions,
                None,
                sp_reward,
                sp_cost,
                sp_stability_cost,
                sp_done,
                sp_valid_transitions,
                sp_term,
                sp_next_share_obs,
                sp_next_obs,
                None,
                sp_gamma,
            )

    def next(self, indices):
        """Get next indices"""
        end_flag = np.column_stack(
            [self.end_flag[indices[:, agent_id], agent_id] for agent_id in range(self.num_agents)]
        )  # (batch_size, n_agents)
        return (indices + (1 - end_flag) * self.n_rollout_threads) % self.buffer_size

    def update_end_flag(self):
        """Update current end flag for computing n-step return.
        End flag is True at the steps which are the end of an episode or the latest but unfinished steps.
        """
        self.unfinished_index = (
            self.idx - np.arange(self.n_rollout_threads) - 1 + self.cur_size
        ) % self.cur_size
        self.end_flag = self.dones.copy().squeeze()  # FP: (batch_size, n_agents)
        self.end_flag[self.unfinished_index, :] = True

