from abc import ABC, abstractmethod

import torch
from torch import nn

from typing import List
from torch._vmap_internals import vmap
from ...utils.torch_utils import generate_fully_connected


class ContractiveInvertibleGNN(nn.Module):
    """
    Given x, we can easily compute the exog noise z that generates it.
    """

    def __init__(
        self, input_dim: int, device: torch.device, mode_f_sem: str = "linear",
    ):
        """
        Args:
            input_dim: Number of nodes.
            device: Device used.
            mode_f_sem: Mode used for function. Admits {"linear", "lrelu", "gnn_i"}. The first one
                        is just a linear function, the second leaky relu. The third described in pdf.
        """
        super().__init__()
        self.input_dim = input_dim
        self._device = device
        self.mode_f_sem = mode_f_sem
        self.W = self._initialize_W()
        self.f = self._initialize_function()

    def _initialize_function(self):
        """
        Initializes functions.
        """
        if self.mode_f_sem == "linear":
            return FLinear(self.input_dim, self._device)
        if self.mode_f_sem == "lrelu":
            return FLRelu(self.input_dim, self._device)
        if self.mode_f_sem == "gnn_i":
            return FGNNI(self.input_dim, self._device)
        raise NotImplementedError()

    def _initialize_W(self) -> torch.Tensor:
        """
        Creates and initializes the weight matrix for adjacency.

        Returns:
            Matrix of size (input_dim, input_dim) initialized with zeros.

        Question: Initialize to zeros?
        """
        W = torch.zeros(self.input_dim, self.input_dim, device=self._device)
        return nn.Parameter(W, requires_grad=True)

    def get_weighted_adjacency(self) -> torch.Tensor:
        """
        Returns the weights of the adjacency matrix.
        """
        W_adj = self.W * (1.0 - torch.eye(self.input_dim, device=self._device))  # Shape (input_dim, input_dim)
        return W_adj

    def invert_GNN(self, X: torch.Tensor, W_adj: torch.Tensor) -> torch.Tensor:
        """
        Given the output X of the fixed point equation, computes the exogenous noise Z that
        generates it.

        Args:
            X: Output of the GNN after reaching a fixed point, batched. Array of size (batch_size, input_dim).
            W_adj: Weighted adjacency matrix, possibly normalized.
        
        Returns:
            Z: Exogenous noise vector, batched, of size (B, n) that generated the output X.
        """
        aux = self.f.feed_forward(X, W_adj)  # Shape (batch_size, input_dim)
        return X - aux  # Shape (batch_size, input_dim)


class FunctionSEM(ABC, nn.Module):
    """
    Function SEM. Defines the (possibly nonlinear) function f for the additive noise SEM.
    """

    def __init__(
        self, input_dim: int, device: torch.device,
    ):
        """
        Args:
            input_dim: Number of nodes.
            device: Device used.
        """
        super().__init__()
        self.input_dim = input_dim
        self._device = device

    @abstractmethod
    def feed_forward(self, X: torch.Tensor, W_adj: torch.Tensor) -> torch.Tensor:
        """
        Computes function f(X) using the given weighted adjacency matrix.

        Args:
            X: Batched inputs, size (B, n).
            W_adj: Weighted adjacency matrix, size (n, n).
        """
        raise NotImplementedError()

    def initialize_embeddings(self) -> torch.Tensor:
        """
        Initialize the node embeddings.
        """
        aux = torch.randn(self.input_dim, self.embedding_size, device=self._device) * 0.01  # (N, E)
        return nn.Parameter(aux, requires_grad=True)


class FLinear(FunctionSEM):
    """
    Defines the function f for the linear SEM.
    """

    def __init__(
        self, input_dim: int, device: torch.device,
    ):
        """
        Args:
            input_dim: Number of nodes.
            device: Device used.
        """
        super().__init__(input_dim, device)

    def feed_forward(self, X: torch.Tensor, W_adj: torch.Tensor) -> torch.Tensor:
        """
        Computes linear function W^T X using the given weighted adjacency matrix.

        Args:
            X: Batched inputs, size (B, n).
            W_adj: Weighted adjacency matrix, size (n, n).
        """
        return vmap(torch.mv, in_dims=(None, 0))(W_adj.t(), X)  # Shape (batch_size, input_dim)


class FLRelu(FunctionSEM):
    """
    Defines the function f as a leaky relu for the SEM.
    """

    def __init__(
        self, input_dim: int, device: torch.device,
    ):
        """
        Args:
            input_dim: Number of nodes.
            device: Device used.
        """
        super().__init__(input_dim, device)
        self.f = nn.LeakyReLU(0.1)

    def feed_forward(self, X: torch.Tensor, W_adj: torch.Tensor) -> torch.Tensor:
        """
        Computes non-linear function lrelu(W^T X) using the given weighted adjacency matrix.

        Args:
            X: Batched inputs, size (B, n).
            W_adj: Weighted adjacency matrix, size (n, n).
        """
        aux = vmap(torch.mv, in_dims=(None, 0))(W_adj.t(), X)  # Shape (batch_size, input_dim)
        return self.f(aux)


