import copy
from .basic_controller import BasicMAC
import torch as th
import torch.nn as nn
import numpy as np
import contextlib
import itertools
import torch_scatter as ts
from math import factorial
from random import randrange
from einops import repeat as rep
from einops import rearrange as rea
from utils.rl_utils import ncr
from utils.mst_constructor import MSTConstructor

from modules.agents import REGISTRY as agent_REGISTRY
from modules.action_encoders import REGISTRY as action_encoder_REGISTRY


class DeepCoordinationGraphMAC(BasicMAC):
    """ Multi-agent controller for a Deep Coordination Graph (DCG, Boehmer et al., 2020)"""

    # ================================ Constructors ===================================================================

    def __init__(self, scheme, groups, args):
        super().__init__(scheme, groups, args)
        self.n_actions = args.n_actions
        self.payoff_rank = args.cg_payoff_rank
        self.payoff_decomposition = isinstance(self.payoff_rank, int) and self.payoff_rank > 0
        self.iterations = args.msg_iterations
        self.normalized = args.msg_normalized
        self.anytime = args.msg_anytime

        # action representation
        self.use_action_repr = args.use_action_repr
        if self.use_action_repr:
            self.action_encoder = action_encoder_REGISTRY[args.action_encoder](args)
            self.action_repr = th.ones(self.n_actions, self.args.action_latent_dim).to(args.device)
            input_i = self.action_repr.unsqueeze(1).repeat(1, self.n_actions, 1)
            input_j = self.action_repr.unsqueeze(0).repeat(self.n_actions, 1, 1)
            self.p_action_repr = th.cat([input_i, input_j], dim=-1).view(self.n_actions * self.n_actions,
                                                                         -1).t().unsqueeze(0)

        # Create neural networks for utilities and payoff functions
        self.utility_fun = self._mlp(self.args.rnn_hidden_dim, args.cg_utilities_hidden_dim, self.n_actions)
        if self.args.rho_formulation:
            input_shape = self._get_input_shape(scheme)
            self.pair_encoder = agent_REGISTRY[self.args.pair_encoder](input_shape, self.args, pairwise=True)
            self.pair_hidden_states = None
            payoff_in = 2 * self.args.pair_encoder_hidden_dim
        else:
            payoff_in = 2 * self.args.rnn_hidden_dim
        payoff_out = 2 * self.payoff_rank * self.n_actions if self.payoff_decomposition else self.n_actions ** 2
        if self.use_action_repr:
            self.payoff_fun = self._mlp(2 * self.args.rnn_hidden_dim, args.cg_payoffs_hidden_dim,
                                        2 * self.args.action_latent_dim)
        else:
            self.payoff_fun = self._mlp(payoff_in, args.cg_payoffs_hidden_dim, payoff_out)
        # Create the edge information of the CG
        self._set_edges(self._edge_list(args.cg_edges))
        self.edges_from = rep(self.edges_from, 'p -> b p', b=args.batch_size)
        self.edges_to = rep(self.edges_to, 'p -> b p', b=args.batch_size)
        # sparsity config
        self.sp_coeff = args.sp_coeff
        self.cyclicity_weighting = args.cyclicity_weighting
        n_pairs = ncr(self.n_agents, 2)
        self.sp_cut = args.sp_iters is not None
        if args.sp_iters == 0 and args.cyclicity_weighting:
            print("When sp_iters is set to 0, we are pruning down to spanning tree, setting sp_coeff correspondingly:")
            self.sp_coeff = 1 - 2 / self.n_agents
            print("sp_coeff set to: {}".format(self.sp_coeff))
        self.sp_iters = n_pairs - round((1 - self.sp_coeff) * n_pairs) if args.sp_iters is None else args.sp_iters
        self.EPS = th.tensor(1e-6)
        triu = th.ones(self.n_agents, self.n_agents).triu(diagonal=1).flatten(start_dim=-2)
        self.triu_ids = rep(triu.nonzero()[:, -1], 'k -> b k', b=args.batch_size)
        self.mst_constructor = MSTConstructor()

    # ================== Core Methods =============================================================================

    def resistance_deficit(self, adjacency_matrix):
        # compute resistance deficits
        n = adjacency_matrix.size(-1)
        gamma = self._compute_gamma(adjacency_matrix)
        gamma_ii = rep(th.diagonal(gamma, dim1=1, dim2=2), 'b n1 -> b n1 n2', n2=n)
        gamma_jj = rep(th.diagonal(gamma, dim1=1, dim2=2), 'b n1 -> b n2 n1', n2=n)
        effective_resistance = gamma_ii + gamma_jj - 2 * gamma
        resistance_deficit = th.max(adjacency_matrix - effective_resistance, self.EPS)
        resistance_deficit = self.triu_gather(resistance_deficit)
        return resistance_deficit

    def iterative_removal(self, q_influences):
        # if sp_iters is None -> full iteration
        # if sp_iters = 0 -> drop directly to MST
        b, n = q_influences.size(0), self.n_agents
        adjacency = th.logical_not(th.eye(n)).float().to(q_influences.device)
        adjacency = rep(adjacency, 'n1 n2 -> b n1 n2', b=b)
        chunk_size = 1
        if self.sp_cut:
            q_influences_full = self.triu_scatter(adjacency.clone(), value=q_influences)
            adjacency_tree = self.mst_constructor.construct(q_influences_full, q_influences.device)
            if self.sp_iters > 0:
                maximum_spanning_tree_mask = self.triu_gather(adjacency_tree)
                maximum_spanning_tree_ids = maximum_spanning_tree_mask.nonzero()[:, -1]
                maximum_spanning_tree_ids = rea(maximum_spanning_tree_ids, '(b x) -> b x', b=b)
                # scatter infs to mst elements
                q_influences = th.scatter(q_influences, dim=1, index=maximum_spanning_tree_ids, value=th.inf)
                chunk_size = int(self.sp_coeff * ncr(n, 2) / self.sp_iters)
            else:
                adjacency = adjacency_tree
        for j in range(self.sp_iters):
            # update the resistance distance matrix
            resistance_deficits = self.resistance_deficit(adjacency)
            # weight the q_influence scores by the cyclicity and pick the worst edge index to be pruned
            scores = q_influences / resistance_deficits
            worst_chunk_ids = th.topk(-scores, dim=-1, k=chunk_size)[1]
            # scatter inf values to sparsified edge ids to make sure their never picked again
            q_influences = th.scatter(q_influences, dim=1, index=worst_chunk_ids, value=th.inf)
            # sparsify the adjacency matrix
            adjacency = self.triu_scatter(adjacency, id=worst_chunk_ids, value=0)

        active_edge_mask = self.triu_gather(adjacency)
        return active_edge_mask

    def create_sparse_graph(self, f_ij, available_actions):
        # map input tensors to cpu, sparsification contains many small operations that dont make sense to run on gpu
        b, n, p, a = f_ij.size(0), self.n_agents, f_ij.size(1), self.n_actions
        # available actions for agents i and j
        edges_from = self.edges_from if b > 1 else rea(self.edges_from[0], 'k -> 1 k')
        edges_to = self.edges_to if b > 1 else rea(self.edges_to[0], 'k -> 1 k')
        aa_i = th.gather(available_actions, dim=1, index=rep(edges_from, 'b k -> b k a', a=a))
        aa_j = th.gather(available_actions, dim=1, index=rep(edges_to, 'b k -> b k a', a=a))
        aa_ij = rea(aa_i, 'b k a -> b k a 1') * rea(aa_j, 'b k a -> b k 1 a')
        # variance of agent j's utility as a function of i's action
        varj_of_i, vari_of_j = self._masked_variance(f_ij, aa_ij)
        # 1. influence of i on j and 2. influence of j on i
        infj_by_i, infi_by_j = varj_of_i.max(-1)[0], vari_of_j.max(-1)[0]
        # importance of each edge is proportional (equal if not cyclicity_weighting) to max of the 2 influences defined above
        q_influences = th.stack([infj_by_i, infi_by_j], dim=-1).max(-1)[0]
        # need to make sure that padded episodes have non-zero qinfs; otherwise well run into problems
        # with disconnected graphs later in resistance deficit computations
        q_influences = th.max(q_influences, self.EPS)
        # active edge mask: vector of ncr(N, 2) elements, indicating which edges are active in the set of f_ijs
        active_edge_masks = th.zeros(b, p).to(f_ij.device)
        # how many edge should we sample
        k = round((1 - self.sp_coeff) * p)
        # ids of in [0, 1, ..., E-1], where E = ncr(N, 2), N = num. agents
        topk_ids = th.topk(q_influences, dim=-1, k=k)[1]
        active_edge_masks = th.scatter(active_edge_masks, dim=-1, index=topk_ids, value=1)
        # iterative sparsification based on the graph cyclicity
        if self.cyclicity_weighting:
            active_edge_masks = self.iterative_removal(q_influences)
        active_edge_masks = rea(active_edge_masks, 'b p -> b p 1 1').long()
        return active_edge_masks

    def annotations(self, ep_batch, t, compute_grads=False, actions=None):
        """ Returns all outputs of the utility and payoff functions (Algorithm 1 in Boehmer et al., 2020). """
        with th.no_grad() if not compute_grads else contextlib.suppress():
            agent_inputs = self._build_inputs(ep_batch, t)
            hidden_states = self.agent(agent_inputs, self.hidden_states)[1]
            self.hidden_states = hidden_states.view(ep_batch.batch_size, self.n_agents, -1)
            f_i = self.utilities(self.hidden_states)
            payoff_hidden_states = self.hidden_states
            # compute the pairwise payoffs
            if self.args.rho_formulation:
                b, a = ep_batch.batch_size, self.n_actions
                pair_hidden_states = self.pair_encoder(agent_inputs, self.pair_hidden_states)[1]
                self.pair_hidden_states = pair_hidden_states.view(ep_batch.batch_size, self.n_agents, -1)
                payoff_hidden_states = self.pair_hidden_states
                rho_ij = self.payoffs(payoff_hidden_states)
                # expand f_is and f_js
                triu_ids = self.triu_ids if b > 1 else rea(self.triu_ids[0], 'p -> 1 p')
                triu_ids = rep(triu_ids, 'b p -> b p a1 a2', a1=a, a2=a)
                f_i_f_j_full = rea(f_i, 'b n a -> b n 1 a 1') + rea(f_i, 'b n a -> b 1 n 1 a')
                f_i_f_j_full = rea(f_i_f_j_full, 'b n1 n2 a1 a2 -> b (n1 n2) a1 a2')
                f_i_f_j_triu = th.gather(f_i_f_j_full, dim=1, index=triu_ids)
                # with the individual utilities expanded, we can now add them to rho to get the payoffs
                f_ij = f_i_f_j_triu + rho_ij
            else:
                f_ij = self.payoffs(payoff_hidden_states)
            # when sp_coeff non-zero, create a sparse edge mask to be used in max-sum
            if self.args.sp_coeff > 0:
                active_edge_masks = self.create_sparse_graph(f_ij, available_actions=ep_batch['avail_actions'][:, t])
            else:
                active_edge_masks = th.ones(f_ij.size(0), f_ij.size(1), 1, 1).to(f_ij.device)
        return f_i, f_ij, active_edge_masks

    def utilities(self, hidden_states):
        """ Computes the utilities for a given batch of hidden states. """
        return self.utility_fun(hidden_states)

    def payoffs(self, hidden_states):
        """ Computes all payoffs for a given batch of hidden states. """
        # Construct the inputs for all edges' payoff functions and their flipped counterparts
        n = self.n_actions
        # get rid of the batch dimension
        edges_from, edges_to = self.edges_from[0], self.edges_to[0]
        inputs = th.stack([th.cat([hidden_states[:, edges_from], hidden_states[:, edges_to]], dim=-1),
                           th.cat([hidden_states[:, edges_to], hidden_states[:, edges_from]], dim=-1)], dim=0)
        # Compute the payoff matrices for all edges (and flipped counterparts)
        if self.use_action_repr:
            key = self.payoff_fun(inputs).view(-1, len(edges_from), 2 * self.args.action_latent_dim)
            output = th.bmm(key, self.p_action_repr.repeat(key.shape[0], 1, 1)) / self.args.action_latent_dim / 2
            output = output.view(inputs.shape[0], inputs.shape[1], len(edges_from), -1)
        else:
            output = self.payoff_fun(inputs)
        if self.payoff_decomposition:
            # If the payoff matrix is decomposed, we need to de-decompose it here: ...
            dim = list(output.shape[:-1])
            # ... reshape output into left and right bases of the matrix, ...
            output = output.view(*[np.prod(dim) * self.payoff_rank, 2, n])
            # ... outer product between left and right bases, ...
            output = th.bmm(output[:, 0, :].unsqueeze(dim=-1), output[:, 1, :].unsqueeze(dim=-2))
            # ... and finally sum over the above outer products of payoff_rank base-pairs.
            output = output.view(*(dim + [self.payoff_rank, n, n])).sum(dim=-3)
        else:
            # Without decomposition, the payoff_fun output must only be reshaped
            output = output.view(*(list(output.shape[:-1]) + [n, n]))
        # The output of the backward messages must be transposed
        output[1] = output[1].clone().transpose(dim0=-2, dim1=-1)
        # Compute the symmetric average of each edge with it's flipped counterpart
        return output.mean(dim=0)

    def q_values(self, f_i, f_ij, actions):
        """ Computes the Q-values for given utilities, payoffs and actions (Algorithm 2 in Boehmer et al., 2020). """
        n_batches = actions.shape[0]
        # get rid of the batch dimension
        edges_from, edges_to = self.edges_from[0], self.edges_to[0]
        # Use the utilities for the chosen actions
        values = f_i.gather(dim=-1, index=actions).squeeze(dim=-1).mean(dim=-1)
        # Use the payoffs for the chosen actions (if the CG contains edges)
        if len(edges_from) > 0:
            n_actions = f_i.size(-1)
            f_ij = f_ij.view(n_batches, len(edges_from), n_actions * n_actions)
            edge_actions = actions.gather(dim=-2, index=edges_from.view(1, -1, 1).expand(n_batches, -1, 1)) \
                           * n_actions + actions.gather(dim=-2, index=edges_to.view(1, -1, 1).expand(n_batches, -1, 1))
            values = values + f_ij.gather(dim=-1, index=edge_actions).squeeze(dim=-1).mean(dim=-1)
        # Return the Q-values for the given actions
        return values

    def greedy(self, f_i, f_ij, aem, available_actions=None):
        # and reduce the matrix sizes according to decimation mask
        b, n, p, a = f_i.size(0), f_i.size(1), f_ij.size(1), f_i.size(-1)
        # cast to double to reduce numerical errors
        in_f_i, f_i = f_i, f_i.double() / n
        in_f_ij, f_ij = f_ij, f_ij.double() / p
        # Unavailable actions have a utility of -inf, which propagates throughout message passing
        if available_actions is not None:
            f_i = f_i.masked_fill(available_actions == 0, -float('inf'))
        # convert active edge mask into ids
        aem_id = rea(aem.nonzero()[:, 1], '(b k) -> b k', b=b)
        # reduce f_ij to only contain non-masked edges, reduces the computation required in the message passing
        f_ij = th.gather(f_ij, dim=1, index=rep(aem_id, 'b k -> b k a1 a2', a1=a, a2=a))
        # now we need to also consider which edges are still active for message passing
        edges_from = th.gather(self.edges_from, dim=1, index=aem_id)
        edges_to = th.gather(self.edges_to, dim=1, index=aem_id)
        # Initialize best seen value and actions for anytime-extension
        # Without edges (or iterations), CG would be the same as VDN: mean(f_i)
        utils = f_i.clone()
        best_value = in_f_i.new_empty(f_i.shape[0]).fill_(-float('inf'))
        best_actions = f_i.new_empty(best_value.shape[0], self.n_agents, 1, dtype=th.int64, device=f_i.device)
        # Perform message passing for self.iterations: [0] are messages to *edges_to*, [1] are messages to *edges_from*
        if self.iterations > 0 and n > 1:
            messages = f_i.new_zeros(2, b, edges_from.size(1), a)
            for iteration in range(self.iterations):
                # Recompute messages: joint utility for each edge: "sender Q-value"-"message from receiver"+payoffs/E
                gt_from = rep(edges_from, 'b k -> b k a', a=a)
                joint0 = (th.gather(utils, dim=1, index=gt_from) - messages[1]).unsqueeze(dim=-1) + f_ij
                gt_to = rep(edges_to, 'b k -> b k a', a=a)
                joint1 = (th.gather(utils, dim=1, index=gt_to) - messages[0]).unsqueeze(dim=-1) + f_ij.transpose(
                    dim0=-2, dim1=-1)
                # Maximize the joint Q-value over the action of the sender
                messages[0] = joint0.max(dim=-2)[0]
                messages[1] = joint1.max(dim=-2)[0]
                # Normalization as in Kok and Vlassis (2006) and Wainwright et al. (2004)
                messages -= messages.mean(dim=-1, keepdim=True)
                # Create the current utilities of all agents, based on the messages
                msg = ts.scatter_add(src=messages[0], index=edges_to, dim=1, dim_size=n)
                msg += ts.scatter_add(src=messages[1], index=edges_from, dim=1, dim_size=n)
                utils = f_i + msg
                if self.anytime:
                    # Find currently best actions and the (true) value of these actions
                    actions = utils.max(dim=-1, keepdim=True)[1]
                    value = self.q_values(in_f_i, in_f_ij,
                                          actions)
                    # Update best_actions only for the batches that have a higher value than best_value
                    change = value > best_value
                    best_value[change] = value[change]
                    best_actions[change] = actions[change]
        if not self.anytime:
            _, best_actions = utils.max(dim=-1, keepdim=True)
        return best_actions

    # ================== Override methods of BasicMAC to integrate DCG into PyMARL ====================================

    def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False):
        # Only select actions for the selected batch elements in bs

        avail_actions = ep_batch["avail_actions"][:, t_ep]
        agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode)
        chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env,
                                                            test_mode=test_mode)
        return chosen_actions

    def forward(self, ep_batch, t, aem_tgt=None, actions=None, policy_mode=True, test_mode=False, compute_grads=False):
        """ This is the main function that is called by learner and runner.
            If policy_mode=True,    returns the greedy policy (for controller) for the given ep_batch at time t.
            If policy_mode=False,   returns either the Q-values for given 'actions'
                                            or the actions of the greedy policy for 'actions==None'.  """
        # Get the utilities and payoffs after observing time step t
        f_i, f_ij, aem = self.annotations(ep_batch, t, compute_grads, actions)
        # We either return the values for the given batch and actions...
        if actions is not None and not policy_mode:
            values = self.q_values(f_i, f_ij, actions)
            return values
        # ... or greedy actions  ... or the computed Q-values (for the learner)
        aem_greedy = aem if aem_tgt is None else aem_tgt
        best_actions = self.greedy(f_i, f_ij, aem_greedy, available_actions=ep_batch['avail_actions'][:, t])
        if policy_mode:  # ... either as policy tensor for the runner ...
            policy = f_i.new_zeros(ep_batch.batch_size, self.n_agents, self.n_actions)
            policy.scatter_(dim=-1, index=best_actions, src=policy.new_ones(1, 1, 1).expand_as(best_actions))
            return policy
        else:  # ... or as action tensor for the learner
            return best_actions

    def init_hidden(self, batch_size):
        BasicMAC.init_hidden(self, batch_size)
        if self.args.rho_formulation:
            pair_hidden_states = self.pair_encoder.init_hidden()
            self.pair_hidden_states = rep(pair_hidden_states, '1 h -> b n h', b=batch_size, n=self.n_agents)

    def update_action_repr(self):
        action_repr = self.action_encoder()

        self.action_repr = action_repr.detach().clone()

        # Pairwise Q (|A|, al) -> (|A|, |A|, 2*al)
        input_i = self.action_repr.unsqueeze(1).repeat(1, self.n_actions, 1)
        input_j = self.action_repr.unsqueeze(0).repeat(self.n_actions, 1, 1)
        self.p_action_repr = th.cat([input_i, input_j], dim=-1).view(self.n_actions * self.n_actions, -1).t().unsqueeze(
            0)

    def cuda(self):
        """ Moves this controller to the GPU, if one exists. """
        self.agent.cuda()
        self.utility_fun.cuda()
        self.payoff_fun.cuda()
        if self.edges_from is not None:
            self.edges_from = self.edges_from.cuda()
            self.edges_to = self.edges_to.cuda()
            self.edges_n_in = self.edges_n_in.cuda()
        self.EPS = self.EPS.cuda()
        self.triu_ids = self.triu_ids.cuda()
        if self.args.rho_formulation:
            self.pair_encoder.cuda()
        if self.use_action_repr:
            self.action_encoder.cuda()

    def parameters(self):
        """ Returns a generator for all parameters of the controller. """
        param_list = [BasicMAC.parameters(self), self.utility_fun.parameters(), self.payoff_fun.parameters()]
        if self.args.rho_formulation:
            param_list.append(self.pair_encoder.parameters())
        param = itertools.chain(*param_list)
        return param

    def load_state(self, other_mac):
        """ Overwrites the parameters with those from other_mac. """
        BasicMAC.load_state(self, other_mac)
        self.utility_fun.load_state_dict(other_mac.utility_fun.state_dict())
        self.payoff_fun.load_state_dict(other_mac.payoff_fun.state_dict())
        if self.args.rho_formulation:
            self.pair_encoder.load_state_dict(other_mac.pair_encoder.state_dict())
        if self.args.use_action_repr:
            self.action_repr = copy.deepcopy(other_mac.action_repr)
            self.p_action_repr = copy.deepcopy(other_mac.p_action_repr)

    def save_models(self, path):
        """ Saves parameters to the disc. """
        BasicMAC.save_models(self, path)
        th.save(self.utility_fun.state_dict(), "{}/utilities.th".format(path))
        th.save(self.payoff_fun.state_dict(), "{}/payoffs.th".format(path))
        if self.args.rho_formulation:
            th.save(self.pair_encoder.state_dict(), "{}/pair_encoder.th".format(path))
        if self.args.use_action_repr:
            th.save(self.action_repr, "{}/action_repr.pt".format(path))
            th.save(self.p_action_repr, "{}/p_action_repr.pt".format(path))

    def load_models(self, path):
        """ Loads parameters from the disc. """
        BasicMAC.load_models(self, path)
        self.utility_fun.load_state_dict(
            th.load("{}/utilities.th".format(path), map_location=lambda storage, loc: storage))
        self.payoff_fun.load_state_dict(
            th.load("{}/payoffs.th".format(path), map_location=lambda storage, loc: storage))
        if self.args.rho_formulation:
            self.pair_encoder.load_state_dict(
                th.load("{}/pair_encoder.th".format(path), map_location=lambda storage, loc: storage))
        if self.args.use_action_repr:
            self.action_repr = th.load("{}/action_repr.pt".format(path),
                                       map_location=lambda storage, loc: storage).to(self.args.device)
            self.p_action_repr = th.load("{}/p_action_repr.pt".format(path),
                                         map_location=lambda storage, loc: storage).to(self.args.device)

    def action_encoder_params(self):
        return list(self.action_encoder.parameters())

    def action_repr_forward(self, ep_batch, t):
        return self.action_encoder.predict(ep_batch["obs"][:, t], ep_batch["actions_onehot"][:, t])

    # ================== Private helper methods ======================================================

    @staticmethod
    def _masked_variance(f_ij, aa_ij):
        f_ij_masked = f_ij * aa_ij
        denom_min = th.ones(1, ).to(f_ij.device)
        row_denominator = th.max(aa_ij.sum(dim=-1, keepdim=True), denom_min)
        col_denominator = th.max(aa_ij.sum(dim=-2, keepdim=True), denom_min)
        row_means = f_ij_masked.sum(dim=-1, keepdim=True) / row_denominator
        col_means = f_ij_masked.sum(dim=-2, keepdim=True) / col_denominator
        row_denominator = th.max(row_denominator - 1, denom_min)
        col_denominator = th.max(col_denominator - 1, denom_min)
        row_variance = ((f_ij_masked - row_means).pow(2) * aa_ij).sum(dim=-1, keepdim=True) / row_denominator
        col_variance = ((f_ij_masked - col_means).pow(2) * aa_ij).sum(dim=-2, keepdim=True) / col_denominator
        row_variance = rea(row_variance, 'b p a 1 -> b p a')
        col_variance = rea(col_variance, 'b p 1 a -> b p a')
        return row_variance, col_variance

    def triu_scatter(self, matrix, value, id=None):
        matrix = matrix.triu(diagonal=1)
        b, n = matrix.size(0), self.n_agents
        # convert the triu ids into total adjacency ids
        triu_ids = self.triu_ids if b > 1 else rea(self.triu_ids[0], 'k -> 1 k')
        id_in_total = th.gather(triu_ids, dim=-1, index=id) if id is not None else triu_ids
        matrix_flat = rea(matrix, 'b n1 n2 -> b (n1 n2)')
        if isinstance(value, th.Tensor):
            matrix_flat = th.scatter(matrix_flat, dim=-1, index=id_in_total, src=value)
        else:
            matrix_flat = th.scatter(matrix_flat, dim=-1, index=id_in_total, value=value)
        matrix = rea(matrix_flat, 'b (n1 n2) -> b n1 n2', n1=n)
        matrix = matrix + matrix.transpose(-2, -1)
        return matrix

    def triu_gather(self, matrix):
        b, n = matrix.size(0), self.n_agents
        triu_ids = self.triu_ids if b > 1 else rea(self.triu_ids[0], 'k -> 1 k')
        matrix_flat = rea(matrix, 'b n1 n2 -> b (n1 n2)')
        triu = th.gather(matrix_flat, dim=-1, index=triu_ids)
        return triu

    def _compute_gamma(self, adjacency_matrix):
        b, n = adjacency_matrix.size(0), adjacency_matrix.size(-1)
        # compute resistance deficits
        eye = th.eye(n).to(adjacency_matrix.device)
        l = eye * adjacency_matrix.sum(dim=-1, keepdims=True) - adjacency_matrix
        phi = 1 / n + th.ones_like(eye)
        gamma = th.linalg.inv(l + phi)
        return gamma

    @staticmethod
    def _mlp(input, hidden_dims, output):
        """ Creates an MLP with the specified input and output dimensions and (optional) hidden layers. """
        hidden_dims = [] if hidden_dims is None else hidden_dims
        hidden_dims = [hidden_dims] if isinstance(hidden_dims, int) else hidden_dims
        dim = input
        layers = []
        for d in hidden_dims:
            layers.append(nn.Linear(dim, d))
            layers.append(nn.ReLU())
            dim = d
        layers.append(nn.Linear(dim, output))
        return nn.Sequential(*layers)

    def _edge_list(self, arg):
        """ Specifies edges for various topologies. """
        edges = []
        wrong_arg = "Parameter cg_edges must be either a string:{'vdn','line','cycle','star','full'}, " \
                    "an int for the number of random edges (<= n_agents!), " \
                    "or a list of either int-tuple or list-with-two-int-each for direct specification."
        # Parameter cg_edges must be either a string:{'vdn','line','cycle','star','full'}, ...
        if isinstance(arg, str):
            if arg == 'full':  # fully connected CG
                edges = [[(j, i + j + 1) for i in range(self.n_agents - j - 1)] for j in range(self.n_agents - 1)]
                edges = [e for l in edges for e in l]
            else:
                assert False, wrong_arg
        # ... an int for the number of random edges (<= (n_agents-1)!), ...
        if isinstance(arg, int):
            assert 0 <= arg <= factorial(self.n_agents - 1), wrong_arg
            for i in range(arg):
                found = False
                while not found:
                    e = (randrange(self.n_agents), randrange(self.n_agents))
                    if e[0] != e[1] and e not in edges and (e[1], e[0]) not in edges:
                        edges.append(e)
                        found = True
        # ... or a list of either int-tuple or list-with-two-int-each for direct specification.
        if isinstance(arg, list):
            assert all([(isinstance(l, list) or isinstance(l, tuple))
                        and (len(l) == 2 and all([isinstance(i, int) for i in l])) for l in arg]), wrong_arg
            edges = arg
        return edges

    def _set_edges(self, edge_list):
        """ Takes a list of tuples [0..n_agents)^2 and constructs the internal CG edge representation. """
        self.edges_from = th.zeros(len(edge_list), dtype=th.long)
        self.edges_to = th.zeros(len(edge_list), dtype=th.long)
        for i, edge in enumerate(edge_list):
            self.edges_from[i] = edge[0]
            self.edges_to[i] = edge[1]
        self.edges_n_in = ts.scatter_add(src=self.edges_to.new_ones(len(self.edges_to)),
                                         index=self.edges_to, dim=0, dim_size=self.n_agents) \
                          + ts.scatter_add(src=self.edges_to.new_ones(len(self.edges_to)),
                                           index=self.edges_from, dim=0, dim_size=self.n_agents)
        self.edges_n_in = self.edges_n_in.float()

