import torch, sys, numpy as np, pytorch_lightning as pl, networkx as nx
from os import path
from typing import Optional, Dict
import matplotlib.pyplot as plt

from pytorch_lightning.utilities.types import EVAL_DATALOADERS
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2SmoothData = path2project + 'data/smooth_signals/'
path2BrainData = path2project + 'data/brain_data/'
path2DiffData = path2project + 'data/network_diffusion/'
path2DiffCreatedData = path2DiffData + 'created_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from utils import correlation_from_covariance, matrix_polynomial, normalize_slices, edge_density
from unroll.glad.matrix_sqrt_utils import scipy_sqrtm


def analytic_summary_stats(S, coeffs, filters=None):
    assert len(S.shape) == 3, f'must have batch, not single matrix'

    if filters is None:
        filters = matrix_polynomial(S, coeffs)
    # Cov = H^2 when diffusing white signals over graph with Graph Filter
    cov_analytic = torch.bmm(filters, filters)
    corr_analytic = correlation_from_covariance(cov_analytic)

    return cov_analytic, corr_analytic


def sample_white_signals(num_vertices, num_matrices, signals_per_matrix, dtype=torch.float32, which='single'):
    zero_means = torch.zeros(num_vertices, dtype=dtype)
    I = torch.eye(num_vertices, dtype=dtype)

    if which == 'single':
        # option 1: sample all from same zero mean, identity covariance distribution
        mvns = torch.distributions.MultivariateNormal(zero_means, covariance_matrix=I)
        # shape: num_matrices x signals_per_matrix x square_matrix_size. Ex) 1064 x 500 x 68
        #   -> rows are sampled vectors
        white_signals = mvns.sample((num_matrices, signals_per_matrix)).transpose(dim0=-1, dim1=-2)
    elif which == 'many':
        # option 2: sample from num_vertices different zero mean, identity covariance distributions
        mvns = torch.distributions.MultivariateNormal(loc=torch.zeros(num_matrices, num_vertices),
                                                      covariance_matrix=I.expand(num_matrices, num_vertices,
                                                                                 num_vertices))
        # signals.shape = [num_graphs, nodes, num_signals]
        white_signals = mvns.sample([signals_per_matrix]).transpose(dim0=0, dim1=1).transpose(dim0=-1, dim1=-2)
    else:
        raise ValueError('uneexpected arg in white signals sampling')

    # want shape: num_matrices x square_matrix_size x signals_per_matrix. Ex) 1064 x 68 x 500
    #   -> cols are sampled vectors
    return white_signals


# filter of white signals for which GLASSO is MLE estimate of
def glasso_filter(s):
    assert s.ndim == 3
    num_matrices, n = s.shape[:2]

    min_eigs = torch.linalg.eigvalsh(s).min(dim=1)[0]
    # if minimum eigenvalue is less 1, set it to 1
    delta = torch.where(min_eigs < 1, (1.0-min_eigs), torch.zeros_like(min_eigs))
    K = delta.view(num_matrices, 1, 1)*torch.eye(n).expand(num_matrices, n, n) + s
    min_eig_k = torch.linalg.eigvalsh(K).min(dim=1)[0].min()
    assert torch.all(min_eig_k > (1.0 - 1e-3)), f'must be PSD to invert. min_eig(S) =  {min_eigs.min():.3f}. min_eig(dI+S) {min_eig_k:.3f}'

    # H = (dI + S)^-1/2
    filters = torch.inverse(scipy_sqrtm(K)[0])
    # sanity check: K = filters^-2 = (K^-1/2)^-2
    # torch.allclose(torch.inverse(filters.bmm(filters)), K, atol=1e-4)
    return filters


