import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, MLP
from torch_geometric.graphgym.register import register_head, pooling_dict

from torch_geometric.utils import to_dense_batch

from graphgym.utils import eigen_mask, virtual_node_mask

def _pad_and_stack(x1: torch.Tensor, x2: torch.Tensor, pad1: int, pad2: int):
    padded_x1 = nn.functional.pad(x1, (0, pad2))
    padded_x2 = nn.functional.pad(x2, (pad1, 0))
    return torch.vstack([padded_x1, padded_x2])



def concat_forward(linear, h, batch, max_nodes):
    # h:    [total_nodes,  F]
    # batch:[total_nodes]  with values in 0…B−1
    

    # 1) pack into a dense (padded) batch + mask
    h_dense, mask = to_dense_batch(h, batch, max_num_nodes= max_nodes)  
    #   h_dense: [B, max_nodes,  F]
    #     mask: [B, max_nodes]  (True where a real node lives)

    B, N, F_in = h_dense.size()

    # 2) flatten and run through the linear
    flat = h_dense.view(B, N * F_in)                   # [B, N*F_in]
    out_flat = linear(flat)                            # [B, out_dim]
    
    # if your linear outputs per‐node features instead of per‐graph,
    # you can do something like:
    # out_flat = out_flat.view(B, N, out_feat)

    # 3) unravel and re‑mask
    # out_flat.view(B, N, out_feat)  → [B, N, C]
    # mask                            → [B, N]
    # final_out                       → [total_nodes, C]
    final_out = out_flat.view(B, N, -1)[mask]

    return final_out





def _apply_index(batch, virtual_node: bool, pad_node: int, pad_graph: int):
    
    graph_pred, graph_true = batch.graph_feature, batch.y_graph
    node_pred, node_true = batch.node_feature, batch.y



    if virtual_node:
        # Remove virtual node
        idx = torch.concat([
            torch.where(batch.batch == i)[0][:-1]
            for i in range(batch.batch.max().item() + 1)
        ])
        node_pred, node_true = node_pred[idx], node_true[idx]

    # Stack node predictions on top of graph predictions and pad with zeros
    pred = _pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)
    true = _pad_and_stack(node_true, graph_true, pad_node, pad_graph)


    return pred, true