class FGNNI(FunctionSEM):
    """
    Defines the function f for the SEM. For each variable x_i we use
    f_i(x) = f(e_i, sum_{k in pa(i)} g(e_k, x_k)), where e_i is a learned embedding
    for node i.
    """

    def __init__(
        self,
        input_dim: int,
        device: torch.device,
        embedding_size: int = None,
        out_dim_g: int = None,
        layers_g: List[int] = None,
        layers_f: List[int] = None,
    ):
        """
        Args:
            input_dim: Number of nodes.
            device: Device used.
            embedding_size: Size of the embeddings used by each node. If none, default is input input_dim.
            out_dim_g: Output dimension of the "inner" NN, g. If none, default is embedding size.
            layers_g: Size of the layers of NN g. Does not include input not output dim. If none, default
                      is [a], with a = max(2 * input_dim, embedding_size, 10).
            layers_f: Size of the layers of NN f. Does not include input nor output dim. If none, default
                      is [a], with a = max(2 * input_dim, embedding_size, 10)
        """
        super().__init__(input_dim, device)
        # Initialize embeddings
        self.embedding_size = embedding_size or self.input_dim
        self.embeddings = self.initialize_embeddings()  # Shape (input_dim, embedding_size)
        # Set value for out_dim_g
        out_dim_g = out_dim_g or self.embedding_size
        # Set NNs sizes
        a = max(2 * self.input_dim, self.embedding_size, 10)
        layers_g = layers_g or [a]
        layers_f = layers_f or [a]
        in_dim_g = self.embedding_size + 1
        in_dim_f = self.embedding_size + out_dim_g
        self.g = generate_fully_connected(
            input_dim=in_dim_g,
            output_dim=out_dim_g,
            hidden_dims=layers_g,
            non_linearity=nn.LeakyReLU,
            activation=nn.Tanh,
            device=self._device,
        )
        self.f = generate_fully_connected(
            input_dim=in_dim_f,
            output_dim=1,
            hidden_dims=layers_f,
            non_linearity=nn.LeakyReLU,
            activation=nn.Identity,
            device=self._device,
        )

    def feed_forward(self, X: torch.Tensor, W_adj: torch.Tensor) -> torch.Tensor:
        """
        Computes non-linear function f(X, W) using the given weighted adjacency matrix.

        Args:
            X: Batched inputs, size (batch_size, input_dim).
            W_adj: Weighted adjacency matrix, size (input_dim, input_dim).
        """

        def sum_aggr(Wt: torch.Tensor, X_emb: torch.Tensor) -> torch.Tensor:
            """
            Applies aggregation operation sum with parents.

            Args:
                Wt: Transpose of adjacency W, shape (imput_dim, input_dim)
                X_emb: Batched embedded data, shape (batch_size, input_dim, out_dim_g)

            Returns:
                Sum aggregation following adjacency W, of size (batch_size, input_dim, out_dim_g)
            """
            return vmap(torch.mm, in_dims=(None, 0))(W_adj.t(), X_emb)  # Shape (batch_size, input_dim, out_dim_g)

        # g takes inputs of size (*, embedding_size+1) and outputs (*, out_dim_g)
        # f takes inputs of size (*, embedding_size+out_dim_g) and outputs (*, 1)
        # Generate required input for g (concatenate X and embeddings)
        X = X.unsqueeze(-1)  # Shape (batch_size, input_dim, 1)
        E = self.embeddings.unsqueeze(0)  # Shape (1, input_dim, embedding_size)
        E = torch.repeat_interleave(E, X.shape[0], dim=0)  # Shape (batch_size, input_dim, embedding_size)
        X_in_g = torch.cat([X, E], dim=2)  # Shape (batch_size, input_dim, embedding_size+1)
        X_emb = self.g(X_in_g)  # Shape (batch_size, input_dim, out_dim_g)
        # Aggregate sum and generate input for f (concatenate X_aggr and embeddings)
        X_aggr_sum = sum_aggr(W_adj.t(), X_emb)  # Shape (batch_size, input_dim, out_dim_g)
        X_in_f = torch.cat([X_aggr_sum, E], dim=2)  # Shape (batch_size, input_dim, out_dim_g+embedding_size)
        # Run f
        X_rec = self.f(X_in_f)  # Shape (batch_size, input_dim, 1)
        return X_rec.squeeze(-1)  # Shape (batch_size, input_dim)
