"""
Modified from https://arxiv.org/pdf/2406.08993
"""

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn.resolver import activation_resolver

from .conv import GCNConv, GraphSAGEConv
from .graph import GraphBuilder
from .utils import BinaryConcrete

from typing import Optional, Dict, Any, Union, Tuple, Final
    

class MPNNs(torch.nn.Module):
    supports_norm_batch: Final[bool]
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, 
                 graph_builder_kwargs: Optional[Dict[str, Any]] = None,
                 dropout=0.5, heads=1, pre_ln=False, pre_linear=False, res=False,
                 ignore_norm_params=False, act='relu',
                 act_kwargs: Optional[Dict[str, Any]] = None,
                 ln=False, bn=False, jk=False,
                 bin_conc_kwargs: Optional[Dict[str, Any]] = None,
                 gnn='gcn'):
        super(MPNNs, self).__init__()

        if gnn not in ['gcn', 'graphsage']:
            raise NotImplementedError(f"Not implemented for {gnn} GNN type. ")
        
        self._ignore_norm_params = ignore_norm_params

        # Graph builder
        self.graph_builder = GraphBuilder(**(graph_builder_kwargs or {}))
        # binary concrete for discrete sampling
        self.bin_conc = BinaryConcrete(**(bin_conc_kwargs or {}))

        self.dropout = dropout
        self.pre_ln = pre_ln

        self.pre_linear = pre_linear
        self.res = res
        self.ln = ln
        self.bn = bn
        self.jk = jk
        
        self.act = activation_resolver(act, **(act_kwargs or {}))   
        self.h_lins = torch.nn.ModuleList()
        self.local_convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        self.layer_norms = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        if self.pre_ln:
            self.pre_lns = torch.nn.ModuleList()

        if self.pre_linear:
            self.lin_in = torch.nn.Linear(in_channels, hidden_channels)
        
        if not self.pre_linear:
            if gnn=='gat':
                pass    
            elif gnn=='graphsage':
                self.local_convs.append(GraphSAGEConv(in_channels, hidden_channels))
            else:
                self.local_convs.append(GCNConv(in_channels, hidden_channels))
            if self.res:
                self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
            if self.ln:
                self.layer_norms.append(torch.nn.LayerNorm(hidden_channels))
            if self.bn:
                self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
            if self.pre_ln:
                self.pre_lns.append(torch.nn.LayerNorm(in_channels))
            num_layers = num_layers - 1
            
        for _ in range(num_layers):
            if gnn=='gat':
                pass
                # self.local_convs.append(GATConv(hidden_channels, hidden_channels, heads=heads,
                #     concat=True, add_self_loops=False, bias=False))
            elif gnn=='graphsage':
                self.local_convs.append(GraphSAGEConv(hidden_channels, hidden_channels))
            else:
                self.local_convs.append(GCNConv(hidden_channels, hidden_channels))
            if self.res:
                self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
            if self.ln:
                self.layer_norms.append(torch.nn.LayerNorm(hidden_channels))
            if self.bn:
                self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
            if self.pre_ln:
                self.pre_lns.append(torch.nn.LayerNorm(hidden_channels))
                
        self.pred_local = torch.nn.Linear(hidden_channels, out_channels)

    def reset_parameters(self):
        for local_conv in self.local_convs:
            local_conv.reset_parameters()
        for lin in self.lins:
            lin.reset_parameters()
        for ln in self.layer_norms:
            ln.reset_parameters()
        for bn in self.batch_norms:
            bn.reset_parameters()
        if self.pre_ln:
            for p_ln in self.pre_lns:
                p_ln.reset_parameters()
        self.lin_in.reset_parameters()
        self.pred_local.reset_parameters()
    
    def get_discrete_graph(self, return_discrete_prob_adj: bool=False) -> \
            Union[torch.sparse_coo_tensor,
                  Tuple[torch.sparse_coo_tensor, torch.sparse_coo_tensor]]:
        edge_index, prob_vals = self.graph_builder()
        # discretize adjacency using hard binary concrete sampling
        discrete_sample_vals = self.bin_conc.hard_binary_concrete(prob_vals)
        mask = discrete_sample_vals > 0
        edge_index = edge_index[:, mask]
        discrete_sample_vals = discrete_sample_vals[mask]
        # symmetrize
        discrete_adj = torch.sparse_coo_tensor(
                torch.cat([edge_index, edge_index.flip(0)], dim=1),
                torch.cat([discrete_sample_vals, discrete_sample_vals]),
                self.graph_builder.shape
            ).coalesce()
        if return_discrete_prob_adj:  # return the edge probabilities
            sample_prob_vals = prob_vals[mask]  # probabilities of sampled edges in discrete_adj
            discrete_prob_adj = torch.sparse_coo_tensor(
                torch.cat([edge_index, edge_index.flip(0)], dim=1),
                torch.cat([sample_prob_vals, sample_prob_vals]),
                self.graph_builder.shape
            ).coalesce()
            return discrete_adj, discrete_prob_adj
        return discrete_adj
    
    @property
    def ignore_norm_params(self) -> bool:
        return self._ignore_norm_params
    
    @ignore_norm_params.setter
    def ignore_norm_params(self, value: bool):
        if value is True:
            print("Ignore normalization parameters")
        self._ignore_norm_params = value

    def parameters(self, recurse: bool = True):
        """
        Overwrite learnable parameters to exclude graph_builder.
        Exclude graph_builder parameters for Hessian calculation.
        """
        for name, param in self.named_parameters(recurse=recurse):
            if self.ignore_norm_params and 'norms' in name:
                continue
            if 'graph_builder' not in name:
                yield param
    
    def modules(self):
        """
        Overwrite modules to exclude graph_builder.
        For Hessian calculation with ASDL library.
        """
        for name, module in self.named_modules():
            if 'graph_builder' not in name:
                yield module
    
    def zero_grad(self, set_to_none: bool = False):
        for p in self.parameters():
            if p.grad is not None and not set_to_none:
                p.grad.zero_()
            else:
                p.grad = None
        for p in self.graph_builder.parameters():
            if p.grad is not None and not set_to_none:
                p.grad.zero_()
            else:
                p.grad = None
    
    def disable_graph_builder_grad(self):
        for p in self.graph_builder.parameters():
            p.requires_grad = False
    
    def enable_graph_builder_grad(self):
        for p in self.graph_builder.parameters():
            p.requires_grad = True
    
    def forward(self, x):
        
        self.discrete_adj, self.discrete_adj_prob = \
            self.get_discrete_graph(return_discrete_prob_adj=True)

        if self.pre_linear:
            x = self.lin_in(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        x_final = 0
        
        for i, local_conv in enumerate(self.local_convs):
            if self.res:
                x = local_conv(x, self.discrete_adj) + self.lins[i](x)
            else:
                x = local_conv(x, self.discrete_adj)
            if self.ln:
                x = self.layer_norms[i](x)
            elif self.bn:
                x = self.batch_norms[i](x)
            else:
                pass

            x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if self.jk:
                x_final = x_final + x
            else:
                x_final = x

        x = self.pred_local(x_final)
        return x

