from peagang.data.dense.PEAWGANDenseData import PEAWGANDenseData
import torch as pt

from peagang.data.dense.utils.features import our_dom_symeig, approx_dom_symeig
from peagang.utils.utils import ensure_tensor


class PEAWGANDenseStructureData(PEAWGANDenseData):
    """
    Has access to the same datasets as PEAWGANDenseData, but replaces the node features with strictly graph-structural features.
    Currently these are (in order): node degree, the kth first entries of eigenvectors.
    Uses torch.symeig by default, with the option to switch to an approximate version for large graphs on which this would become infeasible
    """
    def __init__(self, data_dir=None, filename=None, k_eigenvals=4, dataset="CommunitySmall", print_statistics=True, remove_zero_padding=None,
                 inner_kwargs=None,use_laplacian=True,large_N_approx=False,zero_pad=True, cut_train_size=False):
        self.k_eigenvals=k_eigenvals
        self.use_laplacian=use_laplacian
        # get k smallest eig impl
        if large_N_approx:
            self.dominant_symeig= approx_dom_symeig
        else:
            self.dominant_symeig= our_dom_symeig
        super().__init__(data_dir=data_dir, filename=filename, dataset=dataset, print_statistics=print_statistics,
                         remove_zero_padding=remove_zero_padding, inner_kwargs=inner_kwargs,zero_pad=zero_pad,cut_train_size=cut_train_size)

    def __getitem__(self, idx):
        x,A= super().__getitem__(idx)
        x, A = self.get_structural_node_features(A)
        if self.zero_pad:
            assert A.shape[-1]==self.max_N
            assert A.shape[-2]==self.max_N
            assert x.shape[-2]==self.max_N
        return x,A

    def get_structural_node_features(self, A):
        """
        Extracted so we can use it in the generator as well
        :param A: [B,N,N] or [N,N]
        :param x:  [B,N,F] or [ N,F]
        :return:
        """
        tensor=pt.is_tensor(A)
        if not tensor:
            A = pt.tensor(A).float()

        batch=A.dim()==3
        if not batch:
            # simplify dim thinking by always operating in batch mode
            A=A.unsqueeze(0)
        degrees = A.sum(-1,
                        keepdims=False)  # undirected graphs, in and out degree shttps://github.com/buwantaiji/DominantSparseEigenAD/blob/master/examples/TFIM_vumps/symmetric.pyame

        # Append graph size
        num = A.max(dim=-1).values.sum(
            dim=-1
        )# [B]
        node_mask = (
            pt.ones(A.shape[0], A.shape[1],1, device=A.device).cumsum(dim=1)<=num
        ).float().max(-1,keepdim=True).values
        N_app = node_mask * num.reshape(-1,1,1).repeat([1,A.shape[1],1])

        if self.use_laplacian:
            D = pt.diag_embed(ensure_tensor(degrees))
            L = D - A
            k_eigenval, k_eigen_feat = self.dominant_symeig(L, self.k_eigenvals, end_offset=1)

        else:
            negA = -A  # since we are after the *largest* eigenvalue, see
            k_eigenval, k_eigen_feat = self.dominant_symeig(negA, self.k_eigenvals)

        # torch tensor as anu_graphs features are arrays
        x = pt.cat([ensure_tensor(k_eigen_feat), ensure_tensor(degrees).unsqueeze(-1)], dim=-1)

        # append graph_size
        x = pt.cat([x, N_app], dim=-1)
        N_pad = self.max_N - x.shape[-2]
        if not batch and self.zero_pad and N_pad>0:
            x=pt.nn.functional.pad(x,(0,0,0,N_pad),"constant",0)
            A = pt.nn.functional.pad(A, (0,N_pad,0,N_pad), "constant", 0)
            assert A.shape[-1]==self.max_N
        if not batch and A.dim()==3:
            A=A[0]
            x=x[0]
        if not tensor:
            x=x.numpy()
            A=A.numpy()
        return x, A


if __name__=="__main__":
    ds=PEAWGANDenseStructureData(dataset="egonet",inner_kwargs={
        "num_graphs":20
    })
    As=[x[1] for x in ds]
    xs=[x[0] for x in ds]