def diffusion_summary_stat(S, coeffs, num_signals, sum_stat='sample_cov', dtype=torch.float32, use_gauss_mle=False):
    # given adjs S and coefficients coeffs, sample white signals and diffuse them using graph filter H(S[i];coeffs),
    # then take diffused signals, and compute their covariance
    assert sum_stat in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']
    
    if S.ndim == 2:
        S = S.unsqueeze(dim=0)
    num_matrices, n = S.shape[:2]

    # construct graph filters
    filters = glasso_filter(S) if use_gauss_mle else matrix_polynomial(S, coeffs)

    # compute cov/corr by diffusing signals or computing analytically
    if sum_stat == 'analytic_cov':
        # Cov = H^2
        return torch.bmm(filters, filters)
    elif sum_stat == 'analytic_corr':
        return correlation_from_covariance(torch.bmm(filters, filters))
    else: #sample_cov/corr
        white_signals = sample_white_signals(num_vertices=n, num_matrices=num_matrices,
                                             signals_per_matrix=num_signals, dtype=dtype, which='single')
        diffused_signals = torch.bmm(filters, white_signals)
        covariances = torch.cat([torch.cov(diffused_signals[i], correction=0).unsqueeze(dim=0) for i in range(len(diffused_signals))], dim=0)
        return covariances if (sum_stat == 'sample_cov') else correlation_from_covariance(covariances)


# Inputs: G :: networkx graph representation, N:: int number of vertices
# Output  A :: numpy array of adjacency matrix, L :: numpy array of laplacian matrix
def create_GSO(G):
    A = np.array(nx.to_numpy_matrix(G))
    D = np.diag(np.sum(A, axis=0))
    """
    ones    = np.ones(N, dtype=int)
    degrees = np.matmul(A, ones)
    degrees = np.resize(degrees, (N,)) # make 1D
    D       = np.diag(degrees)
    """
    L = D - A

    return A, L


def sbm_constructor(num_vertices, num_communities, p_in, p_out):
    assert 0 <= p_out <= p_in <= 1, f'invalid p_in, p_out = {p_in}, {p_out}'
    # assert num_vertices%num_communities == 0, f'Number of specified communites {num_communities} must evenly divide num_vertices {num_vertices} (for now)'
    sizes = [int(num_vertices / num_communities)] * num_communities
    if not (num_vertices % num_communities == 0):
        # num_vertices not perfectly divisible by num_communites. Distribute leftover vertices as evenly as
        # possible
        for i in range(num_vertices - sum(sizes)):
            # add one to each size until no more leftover
            j = i % len(sizes)
            sizes[j] += 1

    assert sum(sizes) == num_vertices
    size_diff = [[abs(size - other_size) <= 1 for other_size in sizes] for size in sizes]
    assert all(size_diff), f'difference bewtween community sizes can be at most one, {sizes}'

    # matrix of edge connection probs
    blocks = [torch.ones(size, size)*p_in for size in sizes]
    prob_matrix = torch.block_diag(*blocks)
    prob_matrix[prob_matrix==0] = p_out
    return sizes, prob_matrix


