from math import inf
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from ..._layers import GCN_layer, GIN_layer
from ..._utils import log_max, debug_print
from .._EPM_VAE import EPM_VAE, EPM_VAE_Loss
from .._GIN import GIN, GIN_lite
from .._GCN import GCN

global_debug_ckpt = '_ComGCN'


class Modules_E(EPM_VAE):
    def __init__(self, ft_in: int, ft_h1: int, ft_h2: int, dropout=0., bias=False, act=nn.Softplus(), *args, **kwargs):
        super(Modules_E, self).__init__(ft_in, ft_h1, ft_h2, dropout=dropout, bias=bias, act=act)
        self.tau = kwargs['tau'] if 'tau' in kwargs else 1.
        self.lite_mode = kwargs['lite_mode'] if 'lite_mode' in kwargs else False

    def infer_module(self, batch_adj, batch_fts):
        return super(Modules_E, self).encoder(batch_adj, batch_fts)


    def edge_weight(self, phir_row, phir_col):
        """ return the inner-product of all rows and corresponding cols """
        assert phir_row.shape[0] == phir_col.shape[0], f"num of rows = {phir_row.shape[0]}, num of cols = {phir_col.shape[0]}"
        phir_row, phir_col = phir_row.unsqueeze(1), phir_col.unsqueeze(2)
        return torch.bmm(phir_row, phir_col).flatten()


    def edge_part_1(self, batch_adj, G_unnorm_):
        """
        Inputs:
            batch_adj: [torch.Tensor] binary sparse adjacency matrix, arranging graphs in the batch on the main diagonal
            G_unnorm_: [list] a list of dense matrices, carrying the weights for edge partition
        Outputs:
            a list of sparse adjacency matrices, adding up to batch_adj.
        """
        indices = batch_adj._indices().cpu().numpy()
        values_unnorm_ = [G_unnorm[indices] for G_unnorm in G_unnorm_]
        values_mat = F.softmax(torch.stack(values_unnorm_, dim=0) / self.tau, dim=0)
        values_ = torch.unbind(values_mat, dim=0)

        G_norm_ = list(map(
            lambda x: torch.sparse_coo_tensor(batch_adj._indices(), x, batch_adj.shape), values_
        ))

        return G_norm_

    def edge_part_2(self, batch_adj, phir_):
        """
        Inputs:
            batch_adj: [torch.Tensor] binary sparse adjacency matrix, arranging graphs in the batch on the main diagonal
            phir_: [tuple] segments of phi, 
        Outputs:
            a list of sparse adjacency matrices, adding up to batch_adj.
        """
        row_ind, col_ind = np.vsplit(batch_adj._indices().cpu().numpy(), 2)
        row_ind, col_ind = row_ind[0], col_ind[0]
        # create a generator object to yield the phir_seg for ending nodes of all edges
        phir_ends_ = ([phir_seg[row_ind], phir_seg[col_ind]] for phir_seg in phir_)

        values_unnorm_ = [self.edge_weight(*phir_ends) for phir_ends in phir_ends_]
        values_mat = F.softmax(torch.stack(values_unnorm_, dim=0) / self.tau, dim=0)
        values_ = torch.unbind(values_mat, dim=0)

        G_norm_ = list(map(
            lambda x: torch.sparse_coo_tensor(batch_adj._indices(), x, batch_adj.shape), values_
        ))

        return G_norm_


    def epart_module(self, batch_adj, phi, num_split: int=1):
        """
        split phi into L segements, decode each segment into community-specific graphs respectively.
        """
        # create split sections: [len+1, len+1, ..., len+1, len, len, ..., len]
        rr = self.ft_h2 % num_split
        sec_len = np.ones(num_split) * (self.ft_h2 // num_split)
        sec_len += np.append(np.ones(rr), np.zeros(num_split - rr))
        sec_len = sec_len.astype(np.int64).tolist()

        phi = F.dropout(phi, self.dropout, training=self.training)
        phir = phi.mul(self.r_rtsq.unsqueeze(0))

        phir_ = phir.split(sec_len, dim=-1)

        if not self.lite_mode:
            # G_unnorm_ = list(map(lambda x: torch.mm(x, x.t()), phir_))
            # G_norm_ = self.edge_part_1(batch_adj, G_unnorm_)
            indices = batch_adj._indices().cpu().numpy()
            values_unnorm_ = list(map(lambda x: torch.mm(x, x.t())[indices], phir_))

            values_mat = F.softmax(torch.stack(values_unnorm_, dim=0) / self.tau, dim=0)
            values_ = torch.unbind(values_mat, dim=0)

            G_norm_ = list(map(
                lambda x: torch.sparse_coo_tensor(batch_adj._indices(), x, batch_adj.shape), values_
            ))

        else:
            G_norm_ = self.edge_part_2(batch_adj, phir_)
        
        return G_norm_


    def forward(self, batch_adj, batch_fts, num_split: int=1, pretrain=False, lite_mode=False, *args, **kwargs):
        phi, k, lbd = self.infer_module(batch_adj, batch_fts)
        if not lite_mode:
            preds_obj = super(Modules_E, self).decoder(phi)
        else:
            num_nodes_ = kwargs['num_nodes_']
            phi_ = phi.split(num_nodes_.tolist(), dim=0)

            # 1. padding-bmm
            # max_nodes = num_nodes_.max()
            # num_pads_ = max_nodes - num_nodes_
            # phi_dmy_ = [torch.ones(npad, self.ft_h2).to(device) for npad in num_pads_]
            # phi_pad_ = [torch.cat([ph, ph_dummy], dim=0) for (ph, ph_dummy) in zip(phi_, phi_dmy_)]
            # # create a batch of padded phi, prepare for bmm
            # phi_pad = torch.stack(phi_pad_)
            # preds_obj = super(Modules_E, self).decoder(phi_pad)

            # 2. return a list of pred
            preds_obj = [super(Modules_E, self).decoder(phi_mat) for phi_mat in phi_]
            # restore phi
            phi = torch.cat(phi_, dim=0)

        G_norm_ = self.epart_module(batch_adj, phi, num_split) if not pretrain else None

        return preds_obj, G_norm_, phi, k, lbd



class Modules_M(nn.Module):
    def __init__(self, ft_in, ft_hid, ft_out, num_layers_Bcat, num_layers_Acat, num_split:int=1, *args, **kwargs):
        """Module M for graph classification"""
        super(Modules_M, self).__init__()
        self.num_layers_Bcat = num_layers_Bcat  # num of layers running before feature concatenation
        self.num_layers_Acat = num_layers_Acat+1  # num of layers running after feature concatenation (include the catted feature)
        self.num_layers_mlps = 2
        self.ft_in = ft_in
        self.ft_hid = ft_hid
        self.ft_out = ft_out
        self.debug_ckpt = global_debug_ckpt + '.Modules_M'
        self.hid_list = self.feature_split(num_split)

        self.final_dropout = kwargs['final_dropout'] if 'final_dropout' in kwargs else 0.5
        self.aggregation = kwargs['aggregation'] if 'aggregation' in kwargs else 'sum'

        if num_layers_Bcat == 1:
            self.model_tail = GIN_lite(self.num_layers_Acat, self.num_layers_mlps, ft_in, ft_hid, ft_out,
                                   final_dropout=self.final_dropout)
        else:
            self.model_head = nn.ModuleList(
                [GIN_lite(
                    num_layers_Bcat, 
                    self.num_layers_mlps, 
                    ft_in, 
                    nh, 
                    ft_out, 
                    final_dropout=self.final_dropout, 
                    neighbor_pooling_type=self.aggregation
                ) for nh in self.hid_list]
            )
            self.model_tail = GIN_lite(self.num_layers_Acat, self.num_layers_mlps, ft_hid, ft_hid, ft_out,
                                   final_dropout=self.final_dropout, add_input_score=False)


    def feature_split(self, num_split: int):
        # create split sections: [len+1, len+1, ..., len+1, len, len, ..., len]
        local_debug_ckpt = self.debug_ckpt + '.feature_split'

        rr = self.ft_hid % num_split
        sec_len = np.ones(num_split) * (self.ft_hid // num_split)
        sec_len += np.append(np.ones(rr), np.zeros(num_split - rr))

        return sec_len.astype(np.int64).tolist()


    def forward(self, G_norm_, batch_adj, batch_fts, num_nodes_):
        local_debug_ckpt = self.debug_ckpt + '.forward'

        score_over_layers = 0.

        if self.num_layers_Bcat >= 2:
            h_com_ = []
            for (G_norm, gin) in zip(G_norm_, self.model_head):
                score, h_com = gin(G_norm, batch_fts, num_nodes_)
                score_over_layers += score
                h_com_.append(h_com)
            batch_fts = torch.cat(h_com_, dim=-1)

        score, _ = self.model_tail(batch_adj, batch_fts, num_nodes_)

        return score_over_layers + score, batch_fts


class Modules_M_node(nn.Module):
    def __init__(self, ft_in, ft_hid, ft_out, num_layers_Bcat, num_layers_Acat, num_split:int=1, *args, **kwargs):
        super(Modules_M_node, self).__init__()
        assert num_layers_Bcat + num_layers_Acat > 0, 'num_layers error: Bcat = {}, Acat = {}'.format(num_layers_Bcat, num_layers_Acat)
        self.num_layers_Bcat = num_layers_Bcat  # number of layers in model head (exclude input)
        self.num_layers_Acat = num_layers_Acat  # number of layers in model tail (exclude input)
        self.ft_in = ft_in
        self.ft_in = ft_in
        self.ft_hid = ft_hid
        self.ft_out = ft_out
        self.debug_ckpt = global_debug_ckpt + '.Modules_M_node'
        self.hid_list = self.feature_split(num_split)

        self.dropout = kwargs['dropout'] if 'dropout' in kwargs else .5
        self.act = kwargs['act'] if 'act' in kwargs else F.relu
        self.bias = kwargs['bias'] if 'bias' in kwargs else True

        if self.num_layers_Bcat == 0:
            self.model_tail = GCN(self.num_layers_Acat, self.ft_in, self.ft_hid, self.ft_out,
                dropout=self.dropout, act=self.act, bias=self.bias)
        elif self.num_layers_Acat == 0:
            self.model_head = nn.ModuleList(
                [GCN(self.num_layers_Bcat, self.ft_in, self.ft_hid, self.ft_out,
                    dropout=self.dropout, act=self.act, bias=self.bias)] * num_split
            )
        else:
            self.model_head = nn.ModuleList(
                [GCN(self.num_layers_Bcat, self.ft_in, nh, nh,
                    dropout=self.dropout, act=self.act, bias=self.bias) for nh in self.hid_list]
            )
            self.model_tail = GCN(self.num_layers_Acat, self.ft_hid, self.ft_hid, self.ft_out,
                dropout=self.dropout, act=self.act, bias=self.bias)



    def feature_split(self, num_split: int):
        # create split sections: [len+1, len+1, ..., len+1, len, len, ..., len]
        local_debug_ckpt = self.debug_ckpt + '.feature_split'

        rr = self.ft_hid % num_split
        sec_len = np.ones(num_split) * (self.ft_hid // num_split)
        sec_len += np.append(np.ones(rr), np.zeros(num_split - rr))

        return sec_len.astype(np.int64).tolist()

    
    def forward(self, G_norm_, adj, fts):
        """codgnn for node classification, feedforward. (based on vanilla gcn)"""
        h = fts
        if self.num_layers_Bcat > 0:
            h_com_ = []
            for (G_norm, gcn_head) in zip(G_norm_, self.model_head):
                h_com = gcn_head(G_norm, h)
                h_com_.append(h_com)
            
            if self.num_layers_Acat == 0:
                out = torch.stack(h_com_).sum(0)    # if there is no model tail, sum up the h_com for each model head to get output
                return out
            else:
                h = torch.cat(h_com_, dim=1)
                h = F.dropout(self.act(h), self.dropout, self.training)
        
        out = self.model_tail(adj, h)

        return out












