import torch
import torch.nn as nn
from torch.distributions.normal import Normal


class RelationRouter(nn.Module):
    """Call a Sparsely gated mixture of experts layer with graph convolutional layers as experts
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(
        self,
        input_size: int,
        output_size: int,
        experts: nn.ModuleList,
        device: str = "gpu",
        noisy_gating: bool = True,
        k: int = 1,
    ) -> None:

        super(RelationRouter, self).__init__()

        self.k = k
        self.experts = experts
        self.input_size = input_size
        self.output_size = output_size
        self.device = device
        self.num_experts = len(experts)
        self.noisy_gating = noisy_gating

        if self.k > self.num_experts:
            self.k = self.num_experts

        self.w_gate = nn.Parameter(
            torch.zeros(input_size, self.num_experts, device=torch.device(device)),
            requires_grad=True,
        )

        self.w_noise = nn.Parameter(
            torch.zeros(input_size, self.num_experts, device=torch.device(device)),
            requires_grad=True,
        )

        self.softmax = nn.Softmax(1)
        self.softplus = nn.Softplus()
        self.normalize = nn.LayerNorm(input_size)
        self.register_buffer("std", torch.tensor([1.0]))
        self.register_buffer("mean", torch.tensor([0.0]))

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        Args:
        x: a `Tensor`.
        Returns:
        a `Scalar`.
        """
        eps = 1e-10
        # if only num_experts = 1
        if x.shape[0] == 1:
            return torch.tensor([0], device=x.device, dtype=x.dtype)
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        The load is the number of examples for which the corresponding gate is >0.
        Args:
        gates: a `Tensor` of shape [batch_size, n]
        Returns:
        a float32 `Tensor` of shape [n]
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(
        self, clean_values, noisy_values, noise_stddev, noisy_top_values
    ):
        """Helper function to NoisyTopKGating.
        Computes the probability that value is in top k, given different random noise.
        This gives us a way of backpropagating from a loss that balances the number
        of times each expert is in the top k experts per example.
        In the case of no noise, pass in None for noise_stddev, and the result will
        not be differentiable.
        Args:
        clean_values: a `Tensor` of shape [batch, n].
        noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
          normally distributed noise with standard deviation noise_stddev.
        noise_stddev: a `Tensor` of shape [batch, n], or None
        noisy_top_values: a `Tensor` of shape [batch, m].
           "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
        Returns:
        a `Tensor` of shape [batch, n].
        """

        threshold_if_in, threshold_if_out = (
            noisy_top_values[:, self.k : self.k + 1],
            noisy_top_values[:, self.k - 1 : self.k],
        )

        is_in = noisy_values > threshold_if_in

        # is each value currently in the top k.
        normal = Normal(self.mean.to(self.accelerator), self.std.to(self.accelerator))
        prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2):
        """Noisy top-k gating.
        See paper: https://arxiv.org/abs/1701.06538.
        Args:
          x: input Tensor with shape [batch_size, input_size]
          train: a boolean - we only add noise at training time.
          noise_epsilon: a float
        Returns:
          gates: a Tensor with shape [batch_size, num_experts]
          load: a Tensor with shape [num_experts]
        """

        clean_logits = x @ self.w_gate
        # print("CLEAN LOGITS", self.softmax(clean_logits))

        if self.noisy_gating and train:
            raw_noise_stddev = x @ self.w_noise
            noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
            # print("stdev", noise_stddev)
            noisy_logits = clean_logits + (
                torch.randn_like(clean_logits) * noise_stddev
            )
            logits = noisy_logits
        else:
            logits = clean_logits

        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1)
        top_k_logits = top_logits[:, : self.k]
        top_k_indices = top_indices[:, : self.k]
        top_k_gates = self.softmax(top_k_logits)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)

        if self.noisy_gating and self.k < self.num_experts and train:
            load = (
                self._prob_in_top_k(
                    clean_logits, noisy_logits, noise_stddev, top_logits
                )
            ).sum(0)
        else:
            load = self._gates_to_load(gates)
        return gates, load

    def forward(
        self,
        x_dict,
        edge_index_dict,
        src,
        training,
        i,
        noise_epsilon,
        loss_coef=1e-2,
    ):
        """Args:
        x: tensor shape [number_of_cells, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses

        Returns:
        y: a tensor with shape [number_of_cells, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        """

        x = self.normalize(x_dict[src].to(self.accelerator))

        x_pooled = x.mean(dim=0, keepdim=True)

        gates, load = self.noisy_top_k_gating(x_pooled, training, noise_epsilon)

        # print(f"Gates distribution (sample) at layer {i}", gates)

        # if i == 6:
        # print(f"Selected experts (non-zero gates): {torch.nonzero(gates > 0)}")

        if self.k < self.num_experts:
            loss = self.cv_squared(load)
        else:
            loss = self.cv_squared(gates.sum(0))

        loss *= loss_coef

        gates = gates.squeeze()

        # Compute outputs for the selected experts only
        selected_expert_outputs = [
            self.experts[idx](x_dict, edge_index_dict)
            for idx in torch.nonzero(gates > 0)
        ]

        stacked = torch.stack(selected_expert_outputs, dim=1)
        y = gates[gates > 0].unsqueeze(dim=-1) * stacked
        y = y.mean(dim=1)

        return y, loss