def connected_sparse_graph_sampling(params, max_attempts=20):
    # Inputs: N:: int number of vertices, r::L2 distance to use for edge
    # creation in graph, dim :: dim to use in graph creation, sparse_thresh ::
    # float cutoff for % of edges allowed to exist in G
    # Output  G :: networkx graph representation
    # -- seed assumed to be set before this --
    num_vertices, graph_sampling = params['num_vertices'], params['graph_sampling']
    low_density, high_density = params['edge_density_low'], params['edge_density_high']
    assert num_vertices > 2, f'number of nodes {num_vertices} must be >3 for sensical graph'
    attempts = 1
    edge_densities = []
    while True:
        G = {}
        if graph_sampling == 'ER':
            G = nx.fast_gnp_random_graph(n=num_vertices, p=params['p'])
        elif graph_sampling == 'geom':
            G = nx.random_geometric_graph(n=num_vertices, radius=params['r'], dim=params['dim'])
        elif graph_sampling == 'pref_attach' or graph_sampling == 'BA':
            # num_vertices = 68, m = 28 -> sparsity of ~1/2
            G = nx.barabasi_albert_graph(n=num_vertices, m=params['m'])#, seed=attempts if attempts>1 else None)
        elif 'sbm' in graph_sampling:
            num_communities, p_in, p_out = params['num_communities'], params['p_in'], params['p_out']
            sizes, prob_matrix = sbm_constructor(num_vertices, num_communities, p_in, p_out)
            probs = []
            for i in range(num_communities):
                probs_from_i = [(p_in if i == j else p_out) for j in range(num_communities)]
                probs.append(probs_from_i)
            G = nx.stochastic_block_model(sizes=sizes, p=probs,
                                          nodelist=range(sum(sizes)), # This should ensure consistent node labeling
                                          directed=False, selfloops=False)
        else:
            input(f"connected_sparse_gen: No valid graph generator given ({graph_sampling}). Exit.")

        connected = nx.is_connected(G)
        self_loops = any([G.has_edge(i, i) for i in range(num_vertices)])
        ed = edge_density(torch.tensor(nx.to_numpy_matrix(G)).unsqueeze(dim=0))
        edge_densities.append(ed.item())
        sparse = low_density <= ed <= high_density
        if connected and sparse and (not self_loops):
            return G, attempts, ed
        elif graph_sampling == 'pref_attach':
            raise ValueError(f'Preferential Attachment model was not able to meet sparsity requirements. Either m is too low, or it is not possible."')
        #print(f'attempt {attempts}: connected? {connected}, sparsity: {sparsity}')
        attempts += 1
        if attempts > max_attempts:
            #print(f'\tfailed after {attempts} attempts, with connected?({connected} & edge density of {ed}')
            raise ValueError(f"Sampled > {max_attempts} graphs without connectivity/sparsity constraints satisfied. Ave edge density {sum(edge_densities)/len(edge_densities):.3f}. Range [{low_density:.3f}, {high_density:.3f}]. Adjust parameters.")


