# skelaton take from: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

from __future__ import print_function, division
import numpy as np
from torch.utils.data import Dataset, DataLoader

from data.dataset_utils import create_GSO, connected_sparse_gen, \
    add_edge_weights, all_pair_shortest_path_lengths, calc_sparsity, sbm_constructor
from data.summary_stats import compute_diffusion_summary_stats, compute_analytic_summary_stats, \
    compute_diffusion_summary_stats_individual

from utils.util_funcs import normalize_slices, correlation_from_covariance


from typing import Optional

import os
import sys, torch
from pathlib import Path

from pathlib import Path
file = Path(__file__).resolve()
hops = 1
path2project = str(file.parents[hops])
path2currDir = str(Path.cwd())
path2CreatedData = path2project + '/data/created_data'
sys.path.append(path2project) # add top level directory -> geom_dl/

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

np.set_printoptions(precision=10)
np.set_printoptions(suppress=True)

DEBUG = False#True
from os import path


class PureSyntheticDiffusionDataset(Dataset):
    """Diffused white signals over a sampled graph using graph filter H then
            compute summary statistics (Cov, Corr)
        Input: Covariance Cx
        Label: GSO S
    """

    def __init__(self,
                 num_vertices: int = 68,
                 num_signals: int = 50,
                 num_samples: int = 963,
                 dtype=torch.float32,
                 sum_stat: str = 'sample_cov',
                 fc_norm=None,
                 fc_norm_val=None,
                 graph_gen: str = 'geom',
                 r: float = 0.56,
                 sparse_thresh_low: float = 0.5, sparse_thresh_high: float = 0.6,
                 norm_before_diff: bool = False,
                 bin_labels_after_diff: bool = False,
                 coeffs: np.ndarray = np.array([.5, .5, .2]),
                 rand_seed=50,
                 input_adj: Optional[torch.Tensor] = None):
        """
        Args:
            p :: float [0,1] prob of edge in ER OR radius in random geom
            num_signals :: number of white signals to generate
            num_samples :: number of (FC,S) samples to create
            sum_stat :: str indicating whether to use covariance or correlation matrices
            sparse_thresh_low/high :: range of allowable sparsity for samples graphs
            graph_gen :: type of random graph to sample (random geometric,ER,sbm, prefferential_attachment)
            bin_labels_after_diff :: binarize input graphs after diffusing signals? Used only in
                setting the path for saving.
            norm_before_diffusion :: should we normalize graph (via largest eigval by abs val)
                before diffusing white signal over it
            coeffs :: coefficients to use for diffusing white signals
        """
        self.num_vertices = num_vertices
        self.huge_graphs = num_vertices > (500 if torch.cuda.is_available() else 200)
        self.num_signals = num_signals
        self.num_samples = num_samples
        self.graph_gen = graph_gen
        self.r = r
        self.sparse_thresh_low, self.sparse_thresh_high = sparse_thresh_low, sparse_thresh_high
        self.sum_stat = sum_stat
        self.norm_before_diff, self.bin_labels_after_diff = norm_before_diff, bin_labels_after_diff
        self.coeffs = coeffs
        self.dtype = dtype
        self.fc_norm, self.fc_norm_val = fc_norm, fc_norm_val #norm_sc_label, self.norm_sc_label_val = norm_sc_label, norm_sc_label_val
        self.rand_seed = rand_seed
        assert (sum_stat in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']), \
            f'Dataset: sum_stat must be cov or corr, is {sum_stat}\n'
        self.sum_stat = sum_stat

        # make string from coefficients array for filename
        PATH = self.make_path_str()

        # attempt to load if it already exists
        if path.exists(PATH):
            #print(f' {self.graph_gen} Dataset # samples {self.num_samples}, # signals {self.num_signals} already exists! loading dataset...')
            if not self.huge_graphs:
                # Why not simply recompute sample_cov each time? -> diffusion of *random white signals* -> can seed tho
                #  removes reporducability. For huge graphs, we are using analytic_cov
                self.sample_cov, self.adj, self.ids, self.coeffs = torch.load(PATH)
                self.sample_corr = correlation_from_covariance(self.sample_cov)
                # this may be unused (self.analytic_corr) -> change to compute_analytic_sum.._invidiual
                self.analytic_cov, self.analytic_corr = compute_analytic_summary_stats(self.adj.to(self.dtype), self.coeffs)
            else:
                self.adj, self.ids, self.coeffs = torch.load(PATH)

        # if not, create and save it
        else:
            #print(f' {self.graph_gen} Dataset # samples {self.num_samples}, # signals {self.num_signals} does not exist! Creating dataset...')
            if input_adj is not None:
                # if graphs are given, use them
                adj = input_adj
            else:
                # if not, sample graphs and place into tensor
                adj = torch.zeros((self.num_samples, self.num_vertices, self.num_vertices), dtype=torch.bool)
                sparsities = torch.zeros(self.num_samples, dtype=self.dtype)
                if self.huge_graphs:
                    print(f"Creating Dataset with huge graph:s {self.num_vertices} nodes")
                np.random.seed(self.rand_seed) # for reproducible graph sampling
                for i in range(self.num_samples):
                    G, attempts, sparsity = \
                        connected_sparse_gen(self.num_vertices,
                                             r=self.r, dim=2,
                                             sparse_thresh_low=sparse_thresh_low,
                                             sparse_thresh_high=sparse_thresh_high,
                                             graph_gen=self.graph_gen)
                    if self.huge_graphs:
                        print(f"{i}, sparsity {sparsity:.4f}", end="")
                    #plot_graph(G, np.ones(N))
                    A, L = create_GSO(G)
                    adj[i] = torch.from_numpy(A).to(torch.bool)
                    sparsities[i] = torch.tensor(sparsity)

            if self.norm_before_diff:
                print('not yet implemented...exiting')
                exit(1)
                # is this getting largest by magnitude??
                largest_eigvals, _ = torch.lobpcg(adj, k=1, largest=True)
                adj = adj / largest_eigvals.view(num_samples, 1, 1)  # divide by largest eigenvalue

            if self.huge_graphs:
                print(f"\nDONE SAMPLING GRAPHS OF SIZE {self.num_vertices}")
            if not self.huge_graphs:
                # we only store this if graphs aren't huge
                # this is a batched operations: do this for all graphs at once
                self.sample_cov = compute_diffusion_summary_stats_individual(S=adj.to(self.dtype), coeffs=self.coeffs,
                                                                             num_signals=self.num_signals,
                                                                             sum_stat='sample_cov', dtype=self.dtype)
            self.ids = torch.arange(0, len(adj), 1).to(torch.int16)

            self.adj = adj
            """
            if self.huge_graphs:
                torch.save((self.adj, self.ids, self.coeffs), PATH)
            else:
                print('configure')
                torch.save((self.sample_cov, self.adj, self.ids, self.coeffs), PATH)
            print(f'saving dataset...')
            """

    def make_path_str(self):
        self.coeffs_str = ""
        for i, f in enumerate(self.coeffs):
            f_str = str(round(f, 2))
            self.coeffs_str += f_str
        print(f'coefficient used in diffusion: {self.coeffs}')
        if type(self.r) == type({}): #sbm -> {'num_com':3, 'p_in': 0.8, 'p_out': 0.2}
            r_str = str(self.r.copy()).replace("'", '')[1:][:-1].replace(' ', '').replace(',', '_').replace(':', '=')
        else:
            r_str = str(self.r)

        if self.huge_graphs:
            # remove num_signals bc always using analytic_cov
            # remove coeffs_str because indep of coeffs: coeffs can be changed! Saves time recomputing graphs!
            unique = f'{self.graph_gen}_{self.num_vertices}_{self.num_samples}_{self.sparse_thresh_low}_{self.sparse_thresh_high}_{r_str}-HUGE'
        else:
            unique = f'{self.graph_gen}_{self.num_vertices}_{self.num_signals}_{self.num_samples}_{self.coeffs_str}_{self.sparse_thresh_low}_{self.sparse_thresh_high}_{r_str}'
            #f'_{self.norm_before_diff}_{self.bin_labels_after_diff}'

        PATH = path2CreatedData + '/Synthetic_' + unique
        return PATH

    def __len__(self):
        return self.adj.shape[0]

    # compute corr from cov, and normalize here
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.numpy()
        if self.huge_graphs:
            # compute sample_cov on demand
            return self.huge_getitem(idx)
        else:
            # already computed sample_cov
            return self.standard_getitem(idx)

    def standard_getitem(self, idx):
        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov[idx]
        elif self.sum_stat == 'sample_corr':
            sum_mat = correlation_from_covariance(self.sample_cov[idx])
        elif self.sum_stat in ['analytic_cov', 'analytic_corr']:
            sum_mat = compute_diffusion_summary_stats_individual(S=self.adj[idx].to(self.dtype),
                                                                 coeffs=self.coeffs,
                                                                 num_signals=None,
                                                                 sum_stat=self.sum_stat,
                                                                 dtype=self.dtype)
        else:
            raise ValueError(f"unrecognized summary statistic string {self.sum_stat} used")

        sum_mat = normalize_slices(sum_mat, which_norm=self.fc_norm, extra=self.fc_norm_val)

        # if using binary labels, return bin labels
        if self.bin_labels_after_diff:
            adj = (self.adj[idx] > 0) + 0.0
        else:
            adj = self.adj[idx]

        sum_mat = sum_mat.squeeze()
        return sum_mat.to(self.dtype), adj.to(self.dtype), self.ids[idx], "", ""

    # for very large datasets, precomputing everything is not feasible.

    # ex: num_samples=100, N=6,800.
    #   storing adjs with binary dtype: num_samples * N^2 ~= 4.6 GB.
    #   sample_cov is tensor of shape [num_samples, N, N]. Using a 32 bit float tensor ~150 GB!
    #   But if we compute only one slice at a time ==>  1.5 GB/slice ==> 3 GB for adj and sum_stat
    def huge_getitem(self, idx):
        adj = self.adj[idx].view(1, self.num_vertices, self.num_vertices).to(self.dtype)
        sum_mat = compute_diffusion_summary_stats_individual(S=adj, coeffs=self.coeffs, num_signals=self.num_signals,
                                                             sum_stat=self.sum_stat, dtype=self.dtype)
        sum_mat = normalize_slices(sum_mat, which_norm=self.fc_norm, extra=self.fc_norm_val)

        # if using binary labels, return bin labels
        if self.bin_labels_after_diff:
            adj = (adj > 0) + 0.0

        sum_mat = sum_mat.squeeze()
        return sum_mat.to(self.dtype), self.adj[idx].to(self.dtype), self.ids[idx], "", ""

    def full_ds(self):
        assert not self.huge_graphs, f'cant get full ds of a huge graph!'
        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov
        elif self.sum_stat == 'sample_corr':
            sum_mat = correlation_from_covariance(self.sample_cov)
        elif self.sum_stat in ['analytic_cov', 'analytic_corr']:
            sum_mat = compute_diffusion_summary_stats_individual(S=self.adj.to(self.dtype), coeffs=self.coeffs,
                                                                 num_signals=self.num_signals,
                                                                 sum_stat=self.sum_stat, dtype=self.dtype)
        else:
            raise ValueError(f"Unrecognized summary statistic string {self.sum_stat} used")

        sum_mat = normalize_slices(sum_mat, which_norm=self.fc_norm, extra=self.fc_norm_val)

        # if using binary labels, return bin labels
        if self.bin_labels_after_diff:
            adj = (self.adj > 0) + 0.0
        else:
            adj = self.adj

        ids = [f'{i}' for i in range(self.num_samples)]
        scan_dirs = [""]*self.num_samples
        tasks = [""]*self.num_samples

        return sum_mat.to(self.dtype), adj.to(self.dtype), ids, scan_dirs, tasks

    def prob_matrix(self):
        # construct the matrix of edgewise probabilities from which the graphs are sampled.
        # only defined for ER and SBM
        if self.graph_gen == 'ER':
            assert type(self.r) == type(.45), f'r {self.r} assumed to sampling prob'
            return torch.ones(self.num_vertices, self.num_vertices)*self.r
        elif self.graph_gen == 'sbm':
            assert type(self.r) == type({})
            sizes, prob_matrix = sbm_constructor(self.num_vertices, **self.r)
            return prob_matrix
        else:
            raise ValueError(f'Unrecognized graph gen OR have not implimented prob_matrix')



class PsuedoSyntheticDiffusionDataset(Dataset):
    """Diffused white signals over a given brain graphs using graph filter H then
            compute summary statistics (Cov, Corr)
        Input: Covariance Cx
        Label: GSO S
    """

    def __init__(self,
                 adjs,
                 num_signals=50,
                 subject_ids=None, scan_dirs=None, num_patients=None,
                 sum_stat='sample_cov',
                 scale_label=9.9,
                 bin_labels_after_diff=False,
                 diff_over_weighted=True,
                 norm_before_diff=False,
                 transform=None,
                 remove_fc_diag=False,
                 fc_norm=None,
                 fc_norm_val=None,
                 coeffs=np.array([.5, .5, .2]),
                 dtype=torch.float32,
                 split_id=''
                 ):
        """
        Args:
            num_signals :: int number of white signals to generate
            sum_stat :: str indicating whether to use covariance or correlation matrices
            scale_label :: float to divide each label by. For weighted graphs. For brain data, make 10
            transform (callable, optional): Optional transform to be applied
                on a sample.
            adj :: nparray of adjacency inputs.
            diff_over_weighted :: boolean telling if we should binarize graphs before diffusion
            bin_labels_after_diff :: boolean telling if input graphs are binarized. Used only in
                setting the path for saving.
            norm_before_diffusion :: boolean telling if we should normalize graph
                (via largest eigval by abs val) before diffusing white signal over it
            transform :: function to map sample to tensor
            coeffs :: np array of coefficients to use for diffusing white signals
        """

        self.num_scans, self.N, _ = adjs.shape
        if not torch.is_tensor(adjs):
            self.adjs = torch.from_numpy(adjs).to(dtype) # torch.from_numpy(given_adj).type(torch.float32).view(num_samples, self.N, self.N)  # given shape (num_sample, N^2)
        else:
            self.adjs = adjs

        self.num_signals = num_signals
        self.sum_stat = sum_stat
        assert (sum_stat in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']), f'Dataset: sum_stat must be cov or corr, is {sum_stat}\n'
        self.scale_label = scale_label
        self.transform = transform
        self.diff_over_weighted = diff_over_weighted
        self.bin_labels_after_diff = bin_labels_after_diff
        self.norm_before_diff = norm_before_diff

        self.remove_fc_diag = remove_fc_diag
        self.fc_norm, self.fc_norm_val = fc_norm, fc_norm_val #norm_sc_label, self.norm_sc_label_val = norm_sc_label, norm_sc_label_val

        self.coeffs = coeffs

        self.num_patients = num_patients
        self.subject_ids = subject_ids
        self.scan_dirs = scan_dirs

        self.dtype = dtype

        self.split_id = split_id
        self.coeffs_str = ""
        for i, f in enumerate(self.coeffs):
            f_str = str(round(f, 2))
            self.coeffs_str += f_str
        print(f'coefficient used in diffusion: {coeffs}')
        #PATH =  f'./created_data/Psuedo_Synthetic_{self.N}_{num_signals}_{self.coeffs_str}_{diff_over_weighted}_{norm_before_diff}'
        unique = f'{self.N}_{self.num_scans}_{num_signals}_{fc_norm}_{fc_norm_val}_{sum_stat}_{self.coeffs_str}_{diff_over_weighted}_{norm_before_diff}_{bin_labels_after_diff}_{split_id}'
        PATH = path2CreatedData + '/Psuedo_Synthetic_' + unique#./created_data/Synthetic_'+unique

        print(f'data -  # patients {self.num_patients}, # scans {self.num_scans}, # signals {self.num_signals}')
        if False: #path.exists(PATH):
            # use fancy hash to combine all info (subject ids, scan dirs, etc) into unique filename
            print(f'\tAlready exists! Loading dataset...')
            self.sample_cov, self.analytic_cov, self.sample_corr, self.analytic_corr, self.adjs, self.coeffs, self.num_patients, self.subject_ids = torch.load(PATH)
        else:
            print(f'\tDoes NOT exist! Creating dataset...')
            adjs = self.adjs

            if not self.diff_over_weighted:
                adjs = (self.adjs > 0) + 0.0 #binarize

            if self.norm_before_diff:
                print('not yet implemented...exiting')
                exit(1)
                #is this getting largest by magnitude??
                largest_eigvals, _ = torch.lobpcg(adjs, k=1, largest=True)
                adjs = adjs/largest_eigvals.view(self.num_scans, 1, 1) #divide by largest eigenvalue

            (analytic_cov, analytic_corr), (sample_cov, sample_corr) = \
                compute_diffusion_summary_stats(adjs, self.coeffs, num_signals=self.num_signals, dtype=self.dtype)

            fc_norm_val = self.fc_norm_val #'symeig' if self.fc_norm == 'max_eig' else self.fc_norm_val
            self.adj = adjs
            self.sample_cov = normalize_slices(sample_cov, which_norm=self.fc_norm, extra=fc_norm_val)
            self.sample_corr = normalize_slices(sample_corr, which_norm=self.fc_norm, extra=fc_norm_val)
            self.analytic_cov = normalize_slices(analytic_cov, which_norm=self.fc_norm, extra=fc_norm_val)
            self.analytic_corr = normalize_slices(analytic_corr, which_norm=self.fc_norm, extra=fc_norm_val)
            #torch.save((self.sample_cov, self.analytic_cov, self.sample_corr, self.analytic_corr, self.adjs, self.coeffs, self.num_patients, self.subject_ids), PATH)
            print(f'saving dataset...')

        #

    def __len__(self):
        return self.adjs.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.numpy()

        adj = self.adjs[idx]

        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov[idx]
        elif self.sum_stat == 'sample_corr':
            sum_mat = self.sample_corr[idx]
        elif self.sum_stat == 'analytic_cov':
            sum_mat = self.analytic_cov[idx]
        elif self.sum_stat == 'analytic_corr':
            sum_mat = self.analytic_corr[idx]
        else:
            raise ValueError(f'ps-ds: getitem: unrecognized summary statistic string {self.sum_stat} used')

        # if using binary labels, return bin labels
        if self.bin_labels_after_diff:
            adj = (self.adjs[idx] > 0) + 0.0
        #apply scaling, 1 if not overwritten
        else:
            adj = adj/self.scale_label

        if self.transform:
            sum_mat, adj = self.transform(sum_mat, adj)

        if self.remove_fc_diag:
            N = sum_mat.shape[-1]
            remove_diag = torch.ones((N, N)) - torch.eye(N)
            zd = torch.broadcast_to(remove_diag, sum_mat.shape)
            sum_mat = sum_mat*zd

        return sum_mat.to(self.dtype), adj.to(self.dtype), self.subject_ids[idx], "", ""

    def full_ds(self):
        sum_mat, adjs = self.sample_cov, self.adjs

        if self.sum_stat == 'sample_cov':
            sum_mat = self.sample_cov
        elif self.sum_stat == 'sample_corr':
            sum_mat = self.sample_corr
        elif self.sum_stat == 'analytic_cov':
            sum_mat = self.analytic_cov
        elif self.sum_stat == 'analytic_corr':
            sum_mat = self.analytic_corr
        else:
            raise ValueError(f'ps-ds: full_ds: unrecognized summary statistic string {self.sum_stat} used')

        # if using binary labels, return bin labels
        if self.bin_labels_after_diff:
            adjs = (self.adjs > 0) + 0.0
        else:
            # apply scaling, 1 if not overwritten
            adjs = adjs / self.scale_label

        if self.transform:
            sum_mat, adjs = self.transform(sum_mat, adjs)

        if self.remove_fc_diag:
            N = sum_mat.shape[-1]
            remove_diag = torch.ones((N, N)) - torch.eye(N)
            zd = torch.broadcast_to(remove_diag, sum_mat.shape)
            sum_mat = sum_mat*zd

        scan_dirs = [""] * self.num_scans
        tasks = [""]*self.num_scans
        return sum_mat.to(self.dtype), adjs.to(self.dtype), self.subject_ids, scan_dirs, tasks


class BrainDataWrapper(Dataset):
    def __init__(self,
                 fcs,
                 adjs,
                 subject_ids,
                 scan_dirs,
                 tasks,
                 num_patients: Optional[int] = None,
                 binarize_labels: bool = False,
                 scale_label: float = 9.9,
                 sum_stat = 'cov',
                 remove_fc_diag = False,
                 fc_norm=None,
                 fc_norm_val=None,
                 transform=None,
                 dtype=torch.float32):
        self.num_samples = adjs.shape[0]
        self.fcs = torch.tensor(fcs).to(dtype)
        self.adjs = torch.tensor(adjs).to(dtype)
        self.num_patients = num_patients

        self.subject_ids = subject_ids
        self.scan_dirs = scan_dirs
        self.tasks = tasks

        self.binarize_labels = binarize_labels
        self.scale_label = scale_label
        self.transform = transform

        self.remove_fc_diag = remove_fc_diag
        self.sum_stat = sum_stat
        self.fc_norm, self.fc_norm_val = fc_norm, fc_norm_val
        #self.adj = adjs
        self.fcs = normalize_slices(fcs, which_norm=self.fc_norm, extra=self.fc_norm_val)

        self.dtype = dtype

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
             idx = idx.numpy()

        fcs, adjs = self.fcs[idx], self.adjs[idx]

        if self.binarize_labels:
            adjs = (adjs > 0) + 0.0
        else:
            # rescale to be 0/1
            adjs = adjs/self.scale_label

        if self.transform:
            fcs, adjs = self.transform(fcs, adjs)

        if self.sum_stat == 'corr':
            fcs = correlation_from_covariance(fcs)
        elif self.sum_stat in ['cov', None]:
            fcs = fcs
        else:
            raise ValueError(f'sum_stat {self.sum_stat} not recognized')

        if self.remove_fc_diag:
            N = fcs.shape[-1]
            remove_diag = torch.ones((N, N)) - torch.eye(N)
            zd = torch.broadcast_to(remove_diag, fcs.shape)
            fcs = fcs*zd

        subject_id = self.subject_ids[idx]
        scan_dir = self.scan_dirs[idx]
        task = self.tasks[idx]
        return fcs.to(self.dtype), adjs.to(self.dtype), subject_id, scan_dir, task

    def full_ds(self):
        # if using binary labels, return bin labels
        if self.binarize_labels:
            adjs = (self.adjs > 0) + 0.0
        else:
            # apply scaling, 1 if not overwritten
            adjs = self.adjs / self.scale_label

        if self.sum_stat == 'corr':
            fcs = correlation_from_covariance(self.fcs)
        elif self.sum_stat in ['cov', None]:
            fcs = self.fcs
        else:
            raise ValueError(f'sum_stat {self.sum_stat} not recognized')

        if self.remove_fc_diag:
            N= fcs.shape[-1]
            remove_diag = torch.ones((N, N)) - torch.eye(N)
            zd = torch.broadcast_to(remove_diag, fcs.shape)
            fcs = fcs*zd

        return fcs.to(self.dtype), adjs.to(self.dtype), self.subject_ids, self.scan_dirs, self.tasks


def test_rescale_frob(coeffs, scaling, dtype = torch.float32):
    N = 68
    num_samples = 3
    num_signals = 10000
    synth_ds_1 = PureSyntheticDiffusionDataset(num_vertices=N, num_signals=num_signals, num_samples=num_samples,
                                               sum_stat='analytic_cov', coeffs=coeffs,
                                               input_adj=None, dtype=dtype)
    sampled_graphs_1 = synth_ds_1.full_ds()[1]
    synth_ds_2 = PureSyntheticDiffusionDataset(num_vertices=N, num_signals=num_signals, num_samples=num_samples,
                                               sum_stat='analytic_cov', coeffs=scaling*coeffs,
                                               input_adj=sampled_graphs_1, dtype=dtype)

    # inspect outputs
    covs_1, adjs_1 = synth_ds_1.full_ds()[:2]
    covs_2, adjs_2 = synth_ds_2.full_ds()[:2]

    # used same adjs
    assert np.allclose(adjs_1, adjs_2)

    # covs should be equal after rescaling
    rescaled_covs_2 = 1/(scaling**2) * covs_2
    rescale_close = np.allclose(covs_1, rescaled_covs_2)

    # covs should be equal after taking frob norm
    covs_1_normed = normalize_slices(covs_1, which_norm='frob')
    covs_2_normed = normalize_slices(covs_2, which_norm='frob')
    frob_close = np.allclose(covs_1_normed, covs_2_normed)

    print(f'coeffs {coeffs} | scaling: {scaling}: \n\tpost rescale same? {rescale_close}, \n\tpost frob norm same? {frob_close}')
    max_diff_rescale = torch.max(torch.abs(covs_1 - rescaled_covs_2).view(-1))
    max_diff_normed =  torch.max(torch.abs(covs_1_normed - covs_2_normed).view(-1))
    print(f'\trescale max diff: {max_diff_rescale} || normed max diff {max_diff_normed}\n\n')


if __name__ == "__main__":


    ps = PureSyntheticDiffusionDataset(num_vertices=68, num_samples=100, graph_gen='sbm',
                                       r={'num_communities': 3, 'p_in': .8, 'p_out': 0.2},
                                       sparse_thresh_low=.35, sparse_thresh_high=.5, rand_seed=50)

    fcs, adjs, _, _, _ = ps.full_ds()

    # viz to confirm sbms labeled consistently
    import matplotlib.pyplot as plt
    plt.figure()
    for i in range(3):
        plt.imshow(adjs[i])
        input('inspect')
        plt.imshow(fcs[i])
        input('inspect')


    import sys
    from data.brain_data.matlab_to_python_brain_data import load_old_corr_brain_data, load_new_brain_data, load_old_brain_data

    # synthetic
    c1 = np.array([.5, .5, .2])
    for i in range(20):
        scaling = 10**(-i)
        test_rescale_frob(coeffs=c1, scaling=scaling, dtype=torch.float32)

    print(f'which worked?')


    # Are they ~equivalent after taking frob norm?





    ## Exploration of Sparsity and Connectivity of Random Graphs
    # Geometric:: Find which radius r gives many connected and sparse graphs
    #  -> radius of 0.25 gives many sparse and connected graphs (10.7%) with parameters:
    #       -N, dim, num_graphs, sparse_thresh = 30, 2, 1000, 0.15
    #  -> radius of 0.55 gives sparsity of ~0.55 which is the mean of the distrib of sMRI's
    #       -N, dim, num_graphs, sparse_thresh = 68, 2, 1000, 0.56
    # ER:: Find which prob r gives many connected and sparse graphs
    #  -> prob of 0.17 gives many sparse and connected graphs (14%) with parameters:
    #       -N, dim, num_graphs, sparse_thresh = 30, 2, 1000, 0.15

    if False:
        N, dim, num_graphs, sparse_thresh = 68, 2, 1000, 0.15
        sparse_thresh_low, sparse_thresh_high = 0.5, 0.6
        poss_edges = (N*(N-1))/2
        con, sp, num_con_sp, m = 0,0, 0, 0
        diag = np.sqrt(3) #diagonal of unit cube = furtherst distance possible for 2 points

        generator = 'geom'              #.2 -> .45
        r_arr = np.linspace(0.05,.35,num=10)
        for i, r in enumerate(np.linspace(0.55,.7,num=10)):
                for j in range(num_graphs):
                    if generator == 'ER':
                        G = nx.fast_gnp_random_graph(N,r)
                    elif generator == 'geom':
                        G = nx.random_geometric_graph(N, r, dim=dim)#L2 norm used

                    sparsity   = (G.number_of_edges()/poss_edges)
                    #sparse    = (G.number_of_edges()/poss_edges) < sparse_thresh
                    #print(f'r: {r}, sparsity: {sparsity}')
                    sparse     = sparse_thresh_low <= sparsity <= sparse_thresh_high
                    connected  = nx.is_connected(G)
                    if connected:
                        con += 1
                    if sparse:
                        sp  += 1
                    if connected and sparse:
                        #take measurement of qualifying graph
                        num_con_sp = num_con_sp + 1 #number connected and sparse
                        m = m + G.number_of_edges() #add total edges

                p_con = num_con_sp/num_graphs #ave # graphs connected and sparse

                # ave number of edges in connectd and sparse graphs
                m_ave        = (m/num_con_sp if (num_con_sp>0) else 0)
                sparsity_ave = m_ave/poss_edges
                print(f'Graph: {generator}. N: {N}, dim: {dim}, r: {r}')
                print(f'\t% connected: {con/num_graphs}, % sparse: {sp/num_graphs}, % co & sp?: {p_con}')
                print(f'\tAve # Edges: {m_ave}, All poss edges: {poss_edges}: m/#edges: {sparsity_ave}\n')
                #reset for next set of paramaters
                num_con_sp,m,con,sp=0,0,0,0
                if DEBUG and connected and sparse:
                    plot_graph(G, np.ones(N))



