import copy
import inspect

import torch.nn as nn
import torch

from typing import (
    Union, Optional, Dict, List,
    Any, Callable, Final, Tuple
)
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from torch_geometric.nn.models import JumpingKnowledge

from .conv import GCNConv, GraphSAGEConv
from .graph import GraphBuilder
from .utils import BinaryConcrete


class BasicGNN(nn.Module):
    """
    Based on pytorch-geometric BasicGNN and GCN.
    """
    supports_norm_batch: Final[bool]

    def __init__(self,
                 in_channels: int,
                 hidden_channels: int,
                 out_channels: int,
                 num_layers: int,
                 graph_builder_kwargs: Optional[Dict[str, Any]] = None,
                 dropout: float = 0.,
                 act: Union[str, Callable, None] = "relu",
                 act_kwargs: Optional[Dict[str, Any]] = None,
                 norm: Optional[str] = None,
                 norm_kwargs: Optional[Dict[str, Any]] = None,
                 jk: Optional[str] = None,
                 res: bool = False,
                 ignore_norm_params: bool = False,
                 bin_conc_kwargs: Optional[Dict[str, Any]] = None,
                 **kwargs,
                 ):
        """
        Parameters
        ----------
        in_channels : int
            Number of input features.
        hidden_channels : int
            Number of hidden features.
        out_channels : int
            Number of output features.
        num_layers : int
            Number of layers.
        dropout : float
            Dropout probability.
        act : str or callable or None
            Activation function. Default is 'relu'.
        act_kwargs : dict or None
            Additional arguments for activation function.
        norm : str or None
            Normalization type (layer or batch). Default is None.
        jk : str or None
            Jumping knowledge type (cat, max, lstm). Default is None.
        ignore_norm_params: bool
            Ignore normalization parameters for Hessian calculation.
        graph_builder_kwargs : dict or None
            Additional arguments for graph builder.
        discrete_func : str or None
            Discrete function type. Default is None.
        discrete_func_kwargs : dict or None
            Additional arguments for discrete function.
        top_k : int or None
            Number of top k edges to keep. Default is None (use all edges).
        """
        super(BasicGNN, self).__init__()

        assert norm in [None, "layer", "batch"], \
            f"Unknown normalization type: {norm}"
        assert jk in [None, "cat", "max", "lstm"], \
            f"Unknown jumping knowledge type: {jk}"

        self._ignore_norm_params = ignore_norm_params

        # Graph builder
        self.graph_builder = GraphBuilder(**(graph_builder_kwargs or {}))
        
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.output_size = self.out_channels = out_channels
        self.num_layers = num_layers

        self.dropout = nn.Dropout(p=dropout)
        self.act = activation_resolver(act, **(act_kwargs or {}))        

        self.discrete_adj = None  # store sampled discrete adjacency

        self.res = res
        self.res_lins = nn.ModuleList()
        self.convs = nn.ModuleList()
        if num_layers > 1:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if res:
                self.res_lins.append(
                    nn.Linear(in_channels, hidden_channels))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        for _ in range(num_layers - 2):
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if res:
                self.res_lins.append(
                    nn.Linear(in_channels, hidden_channels))
            if isinstance(in_channels, (tuple, list)):
                in_channels = (hidden_channels, hidden_channels)
            else:
                in_channels = hidden_channels
        if jk is None:
            self._is_conv_to_out = True
            self.convs.append(
                self.init_conv(in_channels, out_channels, **kwargs))
            if res:
                self.res_lins.append(
                        nn.Linear(in_channels, out_channels))
        else:
            self.convs.append(
                self.init_conv(in_channels, hidden_channels, **kwargs))
            if res:
                self.res_lins.append(
                    nn.Linear(in_channels, hidden_channels))

        # norms
        self.norms = nn.ModuleList()
        norm_layer = normalization_resolver(
            norm,
            hidden_channels,
            **(norm_kwargs or {}),
        )
        if norm_layer is None:
            norm_layer = nn.Identity()

        # binary concrete for discrete sampling
        self.bin_conc = BinaryConcrete(**(bin_conc_kwargs or {}))

        self.supports_norm_batch = False
        if hasattr(norm_layer, 'forward'):
            norm_params = inspect.signature(norm_layer.forward).parameters
            self.supports_norm_batch = 'batch' in norm_params
        
        for _ in range(num_layers - 1):
            self.norms.append(copy.deepcopy(norm_layer))
        
        if jk is not None:
            self.norms.append(copy.deepcopy(norm_layer))
        else:
            self.norms.append(torch.nn.Identity())
        
        # Jumping knowledge
        self.jk_mode = jk
        if jk is not None and jk != 'last':
            self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)
        
        if jk is not None:
            if jk == 'cat':
                in_channels = num_layers * hidden_channels
            else:
                in_channels = hidden_channels
            self.lin = nn.Linear(in_channels, self.out_channels)  # last layer
            
    def reset_parameters(self):
        """
        Reset all learnable parameters.
        """
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            if hasattr(norm, 'reset_parameters'):
                norm.reset_parameters()
        if hasattr(self, 'jk'):
            self.jk.reset_parameters()
        if hasattr(self, 'lin'):
            self.lin.reset_parameters()
        
    def get_discrete_graph(self, return_discrete_prob_adj: bool=False) -> \
            Union[torch.sparse_coo_tensor,
                  Tuple[torch.sparse_coo_tensor, torch.sparse_coo_tensor]]:
        edge_index, prob_vals = self.graph_builder()
        # discretize adjacency using hard binary concrete sampling
        discrete_sample_vals = self.bin_conc.hard_binary_concrete(prob_vals)
        mask = discrete_sample_vals > 0
        edge_index = edge_index[:, mask]
        discrete_sample_vals = discrete_sample_vals[mask]
        # symmetrize
        discrete_adj = torch.sparse_coo_tensor(
                torch.cat([edge_index, edge_index.flip(0)], dim=1),
                torch.cat([discrete_sample_vals, discrete_sample_vals]),
                self.graph_builder.shape
            ).coalesce()
        if return_discrete_prob_adj:  # return the edge probabilities
            sample_prob_vals = prob_vals[mask]  # probabilities of sampled edges in discrete_adj
            discrete_prob_adj = torch.sparse_coo_tensor(
                torch.cat([edge_index, edge_index.flip(0)], dim=1),
                torch.cat([sample_prob_vals, sample_prob_vals]),
                self.graph_builder.shape
            ).coalesce()
            return discrete_adj, discrete_prob_adj
        return discrete_adj
    
    @property
    def ignore_norm_params(self) -> bool:
        return self._ignore_norm_params
    
    @ignore_norm_params.setter
    def ignore_norm_params(self, value: bool):
        if value is True:
            print("Ignore normalization parameters")
        self._ignore_norm_params = value
    
    def parameters(self, recurse: bool = True):
        """
        Overwrite learnable parameters to exclude graph_builder.
        Exclude graph_builder parameters for Hessian calculation.
        """
        for name, param in self.named_parameters(recurse=recurse):
            if self.ignore_norm_params and 'norms' in name:
                continue
            if 'graph_builder' not in name:
                yield param
    
    def modules(self):
        """
        Overwrite modules to exclude graph_builder.
        For Hessian calculation with ASDL library.
        """
        for name, module in self.named_modules():
            if 'graph_builder' not in name:
                yield module
    
    def zero_grad(self, set_to_none: bool = False):
        for p in self.parameters():
            if p.grad is not None and not set_to_none:
                p.grad.zero_()
            else:
                p.grad = None
        for p in self.graph_builder.parameters():
            if p.grad is not None and not set_to_none:
                p.grad.zero_()
            else:
                p.grad = None
    
    def disable_graph_builder_grad(self):
        for p in self.graph_builder.parameters():
            p.requires_grad = False
    
    def enable_graph_builder_grad(self):
        for p in self.graph_builder.parameters():
            p.requires_grad = True
    
    def init_conv(self, in_channels: Union[int, Tuple[int, int]],
                  out_channels: int, **kwargs):
        raise NotImplementedError
    
    def forward(self, x: torch.Tensor):
        assert len(self.convs) == len(self.norms)

        self.discrete_adj, self.discrete_adj_prob = \
            self.get_discrete_graph(return_discrete_prob_adj=True)


        xs: List[torch.Tensor] = []
        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):
            if self.res:
                x = conv(x, self.discrete_adj) + \
                    self.res_lins[i](x)
            else:
                x = conv(x, self.discrete_adj)
            if i < self.num_layers - 1 or self.jk_mode is not None:
                x = norm(x)
                x = self.act(x)
                x = self.dropout(x)
                if hasattr(self, 'jk'):
                    xs.append(x)
        
        x = self.jk(xs) if hasattr(self, 'jk') else x
        x = self.lin(x) if hasattr(self, 'lin') else x
        return x

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_layers={self.num_layers})')


class GCN(BasicGNN):
    """
    Graph Convolutional Networks from the `"Semi-supervised
    Classification with Graph Convolutional Networks"
    <https://arxiv.org/abs/1609.02907>`_ paper. 
    """
    supports_norm_batch: Final[bool]
    
    def init_conv(self, in_channels, out_channels, **kwargs) -> nn.Module:
        return GCNConv(in_channels, out_channels, **kwargs)

class GraphSAGE(BasicGNN):
    """
    GraphSAGE from the `"Inductive Representation Learning on Large
    Graphs" <https://arxiv.org/abs/1706.02216>`_ paper.
    """
    supports_norm_batch: Final[bool]
    
    def init_conv(self, in_channels, out_channels, **kwargs) -> nn.Module:
        return GraphSAGEConv(in_channels, out_channels, **kwargs)