class DiffusionDataset(Dataset):
    """Diffused white signals over a sampled graph using graph filter H then
            compute summary statistics (Cov, Corr)
        Input: Covariance Cx
        Label: GSO S
    """

    def __init__(self,
                 num_signals: int = 50,
                 num_samples: int = 963,
                 gso: str = 'adjacency',
                 label: str = 'adjacency',
                 #label_norm: str = 'max_abs',
                 label_norm: Dict = {'normalization': 'min_eig', 'min_eig': 1.0},
                 normal_mle: bool = False,
                 dtype=torch.float32,
                 sum_stat: str = 'sample_cov',
                 sum_stat_norm=None,
                 sum_stat_norm_val=None,
                 graph_sampling_params: Dict = {},
                 #r: float = 0.56,
                 coeffs: np.ndarray = np.array([.5, .5, .2]),
                 seed=50,
                 input_adj: Optional[torch.Tensor] = None):
        """
        Args:
            p :: float [0,1] prob of edge in ER OR radius in random geom
            num_signals :: number of white signals to generate
            gso :: which gso to use for diffusions
            label :: which label to return (adj, laplacian, ~precision)
            num_samples :: number of (FC,S) samples to create
            sum_stat :: str indicating whether to use covariance or correlation matrices
            graph_sampling_params: info needed in the process of sampling graphs
            coeffs :: coefficients to use for diffusing white signals
        """
        self.huge_graphs = graph_sampling_params['num_vertices'] > (500 if torch.cuda.is_available() else 200)
        self.num_signals = num_signals
        assert label in ['adjacency', 'laplacian', 'precision']
        assert gso in ['adjacency', 'laplacian']
        self.gso = gso
        self.label = label
        self.label_norm = label_norm
        self.normal_mle = normal_mle # HACK: whether to use special filters for normal mle
        self.num_samples = num_samples
        self.graph_sampling_params = graph_sampling_params
        self.sum_stat = sum_stat
        self.coeffs = coeffs
        self.dtype = dtype
        self.sum_stat_norm, self.sum_stat_norm_val = sum_stat_norm, sum_stat_norm_val #norm_sc_label, self.norm_sc_label_val = norm_sc_label, norm_sc_label_val
        self.seed = seed
        assert (sum_stat in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']), \
            f'Dataset: sum_stat must be cov or corr, is {sum_stat}\n'
        self.sum_stat = sum_stat

        # make string from coefficients array for filename
        PATH = "" #self.make_path_str()

        # attempt to load if it already exists
        if path.exists(PATH):
            #print(f' {self.graph_sampling} Dataset # samples {self.num_samples}, # signals {self.num_signals} already exists! loading dataset...')
            if not self.huge_graphs:
                # Why not simply recompute sample_cov each time? -> diffusion of *random white signals* -> can seed tho
                #  removes reporducability. For huge graphs, we are using analytic_cov
                self.sample_cov, self.adj, self.ids, self.coeffs = torch.load(PATH)
                self.sample_corr = correlation_from_covariance(self.sample_cov)
                # this may be unused (self.analytic_corr) -> change to compute_analytic_sum.._invidiual
                self.analytic_cov, self.analytic_corr = analytic_summary_stats(self.adj.to(self.dtype), self.coeffs)
            else:
                self.adj, self.ids, self.coeffs = torch.load(PATH)

        # if not, create and save it
        else:
            #print(f' {self.graph_sampling} Dataset # samples {self.num_samples}, # signals {self.num_signals} does not exist! Creating dataset...')
            # if graphs are given, use them, otherwise sample them
            self.adj = self.sample_graphs() if (input_adj is None) else input_adj
            self.adj_min_eigs = torch.linalg.eigvalsh(self.adj.to(torch.float32)).min(dim=1)[0]

            if self.huge_graphs:
                print(f"\nDONE SAMPLING GRAPHS OF SIZE {graph_sampling_params['num_vertices']}")
            if not self.huge_graphs:
                # we only store this if graphs aren't huge
                pl.seed_everything(self.seed)
                s = self.construct_gso(self.adj, gso=gso)
                self.sample_cov = diffusion_summary_stat(S=s.to(self.dtype), coeffs=self.coeffs,
                                                         num_signals=self.num_signals, sum_stat=self.sum_stat,
                                                         dtype=self.dtype,
                                                         use_gauss_mle=self.normal_mle)

                H = glasso_filter(s.to(self.dtype)) if self.normal_mle else matrix_polynomial(s.to(self.dtype), self.coeffs)
                self.precision = torch.inverse(H.bmm(H))
                self.pr_min_eigs = torch.linalg.eigvalsh(self.precision).min(dim=1)[0]
            self.ids = torch.arange(0, len(self.adj), 1).to(torch.int16)
            """ 
            if self.huge_graphs:
                torch.save((self.adj, self.ids, self.coeffs), PATH)
            else:
                torch.save((self.sample_cov, self.adj, self.ids, self.coeffs), PATH)
            print(f'saving dataset...')
            """

    def construct_gso(self, adj, gso):
        #assert adj.ndim == 3
        num_vertices = adj.shape[-1]
        if gso == 'adjacency':
            return adj
        elif gso == 'laplacian':
            D = torch.diag_embed(adj.sum(dim=1))
            return D - adj.to(self.dtype)
        else:
            raise ValueError(f'unknown GSO {gso} given') #this is not precision')

    def sample_graphs(self):
        n = self.graph_sampling_params['num_vertices']
        adj = torch.zeros((self.num_samples, n, n), dtype=torch.bool)
        edge_densities = torch.zeros(self.num_samples, dtype=self.dtype)
        if self.huge_graphs:
            print(f"Creating Dataset with huge graph:s {n} nodes")
        pl.seed_everything(seed=self.seed)  # for reproducible graph sampling
        for i in range(self.num_samples):
            G, attempts, edge_density = connected_sparse_graph_sampling(params=self.graph_sampling_params)
            if self.huge_graphs:
                print(f"{i}, edge density {edge_density:.4f}", end="")
            # plot_graph(G, np.ones(N))
            A, L = create_GSO(G)
            adj[i] = torch.from_numpy(A).to(torch.bool)
            edge_densities[i] = edge_density

        return adj

    def make_path_str(self):
        self.coeffs_str = ""
        for i, f in enumerate(self.coeffs):
            f_str = str(round(f, 2))
            self.coeffs_str += f_str
        print(f'coefficient used in diffusion: {self.coeffs}')
        if type(self.r) == type({}): #sbm -> {'num_com':3, 'p_in': 0.8, 'p_out': 0.2}
            r_str = str(self.r.copy()).replace("'", '')[1:][:-1].replace(' ', '').replace(',', '_').replace(':', '=')
        else:
            r_str = str(self.r)

        if self.huge_graphs:
            # remove num_signals bc always using analytic_cov
            # remove coeffs_str because indep of coeffs: coeffs can be changed! Saves time recomputing graphs!
            n = self.graph_sampling_params['num_vertices']
            unique = f'{self.graph_sampling}_{n}_{self.num_samples}_{self.edge_density_low}_{self.edge_density_high}_{r_str}-HUGE'
        else:
            unique = f'{self.graph_sampling}_{n}_{self.num_signals}_{self.num_samples}_{self.coeffs_str}_{self.edge_density_low}_{self.edge_density_high}_{r_str}'
            #f'_{self.norm_before_diff}_{self.bin_labels_after_diff}'

        PATH = path2DiffCreatedData + '/Synthetic_' + unique
        return PATH

    def __len__(self):
        return self.adj.shape[0]

    # compute corr from cov, and normalize here
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.numpy()
        if self.huge_graphs:
            # compute sample_cov on demand
            return self.huge_getitem(idx)
        else:
            # already computed sample_cov
            return self.standard_getitem(idx)

    def standard_getitem(self, idx):
        #s = self.construct_gso(self.adj[idx], gso=self.gso)
        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov[idx]
        elif self.sum_stat == 'sample_corr':
            sum_mat = correlation_from_covariance(self.sample_cov[idx])
        elif self.sum_stat in ['analytic_cov', 'analytic_corr']:
            sum_mat = diffusion_summary_stat(S=self.construct_gso(self.adj[idx], gso=self.gso).to(self.dtype),
                                             coeffs=self.coeffs,
                                             num_signals=None,
                                             sum_stat=self.sum_stat,
                                             dtype=self.dtype,
                                             use_gauss_mle=self.normal_mle)

        else:
            raise ValueError(f"unrecognized summary statistic string {self.sum_stat} used")

        sum_mat = normalize_slices(sum_mat, which_norm=self.sum_stat_norm, extra=self.sum_stat_norm_val)

        y = self.precision[idx] if self.label == 'precision' else self.construct_gso(self.adj[idx], gso=self.label)
        if self.label_norm['normalization'] == 'min_eig':
            if 'adj' in self.label:
                min_eigs = self.adj_min_eigs
            elif 'laplacian' in self.label:
                min_eigs = torch.zeros_like(self.adj_min_eigs)
            elif 'precision' in self.label:
                min_eigs = self.pr_min_eigs
            n = self.adj.shape[-1]
            # adding the minimum eigenvalue makes new min eig 0. Make it self.label_norm['min_eig'].
            # where eigenvalue is less than desired, add to diagonal to make it equal to minimum. Ignore all others.
            delta = torch.where(min_eigs < self.label_norm['min_eig'], (self.label_norm['min_eig'] - min_eigs), torch.zeros_like(min_eigs))
            y = y + delta[idx]*torch.eye(n)

        return sum_mat.squeeze().to(self.dtype), y.to(self.dtype), self.ids[idx], "", ""

    # for very large datasets, precomputing everything is not feasible.

    # ex: num_samples=100, N=6,800.
    #   storing adjs with binary dtype: num_samples * N^2 ~= 4.6 GB.
    #   sample_cov is tensor of shape [num_samples, N, N]. Using a 32 bit float tensor ~150 GB!
    #   But if we compute only one slice at a time ==>  1.5 GB/slice ==> 3 GB for adj and sum_stat
    def huge_getitem(self, idx):
        assert False, f'walk through - copy above'
        n = self.graph_sampling_params['num_vertices']
        adj = self.adj[idx].view(1, n, n).to(self.dtype)
        s = self.construct_gso(adj, self.label)
        sum_mat = diffusion_summary_stat(S=s, coeffs=self.coeffs, num_signals=self.num_signals,
                                         sum_stat=self.sum_stat, dtype=self.dtype, use_gauss_mle=self.normal_mle)
        sum_mat = normalize_slices(sum_mat, which_norm=self.sum_stat_norm, extra=self.sum_stat_norm_val)

        sum_mat = sum_mat.squeeze()
        return sum_mat.to(self.dtype), s.to(self.dtype), self.ids[idx], "", ""

    def full_ds(self):
        assert not self.huge_graphs, f'cant get full ds of a huge graph!'
        s = self.construct_gso(self.adj, self.label)
        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov
        elif self.sum_stat == 'sample_corr':
            sum_mat = correlation_from_covariance(self.sample_cov)
        elif self.sum_stat in ['analytic_cov', 'analytic_corr']:
            sum_mat = diffusion_summary_stat(S=s.to(self.dtype), coeffs=self.coeffs,
                                             num_signals=self.num_signals, sum_stat=self.sum_stat, dtype=self.dtype,
                                             use_gauss_mle=self.normal_mle)
        else:
            raise ValueError(f"Unrecognized summary statistic string {self.sum_stat} used")

        sum_mat = normalize_slices(sum_mat, which_norm=self.sum_stat_norm, extra=self.sum_stat_norm_val)

        s = self.construct_gso(self.adj, self.label)
        if self.label == 'precision':
            # true covariance = H^2 -> Pr = H^-2
            H = matrix_polynomial(s, self.coeffs)
            cov = H.bmm(H)
            y = torch.inverse(cov)
        else:
            y = s

        ids = [f'{i}' for i in range(self.num_samples)]
        scan_dirs = [""]*self.num_samples
        tasks = [""]*self.num_samples

        return sum_mat.to(self.dtype), y.to(self.dtype), ids, scan_dirs, tasks

    def prob_matrix(self):
        # construct the matrix of edgewise probabilities from which the graphs are sampled.
        # only defined for ER and SBM
        n = self.graph_sampling_params['num_vertices']
        if self.graph_sampling_params['graph_sampling'] == 'ER':
            return torch.ones(n, n)*self.graph_sampling_params['p']
        elif self.graph_sampling_params['graph_sampling'] == 'sbm':
            num_communities = self.graph_sampling_params['num_communities']
            p_in, p_out = self.graph_sampling_params['p_in'], self.graph_sampling_params['p_out']
            sizes, prob_matrix = sbm_constructor(num_vertices=n, num_communities=num_communities, p_in=p_in, p_out=p_out)
            #plt.imshow(prob_matrix)
            return prob_matrix
        else:
            raise ValueError(f'Unrecognized graph gen OR have not implimented prob_matrix')


class DiffusionDataModule(pl.LightningDataModule):
    def __init__(self,
                 num_signals=50,
                 gso: str = 'adjacency', # the thing performing diffusions
                 label: str = 'adjacency',
                 label_norm: Dict = {'normalization': 'min_eig', 'min_eig': 1.0},
                 normal_mle: bool = False,
                 batch_size=512,
                 train_size=1000,
                 val_size=500,
                 test_size=500,
                 sum_stat='sample_cov',
                 graph_sampling_params: Dict = {},
                 sum_stat_norm: Optional[str] = 'max_eig',
                 sum_stat_norm_val: Optional[float] = 'symeig',
                 transform=None,
                 seed=None,
                 num_workers=0,
                 coeffs=np.array([0.5, 0.5, 0.2])):#coeffs=np.array([-.18, 2.446, -.0211])):
        super().__init__()
        self.train_size, self.val_size, self.test_size = \
            train_size, val_size, test_size
        if self.train_size == self.val_size and self.train_size>0: #they will load identical datasets otherwise
            self.val_size += 1
        if self.val_size == self.test_size and self.val_size>0:
            self.test_size += 1
        assert (self.train_size != self.val_size) and (self.val_size != self.test_size) and (self.train_size != self.test_size), f'to ensure they dont load same saved dataset'
        self.num_signals = num_signals
        # gso is the object performing diffusions
        self.gso = gso
        assert gso in ['adjacency', 'laplacian']
        # label is the thing being returned as 'y' in dataset
        self.label = label
        assert label in ['adjacency', 'laplacian', 'precision']
        self.label_norm = label_norm
        self.normal_mle = normal_mle
        self.graph_sampling_params = graph_sampling_params
        #self.graph_sampling = graph_sampling_params['graph_sampling']

        self.batch_size = batch_size

        self.sum_stat = sum_stat
        self.sum_stat_norm, self.sum_stat_norm_val = sum_stat_norm, sum_stat_norm_val

        self.coeffs = coeffs

        self.transform = transform
        self.seed = seed
        self.num_workers = num_workers
        self.synth_ds, self.train_ds, self.val_ds, self.test_ds = None, None, None, None

        n = graph_sampling_params['num_vertices']
        self.subnetwork_masks = {'full': torch.ones((n, n), dtype=torch.bool).view(1, n, n)}
        self.non_neg_labels = (label in ['adjacency'])
        self.self_loops = (label in ['laplacian', 'precision'])

    def setup(self, stage=None):

        if (stage == 'fit') or (stage is None):

            kwargs_dict =\
                {
                'num_signals': self.num_signals,
                'gso': self.gso,
                'label': self.label,
                'label_norm': self.label_norm,
                'normal_mle': self.normal_mle,
                'graph_sampling_params': self.graph_sampling_params,
                'coeffs': self.coeffs,
                'sum_stat': self.sum_stat,
                'sum_stat_norm': self.sum_stat_norm,
                'sum_stat_norm_val': self.sum_stat_norm_val
                 }

            if self.train_size > 0:
                self.train_ds = DiffusionDataset(num_samples=self.train_size, seed=self.seed, **kwargs_dict)
            if self.val_size > 0:
                self.val_ds = DiffusionDataset(num_samples=self.val_size, seed=self.seed+17, **kwargs_dict)
            if self.test_size > 0:
                self.test_ds = DiffusionDataset(num_samples=self.test_size, seed=self.seed+50, **kwargs_dict)
        else:
            #print(f'Synthetic Datamodule: have not implimented a {stage} part')
            return

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers) #cant go into debug mode if >0

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers) # if this is >0 for large graphs => timeout

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        pass


