import torch
from torch import Tensor
import torch.nn.functional as F
from typing import Callable, List, Union
from torch_geometric.typing import PairTensor
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn import (
    GINConv,
    GINEConv,
    MLP,
    DenseGINConv,
    GCNConv,
    TopKPooling,
)
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    to_torch_csr_tensor,
)
from torch_geometric.utils.repeat import repeat
from tgp.poolers import get_pooler
from tgp.src import PoolingOutput


class AutoEncoderModel(torch.nn.Module):
    """
    Autoencoder model with the following architecture:
    [MP encoder] ➡️ Pooling ⬇️ [MP bottleneck] ➡️ Unpooling ⬆️ [MP decoder] ➡️ Readout
    """

    def __init__(
        self,
        in_channels,  # Size of node features
        out_channels,  # Number of classes
        edge_channels=None,  # Size of edge features
        num_mp_layers=1,  # Number of MP layers in each block
        hidden_channels=64,  # Dimensionality of node embeddings
        activation="ELU",  # Activation of the MLP in GIN
        pooler=None,  # Pooling method
        pool_kwargs=None,  # Pooling method kwargs
        pooled_nodes=None,  # Number of nodes after pooling
        use_gine_enc=False,  # Use GINE instead of GIN in the Encoder
        use_gine_bottleneck=False,  # Use GINE instead of GIN in the bottleneck
        res_connect=None,  # Residual connections (None, 'sum', 'cat')
        dropout=0.1,  # Dropout rate
        dropout_decoder=False,  # Use dropout in the decoder
    ):
        super().__init__()

        self.num_classes = out_channels
        self.act = activation_resolver(activation)
        self.pooler = pooler
        if edge_channels is not None and use_gine_enc:
            self.using_gine_enc = True
        else:
            self.using_gine_enc = False
        self.using_gine_bottleneck = use_gine_bottleneck
        self.res_connect = res_connect
        self.dropout = dropout
        self.dropout_decoder = dropout_decoder

        ### Encoder MP block
        self.encoder_mp_layers = torch.nn.ModuleList()
        for _ in range(num_mp_layers):
            mlp = MLP(
                [in_channels, hidden_channels, hidden_channels],
                act=activation,
                dropout=dropout,
            )
            if self.using_gine_enc:
                self.encoder_mp_layers.append(
                    GINEConv(nn=mlp, train_eps=False, edge_dim=edge_channels)
                )
            else:
                self.encoder_mp_layers.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels

        self.pooler = get_pooler(
            pooler, in_channels=in_channels, k=pooled_nodes, **pool_kwargs
        )

        ### Bottleneck MP block
        self.bottleneck_mp_layers = torch.nn.ModuleList()
        for _ in range(num_mp_layers):
            mlp = MLP(
                [hidden_channels, hidden_channels, hidden_channels],
                act=activation,
                norm=None,
                dropout=dropout,
            )
            if self.pooler.is_dense:
                self.bottleneck_mp_layers.append(DenseGINConv(nn=mlp, train_eps=False))
            elif self.using_gine_bottleneck:
                self.bottleneck_mp_layers.append(GINEConv(nn=mlp, train_eps=False, edge_dim=1))
            else:
                self.bottleneck_mp_layers.append(GINConv(nn=mlp, train_eps=False))

        ### Decoder MP block
        self.decoder_mp_layers = torch.nn.ModuleList()
        in_channels = 2 * hidden_channels if res_connect == "cat" else hidden_channels
        for _ in range(num_mp_layers):
            mlp = MLP(
                [in_channels, hidden_channels, hidden_channels],
                act=activation,
                norm=None,
                dropout=dropout,
            )
            if self.pooler.is_dense:
                self.decoder_mp_layers.append(DenseGINConv(nn=mlp, train_eps=False))
            else:
                self.decoder_mp_layers.append(GINConv(nn=mlp, train_eps=False))
            in_channels = hidden_channels

        ### Readout
        self.mlp = MLP(
            [hidden_channels, out_channels], act=activation, norm=None, dropout=dropout
        )

    def forward(self, data):
        """
        ⏩
        """
        x = data.x
        adj = data.edge_index
        ea = data.edge_attr
        ew = torch.ones(adj.size(1)).to(adj.device)
        
        if ea is None:
            ea = ew
        batch = data.batch

        ### Encoder MP block ➡️
        for layer in self.encoder_mp_layers:
            if self.using_gine_enc:
                x = self.act(layer(x, adj, edge_attr=ea))
            else:
                x = self.act(layer(x, adj))

        ### Pooling block ⬇️
        x, adj, mask = self.pooler.preprocessing(
            x=x,
            edge_index=adj,
            edge_weight=ew,
            batch=batch,
            use_cache=True,
        )   

        pool_out = self.pooler(x=x, adj=adj, edge_weight=ew, batch=batch, mask=mask)
        x_pool, adj_pool = pool_out.x, pool_out.edge_index

        ### Bottleneck MP block ➡️
        for layer in self.bottleneck_mp_layers:
            if self.using_gine_bottleneck:
                # Use pooled edge attributes if available, otherwise create new ones
                ea_pool = getattr(pool_out, 'edge_weight', None)
                if ea_pool is None:
                    # Create new edge attributes for the pooled graph
                    ea_pool = torch.ones(adj_pool.size(1)).to(adj_pool.device)
                x_pool = self.act(layer(x_pool, adj_pool, edge_attr=ea_pool.unsqueeze(-1)))
            else:
                x_pool = self.act(layer(x_pool, adj_pool))

        ### Lifting (unpooling) ⬆️
        x_lift = self.pooler(x=x_pool, so=pool_out.so, lifting=True)

        ### Residual connection ⏭️
        if self.res_connect == "sum":
            x_lift = x_lift + x
        elif self.res_connect == "cat":
            x_lift = torch.cat([x_lift, x], dim=-1)

        ### Decoder MP block ➡️
        x_dec = x_lift
        for layer in self.decoder_mp_layers:
            x_dec = self.act(layer(x_dec, adj))

        ### Dropout
        if self.dropout_decoder:
            x_dec = torch.nn.functional.dropout(
                x_dec, p=self.dropout, training=self.training
            )

        ### Readout ▶️
        x_out = self.mlp(x_dec).squeeze()

        return x_out, pool_out


