"""Independent and Centralised Critic Implementations for TD3 based systems."""
import tensorflow as tf
import sonnet as snt

from og_marl.utils.trainer_utils import batch_concat_agent_id_to_obs

class StateAndActionCritic(snt.Module):

    def __init__(self, num_agents, num_actions):
        self.N = num_agents
        self.A = num_actions

        self._critic_network = snt.Sequential(
            [
                snt.Linear(128),
                tf.keras.layers.ReLU(),
                snt.Linear(128),
                tf.keras.layers.ReLU(),
                snt.Linear(1)
            ]
        )

        super().__init__()

    def initialise(self, observation, state, action):
        """ A method to initialise the parameters in the critic network.

        observation: a dummy observation with no batch dimension. We assume
            all agent's observations have the same shape.
        state: a dummy environment state with no batch dimension.
        action: a dummy action with no batch dimension. We assume all agents have 
            the same action shape.
        """
        state = tf.reshape(state, (1,1) + state.shape) # add time and batch dim
        actions = tf.stack([action]*self.N, axis=0) # action for each agent
        actions = tf.reshape(actions, (1,1) + actions.shape) # add time and batch dim

        self(None, state, actions, actions) # __call__ with dummy inputs

    def __call__(self, observations, states, agent_actions, other_actions, stop_other_actions_gradient=True):
        """Forward pass of critic network.
        
        observations [T,B,N,O]
        states [T,B,S]
        agent_actions [T,B,N,A]: the actions the agent took.
        other_actions [T,B,N,A]: the actions the other agents took.
        """
        # Repeat states for each agent
        states = tf.stack([states]*self.N, axis=2)

        # Concat states and joint actions
        critic_input = tf.concat([states, agent_actions], axis=-1)

        # Concat agent IDs to critic input
        critic_input = batch_concat_agent_id_to_obs(critic_input)

        q_values = self._critic_network(critic_input)
        return q_values


class StateAndJointActionCritic(snt.Module):

    def __init__(self, num_agents, num_actions):
        self.N = num_agents
        self.A = num_actions

        self._critic_network = snt.Sequential(
            [
                snt.Linear(128),
                tf.keras.layers.ReLU(),
                snt.Linear(128),
                tf.keras.layers.ReLU(),
                snt.Linear(1)
            ]
        )

        super().__init__()

    def initialise(self, observation, state, action):
        """ A method to initialise the parameters in the critic network.

        observation: a dummy observation with no batch dimension. We assume
            all agent's observations have the same shape.
        state: a dummy environment state with no batch dimension.
        action: a dummy action with no batch dimension. We assume all agents have 
            the same action shape.
        """
        state = tf.reshape(state, (1,1) + state.shape) # add time and batch dim
        actions = tf.stack([action]*self.N, axis=0) # action for each agent
        actions = tf.reshape(actions, (1,1) + actions.shape) # add time and batch dim

        self(None, state, actions, actions) # __call__ with dummy inputs

    def __call__(self, observations, states, agent_actions, other_actions, stop_other_actions_gradient=True):
        """Forward pass of critic network.

        observations [T,B,N,O]
        states [T,B,S]
        agent_actions [T,B,N,A]: the actions the agent took.
        other_actions [T,B,N,A]: the actions the other agents took.
        """
        if stop_other_actions_gradient:
            other_actions = tf.stop_gradient(other_actions)

        # Make joint action
        joint_actions = make_joint_action(agent_actions, other_actions)

        # Repeat states for each agent
        states = tf.stack([states]*self.N, axis=2) # [T,B,S] -> [T,B,N,S]

        # Concat states and joint actions
        critic_input = tf.concat([states, joint_actions], axis=-1)

        # Concat agent IDs to critic input
        critic_input = batch_concat_agent_id_to_obs(critic_input)

        q_values = self._critic_network(critic_input)
        return q_values


def make_joint_action(agent_actions, other_actions):
    """Method to construct the joint action.
    
    agent_actions [T,B,N,A]: tensor of actions the agent took. Usually
        the actions from the learnt policy network.
    other_actions [[T,B,N,A]]: tensor of actions the agent took. Usually
        the actions from the replay buffer.
    """
    N = agent_actions.shape[2]
    joint_actions = []
    for i in range(N):
        if N > 2:
            if i > 0 and i < N - 1:
                joint_action = tf.concat(
                    [
                        other_actions[:, :, :i],
                        tf.expand_dims(agent_actions[:, :, i], axis=2),
                        other_actions[:, :, i + 1 :],
                    ],
                    axis=2,  # along agent dim
                )
            elif i == 0:
                joint_action = tf.concat(
                    [
                        tf.expand_dims(agent_actions[:, :, i], axis=2),
                        other_actions[:, :, i + 1 :],
                    ],
                    axis=2,  # along agent dim
                )
            else:
                joint_action = tf.concat(
                    [
                        other_actions[:, :, :i],
                        tf.expand_dims(agent_actions[:, :, i], axis=2),
                    ],
                    axis=2,  # along agent dim
                )
        elif N == 2:
            if i == 0:
                joint_action = tf.concat(
                    [
                        tf.expand_dims(agent_actions[:, :, i], axis=2),
                        tf.expand_dims(other_actions[:, :, i + 1], axis=2),
                    ],
                    axis=2,  # along agent dim
                )
            else:
                joint_action = tf.concat(
                    [
                        tf.expand_dims(other_actions[:, :, i], axis=2),
                        tf.expand_dims(agent_actions[:, :, i], axis=2),
                    ],
                    axis=2,  # along agent dim
                )
        else:
            joint_action = agent_actions

        joint_action = tf.reshape(
            joint_action,
            (
                *joint_action.shape[:2],
                joint_action.shape[2] * joint_action.shape[3],
            ),
        )

        joint_actions.append(joint_action)
    joint_actions = tf.stack(joint_actions, axis=2)

    return joint_actions