### tests ###
def test_sample_white_signals():
    # ensure that what we sample are indeed white: zero mean and with identity covariance
    pl.seed_everything(50)
    n, num_matrices, num_signals = 10, 5, 1000000
    white_signals_single = sample_white_signals(num_vertices=n, num_matrices=num_matrices, signals_per_matrix=num_signals, which='single')
    white_signals_many = sample_white_signals(num_vertices=n, num_matrices=num_matrices, signals_per_matrix=num_signals, which='single')

    mean_single = torch.cat([white_signals_single[i].mean(dim=1) for i in range(len(white_signals_single))], dim=0)
    mean_many = torch.cat([white_signals_many[i].mean(dim=1) for i in range(len(white_signals_many))], dim=0)
    assert torch.allclose(mean_single, torch.zeros_like(mean_single), atol=3e-3)
    assert torch.allclose(mean_many, torch.zeros_like(mean_many), atol=3e-3)

    cov_single = torch.cat([torch.cov(white_signals_single[i, :, :], correction=0).unsqueeze(dim=0) for i in range(len(white_signals_single))], dim=0)
    cov_many = torch.cat([torch.cov(white_signals_many[i, :, :], correction=0).unsqueeze(dim=0) for i in range(len(white_signals_many))], dim=0)
    I = torch.eye(n).expand(num_matrices, n, n)
    assert torch.allclose(cov_single, I, atol=5e-3) and torch.allclose(cov_many, I, atol=5e-3)