@register_head('inductive_hybrid')
class GNNInductiveHybridHead(nn.Module):
    """
    GNN prediction head for inductive node and graph prediction tasks.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. Not used. Use share.num_node_targets
            and share.num_graph_targets instead.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.node_target_dim = cfg.share.num_node_targets
        self.graph_target_dim = cfg.share.num_graph_targets
        self.virtual_node = cfg.virtual_node
        num_layers = cfg.gnn.layers_post_mp

        self.node_post_mp = MLP(
            new_layer_config(dim_in, self.node_target_dim, num_layers,
                             has_act=False, has_bias=True, cfg=cfg))

        self.graph_pooling = pooling_dict[cfg.model.graph_pooling]
        self.graph_post_mp = MLP(
            new_layer_config(dim_in, self.graph_target_dim, num_layers,
                             has_act=False, has_bias=True, cfg=cfg))

    def forward(self, batch):
        batch.node_feature = self.node_post_mp(batch.x)
        graph_emb = self.graph_pooling(batch.x, batch.batch)
        batch.graph_feature = self.graph_post_mp(graph_emb)
        return _apply_index(batch, self.virtual_node, self.node_target_dim,
                            self.graph_target_dim)


# MODIFIED FOR EIGENVECTOR-LEARNING
@register_head('inductive_hybrid_multi')
class GNNInductiveHybridMultiHead(nn.Module):
    """
    GNN prediction head for inductive node and graph prediction tasks using
    individual MLP for each task.

    Args:
        dim_in (int): Input dimension
        dim_out (int): Output dimension. Not used. Use share.num_node_targets
            and share.num_graph_targets instead.
    """

    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.node_target_dim = cfg.share.num_node_targets

        
        
        self.graph_target_dim = cfg.share.num_graph_targets
        self.virtual_node = cfg.virtual_node
        num_layers = cfg.gnn.layers_post_mp

        layer_config = new_layer_config(dim_in, 1, num_layers,
                                        has_act=False, has_bias=True, cfg=cfg)
        if cfg.gnn.multi_head_dim_inner is not None:
            layer_config.dim_inner = cfg.gnn.multi_head_dim_inner
        self.node_post_mps = nn.ModuleList([MLP(layer_config) for _ in
                                            range(self.node_target_dim)]) # this is an individual MLP PER dimension, not per PSE 
                            
        
        self.graph_pooling = pooling_dict[cfg.model.graph_pooling]
        self.graph_post_mp = MLP(
            new_layer_config(dim_in, self.graph_target_dim, num_layers,
                             has_act=False, has_bias=True, cfg=cfg))

        if cfg.posenc_LapPE.separated: # TODO later: also allow for a concat MLP here 
            self.eigvec_dim_out = cfg.posenc_LapPE.eigen.max_freqs

            if cfg.gnn.layers_post_mp_eigvec != None:
                eigvec_num_layers = cfg.gnn.layers_post_mp_eigvec
            else:
                eigvec_num_layers = num_layers

            # self.mixed_MLP_style = False
            # self.per_node_custom_dim_inner_list = None
            # self.concat_custom_dim_inner_list = None
            # self.per_node_custom_num_layers = None 
            # self.concat_custom_num_layers = None


            # if cfg.posenc_LapPE.MLP_style == "custom": # some combination of per_node layers -> concat layers
                
            #     self.eigvec_node_post_mps = nn.ModuleList

            if cfg.posenc_LapPE.MLP_style == "per_node_per_dim":
                print("Using per-node per-dim MLP for eigvec prediction")
                eigvec_layer_config = new_layer_config(dim_in, 1, eigvec_num_layers,
                                        has_act=False, has_bias=True, cfg=cfg)
                                        
                if cfg.gnn.eigvec_head_dim_inner is not None:
                    eigvec_layer_config.dim_inner = cfg.gnn.eigvec_head_dim_inner

                self.eigvec_node_post_mps = nn.ModuleList([MLP(eigvec_layer_config) for _ in
                                            range(self.eigvec_dim_out)])

            elif cfg.posenc_LapPE.MLP_style == "per_node":
                print("Using per-node MLP for eigvec prediction")
                eigvec_layer_config = new_layer_config(dim_in, self.eigvec_dim_out, eigvec_num_layers,
                                        has_act=False, has_bias=True, cfg=cfg)

                if cfg.gnn.eigvec_head_dim_inner is not None:
                    eigvec_layer_config.dim_inner = cfg.gnn.eigvec_head_dim_inner 

                
                self.eigvec_node_post_mps = nn.ModuleList([MLP(eigvec_layer_config) for _ in
                                            range(1)])
            elif cfg.posenc_LapPE.MLP_style == "concat":
                self.max_nodes = cfg.posenc_LapPE.concat_max_nodes 
                self.eigvec_pre_concat_mp = None
                print("Using concat MLP for eigvec prediction. Max nodes to pad to:", self.max_nodes)
                
                if cfg.gnn.pre_concat_head_dim_out != None:
                    eigvec_dim_in = cfg.gnn.pre_concat_head_dim_out
                else:
                    eigvec_dim_in = dim_in

                eigvec_layer_config = new_layer_config(eigvec_dim_in * self.max_nodes, self.eigvec_dim_out * self.max_nodes, eigvec_num_layers,
                                        has_act=False, has_bias=True, cfg=cfg)

                if cfg.gnn.eigvec_head_dim_inner is not None:
                    eigvec_layer_config.dim_inner = cfg.gnn.eigvec_head_dim_inner 

                if cfg.gnn.pre_concat_head_dim_out != None:
                    pre_concat_num_layers = 1
                

                    if cfg.gnn.pre_concat_head_dim_inner != None:
                        pre_concat_num_layers = 2

                    pre_concat_config = new_layer_config(dim_in, cfg.gnn.pre_concat_head_dim_out, pre_concat_num_layers,
                                        has_act=False, has_bias=True, cfg=cfg)
                    
                    if cfg.gnn.pre_concat_head_dim_inner != None:
                        pre_concat_config.dim_inner = cfg.gnn.pre_concat_head_dim_inner

                    if cfg.gnn.pre_concat_head_has_act:
                        self.eigvec_pre_concat_mp = nn.Sequential(MLP(pre_concat_config), nn.ReLU())
                    else:
                        self.eigvec_pre_concat_mp = MLP(pre_concat_config)



              

                self.eigvec_node_post_mps = MLP(eigvec_layer_config) 
            else:
                raise Exception("invalid MLP style")

    def forward(self, batch):
            
        
        batch.node_feature = torch.hstack([m(batch.x)
                                           for m in self.node_post_mps])
        
        graph_emb = self.graph_pooling(batch.x, batch.batch)
        batch.graph_feature = self.graph_post_mp(graph_emb)

        if cfg.posenc_LapPE.separated:
            
            keep = eigen_mask(cfg, batch.batch) 


            x = batch.x[keep]
            batch_idx = batch.batch[keep]
            batch.EigVals = batch.EigVals[keep]
            if cfg.posenc_LapPE.MLP_style == "concat":
                if self.eigvec_pre_concat_mp != None:
                    x = self.eigvec_pre_concat_mp(x)
                batch.eigvec_feature = concat_forward(self.eigvec_node_post_mps, x, batch_idx, self.max_nodes)
            else:
                batch.eigvec_feature = torch.hstack([m(x)
                                            for m in self.eigvec_node_post_mps])
            

            if cfg.posenc_LapPE.use_base_loss_for_eigval: 

                unpadded_eigvec_feature = torch.zeros(keep.shape[0], batch.eigvec_feature.shape[1]).to(batch.x.device)
                unpadded_eigvec_feature[keep] = batch.eigvec_feature
                batch.EigVecs[~keep] = 0
                print(batch.EigVecs)
                batch.node_feature = torch.hstack([batch.node_feature, unpadded_eigvec_feature])
                batch.y = torch.hstack([batch.y, batch.EigVecs])
                print("node", batch.node_feature.shape)
                print("y", batch.y.shape)
                print(self.node_target_dim)
                print(self.graph_target_dim)

            if cfg.posenc_LapPE.use_base_loss_for_eigval: 
                pred, true = _apply_index(batch, self.virtual_node, self.node_target_dim + self.eigvec_dim_out,
                                self.graph_target_dim)
            else:
                pred, true = _apply_index(batch, self.virtual_node, self.node_target_dim,
                                self.graph_target_dim)
            


            eigvec_pred, eigval_true = batch.eigvec_feature, batch.EigVals
            
            keep2 = virtual_node_mask(cfg, batch_idx)
                
            eigvec_pred, eigval_true = eigvec_pred[keep2], eigval_true[keep2]

            if cfg.posenc_LapPE.use_random_outputs:
                eigvec_pred = torch.rand_like(eigvec_pred)

            return pred, true, eigvec_pred, eigval_true
        else: 
            return _apply_index(batch, self.virtual_node, self.node_target_dim,
                                self.graph_target_dim)
        


        
