import numpy as np
import scipy.sparse as sp
import tensorflow as tf
from spektral.data import Loader
from spektral.data.utils import to_disjoint
from spektral.layers import ops
from spektral.utils import sp_matrix_to_sp_tensor

class MFreeLoader(Loader):
    """
    Takes as input a MFreeDataset
    """

    def __init__(self, dataset, batch_size=1, epochs=None, shuffle=False):
        self.F = dataset.n_node_features
        self.n_out = dataset.n_labels
        self.a_dtype = tf.as_dtype(dataset[0].a.dtype)
        self.a_1_dtype = tf.as_dtype(dataset[0].a_1.dtype)
        self.s_dtype = tf.as_dtype(dataset[0].s.dtype)
        self.y_dtype = tf.as_dtype(dataset[0].y.dtype)
        super().__init__(dataset, batch_size=batch_size, epochs=epochs, shuffle=shuffle)

    def _pack(self, batch):
        return [
            list(elem) for elem in zip(*[[g.x, g.a, g.a_1, g.y, g.s] for g in batch])
        ]

    def collate(self, batch):
        packed = self._pack(batch)
        x, a, _ = to_disjoint(*packed[:2])
        a_1 = sp.block_diag(packed[2])
        y = np.array(packed[3]).astype("f4")
        s = sp.block_diag(packed[4]).astype("f4")
        n_nodes = np.array([a.shape[0] for a in packed[2]])
        batch_idx = np.repeat(np.arange(len(n_nodes)), n_nodes)
        a = sp_matrix_to_sp_tensor(a.astype("f4"))
        a_1 = sp_matrix_to_sp_tensor(a_1.astype("f4"))
        s = sp_matrix_to_sp_tensor(s.astype("f4"))
        #if y.shape[0]==60:
        #    print(y)
        return (x, a, a_1, batch_idx, s), y

    def tf(self):
        pass

    def tf_signature(self):
        return (
            (
                tf.TensorSpec((None, self.F)),
                tf.SparseTensorSpec((None, None), dtype=tf.float32),
                tf.SparseTensorSpec((None, None), dtype=tf.float32),
                tf.TensorSpec((None,), dtype=tf.int64),
                tf.SparseTensorSpec((None, None), dtype=tf.float32),
            ),
            tf.TensorSpec((None, self.n_out), dtype=tf.float32),
        )
    
    #def __copy__(self):
    #    return MFreeLoader(
    #        self.dataset,
    #        batch_size=self.batch_size,
    #        epochs=self.epochs,
    #        shuffle=self.shuffle,
    #    )
    #def __deepcopy__(self, memo):
    #    return MFreeLoader(
    #        self.dataset,
    #        batch_size=self.batch_size,
    #        epochs=self.epochs,
    #        shuffle=self.shuffle,
    #    )


def downsampling(inputs):
    X, S = inputs
    return ops.modal_dot(S, X, transpose_a=True)