class GCNModel(torch.nn.Module):
    """
    The standard GCN model for node classification.
    To be used as a baseline.
    """

    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(
            in_channels,
            hidden_channels,
        )
        self.conv2 = GCNConv(
            hidden_channels,
            out_channels,
        )
        self.num_classes = out_channels

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_weight = getattr(data, "edge_attr", None)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index, edge_weight).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        if "adj_pool" not in locals():
            adj_pool = None
        if "batch" not in locals():
            batch = None
        if "edge_weight" not in locals():
            edge_weight = None
        if "aux_loss" not in locals():
            aux_loss = 0
        if "s" not in locals():
            s = None

        return x, adj_pool, edge_weight, batch, aux_loss, s


class GraphUNet(torch.nn.Module):
    r"""The Graph U-Net model from the `"Graph U-Nets"
    <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like
    architecture with graph pooling and unpooling operations.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Size of each hidden sample.
        out_channels (int): Size of each output sample.
        depth (int): The depth of the U-Net architecture.
        pool_ratios (float or [float], optional): Graph pooling ratio for each
            depth. (default: :obj:`0.5`)
        sum_res (bool, optional): If set to :obj:`False`, will use
            concatenation for integration of skip connections instead
            summation. (default: :obj:`True`)
        act (torch.nn.functional, optional): The nonlinearity to use.
            (default: :obj:`torch.nn.functional.relu`)
    """

    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        depth: int,
        pool_ratios: Union[float, List[float]] = 0.5,
        res_connect=None,  # Residual connections (None, 'sum', 'cat')
        act: Union[str, Callable] = "relu",
    ):
        super().__init__()
        assert depth >= 1
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.num_classes = out_channels
        self.depth = depth
        self.pool_ratios = repeat(pool_ratios, depth)
        self.act = activation_resolver(act)
        self.res_connect = res_connect

        channels = hidden_channels

        self.down_convs = torch.nn.ModuleList()
        self.pools = torch.nn.ModuleList()
        self.down_convs.append(GCNConv(in_channels, channels, improved=True))
        for i in range(depth):
            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))
            self.down_convs.append(GCNConv(channels, channels, improved=True))

        in_channels = 2 * channels if res_connect == "cat" else channels

        self.up_convs = torch.nn.ModuleList()
        for i in range(depth - 1):
            self.up_convs.append(GCNConv(in_channels, channels, improved=True))
        self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))

        self.reset_parameters()

    def reset_parameters(self):
        r"""Resets all learnable parameters of the module."""
        for conv in self.down_convs:
            conv.reset_parameters()
        for pool in self.pools:
            pool.reset_parameters()
        for conv in self.up_convs:
            conv.reset_parameters()

    def forward(self, data) -> Tensor:
        """"""
        x = data.x
        edge_index = data.edge_index
        batch = data.batch

        if batch is None:
            batch = edge_index.new_zeros(x.size(0))
        edge_weight = x.new_ones(edge_index.size(1))

        x = self.down_convs[0](x, edge_index, edge_weight)
        x = self.act(x)

        xs = [x]
        edge_indices = [edge_index]
        edge_weights = [edge_weight]
        perms = []

        for i in range(1, self.depth + 1):
            edge_index, edge_weight = self.augment_adj(
                edge_index, edge_weight, x.size(0)
            )
            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](
                x, edge_index, edge_weight, batch
            )

            x = self.down_convs[i](x, edge_index, edge_weight)
            x = self.act(x)

            if i < self.depth:
                xs += [x]
                edge_indices += [edge_index]
                edge_weights += [edge_weight]
            perms += [perm]

        for i in range(self.depth):
            j = self.depth - 1 - i

            res = xs[j]
            edge_index = edge_indices[j]
            edge_weight = edge_weights[j]
            perm = perms[j]

            up = torch.zeros_like(res)
            up[perm] = x

            # Residual connection
            if self.res_connect == "sum":
                x = res + up
            elif self.res_connect == "cat":
                x = torch.cat((res, up), dim=-1)
            else:
                x = up

            x = self.up_convs[i](x, edge_index, edge_weight)
            x = self.act(x) if i < self.depth - 1 else x

        if "adj_pool" not in locals():
            adj_pool = None
        if "batch" not in locals():
            batch = None
        if "edge_weight" not in locals():
            edge_weight = None
        if "aux_loss" not in locals():
            aux_loss = 0
        if "s" not in locals():
            s = None

        return x, adj_pool, edge_weight, batch, aux_loss, s

    def augment_adj(
        self, edge_index: Tensor, edge_weight: Tensor, num_nodes: int
    ) -> PairTensor:
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, edge_weight = add_self_loops(
            edge_index, edge_weight, num_nodes=num_nodes
        )
        adj = to_torch_csr_tensor(edge_index, edge_weight, size=(num_nodes, num_nodes))
        adj = (adj @ adj).to_sparse_coo()
        edge_index, edge_weight = adj.indices(), adj.values()
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}({self.in_channels}, "
            f"{self.hidden_channels}, {self.out_channels}, "
            f"depth={self.depth}, pool_ratios={self.pool_ratios})"
        )
