import torch
import torch_geometric
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)
from torch_geometric.graphgym.register import register_network
import numpy as np 
from torch.nn import Tanh
from  torch_geometric.nn.conv import GATConv, SAGEConv, GCNConv, TransformerConv
from torch_geometric.nn.pool import global_mean_pool,SAGPooling
from torch.nn import Linear
from torch_geometric.utils import to_dense_adj
from grit.encoder.correlation_matrix import CorrelationMatrix, CorrelationMatrixBatched


class Param_encoder(torch.nn.Module):
    def __init__(self,conv_layer : torch_geometric.nn.conv, max_dim : int, 
                 hidden_channels : int, k : int, num_layers_encoder : int, 
                 layer_norm : bool, batch_norm : bool) -> torch.Tensor:

        super().__init__()
        self.num_layers = num_layers_encoder
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleDict()
        self.conv_layer = conv_layer
        self.k = k 
        self.layer_norm = layer_norm
        self.batch_norm = batch_norm
        #self.pool = SAGPooling(3*k, ratio = 1 )

        if layer_norm : 
            self.ln = torch.nn.LayerNorm(k)
        if batch_norm : 
            self.bn = torch.nn.BatchNorm1d(k)
        if conv_layer == TransformerConv:
            conv = conv_layer(max_dim, hidden_channels,concat = False)
            self.convs.append(conv)
            for i in range(1,num_layers_encoder-1):
                conv = conv_layer(hidden_channels, hidden_channels,concat = False)
                self.convs.append(conv)
            conv = conv_layer(hidden_channels, hidden_channels,concat = False)
            self.convs.append(conv)     
        else : 
            conv = conv_layer(max_dim, hidden_channels)
            self.convs.append(conv)
            for i in range(1,num_layers_encoder-1):
                conv = conv_layer(hidden_channels, hidden_channels)
                self.convs.append(conv)
            conv = conv_layer(hidden_channels, hidden_channels)
            self.convs.append(conv)        
        self.lin = Linear(hidden_channels, 3*k)
        #self.final_act = Tanh()


    def forward(self, data_batch):
        x = data_batch.dist.float()
        for i in range(self.num_layers):
            if self.conv_layer in [GATConv,TransformerConv]:
                x = self.convs[i](x , data_batch.edge_index , data_batch.edge_attr)
            else : 
                x = self.convs[i](x , data_batch.edge_index)
        x = self.lin(x.float())



        P = global_mean_pool(x,data_batch.batch)
        #P = self.pool(x.float(), data_batch.edge_index,
        #               data_batch.edge_attr,data_batch.batch)[0]


        correlation = CorrelationMatrixBatched(P, self.k, x.device)
        data_batch = correlation.compute_correlation(data_batch)

        if self.batch_norm : 
            data_batch.qcorr = self.bn(data_batch.qcorr)
            data_batch.qcorr_val = self.bn(data_batch.qcorr_val)

        if self.layer_norm : 
            data_batch.qcorr = self.ln(data_batch.qcorr)
            data_batch.qcorr_val = self.ln(data_batch.qcorr_val)

        return data_batch


class FeatureEncoder(torch.nn.Module):
    """
    Encoding node and edge features

    Args:
        dim_in (int): Input feature dimension
    """
    def __init__(self, dim_in):
        super(FeatureEncoder, self).__init__()
        self.dim_in = dim_in
        if cfg.dataset.node_encoder:
            # Encode integer node features via nn.Embeddings
            NodeEncoder = register.node_encoder_dict[
                cfg.dataset.node_encoder_name]
            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
            if cfg.dataset.node_encoder_bn:
                self.node_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
                                     has_bias=False, cfg=cfg))
            # Update dim_in to reflect the new dimension fo the node features
            self.dim_in = cfg.gnn.dim_inner
        if cfg.dataset.edge_encoder:
            # Hard-limit max edge dim for PNA.
            if 'PNA' in cfg.gt.layer_type:
                cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
            else:
                cfg.gnn.dim_edge = cfg.gnn.dim_inner
            # Encode integer edge features via nn.Embeddings
            EdgeEncoder = register.edge_encoder_dict[
                cfg.dataset.edge_encoder_name]
            self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
            if cfg.dataset.edge_encoder_bn:
                self.edge_encoder_bn = BatchNorm1dNode(
                    new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False,
                                     has_bias=False, cfg=cfg))

    def forward(self, batch):
        for module in self.children():
            batch = module(batch)
        return batch


@register_network('GritTransformerQuantum')
class GritTransformerQuantum(torch.nn.Module):
    '''
        The Quantum-positional encoding version of GritTransformer 
    '''

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.encoder = FeatureEncoder(dim_in)
        dim_in = self.encoder.dim_in

        self.ablation = True
        self.ablation = False


        if cfg.gnn.param_encoder_layer == 'GAT':
            conv_layer = GATConv
        elif cfg.gnn.param_encoder_layer == 'SAGE':
            conv_layer = SAGEConv
        elif cfg.gnn.param_encoder_layer == 'GCN':
            conv_layer = GCNConv
        elif cfg.gnn.param_encoder_layer == 'Transformer':
            conv_layer = TransformerConv
        else: 
            raise ValueError('The convolution layer chosen is not implemented')
        
        self.param_encoder = Param_encoder(conv_layer, cfg.dataset.max_dim, cfg.posenc_QCorr.param_hidden,
                                            cfg.posenc_QCorr.n_quantum, cfg.posenc_QCorr.param_num_layer,
                                            cfg.posenc_QCorr.layer_norm, cfg.posenc_QCorr.batch_norm )
        
        if cfg.posenc_RRWP.enable:
            self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\
                (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner)
            rel_pe_dim = cfg.posenc_RRWP.ksteps
            self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \
                (rel_pe_dim, cfg.gnn.dim_edge,
                 pad_to_full_graph=cfg.gt.attn.full_attn,
                 add_node_attr_as_self_loop=False,
                 fill_value=0.
                 )

        if cfg.gnn.layers_pre_mp > 0:
            self.pre_mp = GNNPreMP(
                dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
            dim_in = cfg.gnn.dim_inner

        assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
            "The inner and hidden dims must match."

        global_model_type = cfg.gt.get('layer_type', "GritTransformer")

        TransformerLayer = register.layer_dict.get(global_model_type)
        
        layers = []
        for l in range(cfg.gt.layers):
            layers.append(TransformerLayer(
                in_dim=cfg.gt.dim_hidden,
                out_dim=cfg.gt.dim_hidden,
                num_heads=cfg.gt.n_heads,
                dropout=cfg.gt.dropout,
                act=cfg.gnn.act,
                attn_dropout=cfg.gt.attn_dropout,
                layer_norm=cfg.gt.layer_norm,
                batch_norm=cfg.gt.batch_norm,
                residual=True,
                norm_e=cfg.gt.attn.norm_e,
                O_e=cfg.gt.attn.O_e,
                cfg=cfg.gt,
            ))

        # layers = []

        self.layers = torch.nn.Sequential(*layers)

        #device = cfg.accelerator # "cuda:0"

        GNNHead = register.head_dict[cfg.gnn.head]
        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)

    def forward(self, batch):

        for module in self.children():
            batch = module(batch)
        return batch