def test_diffusion_dataset(gso, label, normal_mle=False):
    # confirm that with same seed, we get sample outputs
    seed = 45
    """
    dm_args = {'seed': seed, 'num_samples': 10,
               'gso': gso,
               'label': label,
               'sum_stat': 'sample_cov',
               'graph_sampling_params': {'graph_sampling': 'geom', 'num_vertices': 68, 'r': .56, 'dim': 2,
                                         'edge_density_low': .5, 'edge_density_high': .6},
               'coeffs': np.array([0.5, 0.5, 0.2])}
    """
    dm_args = {'seed': seed, 'num_samples': 5,
               'gso': gso,
               'label': label,
               'normal_mle': normal_mle,
               'sum_stat': 'sample_cov',
               'graph_sampling_params': {'graph_sampling': 'ER', 'num_vertices': 20, 'p': .2,
                                         'edge_density_low': .15, 'edge_density_high': .25},
               'coeffs': np.array([0.5, 0.5, 0.2])}

    ds1, ds2 = DiffusionDataset(**dm_args, num_signals=50), DiffusionDataset(**dm_args, num_signals=50)
    assert torch.allclose(ds1.sample_cov, ds2.sample_cov) and torch.allclose(ds1.adj, ds2.adj), f'same seed + params should create identical datasets'

    # as number of signals increases, the sample covariance should approach the analyitic covariance
    print('maximum difference between sample covariance and analytic covariance with...')
    powers = 5
    fig, axs = plt.subplots(nrows=2, ncols=powers, sharex=True)
    axs[0, 0].set_ylabel('estimated cov')
    axs[1, 0].set_ylabel('true cov')
    num_signals_arr = [50*10**i for i in range(powers)]
    # for label = adj: errors_obs = [894, 457, 78, 60] # rounded up errors
    errors_compute = []
    for j, num_signals in enumerate(num_signals_arr):
        ds = DiffusionDataset(**dm_args, num_signals=num_signals)
        s = ds.construct_gso(ds.adj.to(torch.float32), ds.gso)
        analytic_cov = diffusion_summary_stat(s, num_signals=num_signals, coeffs=ds.coeffs,sum_stat='analytic_cov',
                                              use_gauss_mle=normal_mle)
        #if normal_mle:
        #    filter = glasso_filter(s=s)
        #    analytic_cov = filter.bmm(filter)
        #    analytic_corr = correlation_from_covariance(analytic_cov)
        #else:
        #    analytic_cov, analytic_corr = analytic_summary_stats(s, ds.coeffs)
        diffs = torch.max((ds.sample_cov - analytic_cov).abs().view(ds.sample_cov.shape[0], -1), dim=1)[0]
        errors_compute.append(diffs.max())
        print(f'\t{num_signals} signals -> {diffs.max():.5f}')
        axs[0, j].imshow(ds.sample_cov[0].numpy())
        axs[1, j].imshow(analytic_cov[0].numpy())
        axs[0, j].set_title(f'{num_signals}')
        axs[1, j].set_xlabel(f'max abs diff {diffs.max():.5f}')
    #assert all([ec < eo for (eo, ec) in zip(errors_obs, errors_compute)])
    fig.show()


if __name__ == "__main__":
    test_sample_white_signals()
    #test_diffusion_dataset(gso='adjacency', label='adjacency')
    #test_diffusion_dataset(gso='laplacian', label='laplacian')
    test_diffusion_dataset(gso='adjacency', label='adjacency', normal_mle=True)
    dm = DiffusionDataModule(train_size=20, val_size=25, test_size=30,
                             gso='laplacian',
                             label='precision',
                             normal_mle=False,
                             graph_sampling_params={'graph_sampling': 'geom', 'num_vertices': 68,
                                                    'r': 0.56, 'dim': 2,
                                                    'edge_density_low': 0.5, 'edge_density_high': 0.6},
                             seed=50)
    dm.setup('fit')
    train_dl = dm.train_dataloader()
