import os, numpy as np, torch, sys, pytorch_lightning as pl, warnings
from torch.utils.data import Dataset, DataLoader
from typing import Optional, Dict
import matplotlib.pyplot as plt
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
from sklearn.utils import shuffle
from pathlib import Path
file = Path(__file__).resolve()
path2project = str(file.parents[2]) + '/'
path2BrainData = path2project + 'data/brain_data/'
path2currDir = str(Path.cwd()) + '/'
sys.path.append(path2project) # add top level directory -> geom_dl/

from utils import correlation_from_covariance, normalize_slices, edge_density
from data.brain_data.matlab_to_python_brain_data import load_brain_data, subnetwork_masks
from data.brain_data.brain_data_utils import matlab_struct_to_tensors
from data.network_diffusion.diffused_signals import diffusion_summary_stat, analytic_summary_stats


def construct_gso(adj, gso):
    # assert adj.ndim == 3
    num_vertices = adj.shape[-1]
    if gso == 'adjacency':
        return adj
    elif gso == 'laplacian':
        D = torch.diag_embed(adj.sum(dim=1))
        return D - adj.to(torch.float32)
    else:
        raise ValueError(f'unknown GSO {gso} given')  # this is not precision')


class BrainDataWrapper(Dataset):
    def __init__(self,
                 fcs,
                 adjs,
                 subject_ids,
                 scan_dirs,
                 tasks,
                 num_patients: Optional[int] = None,
                 sc_info: Dict = {'scaling': 9.9, 'edge_density_low': 0.35},
                 fc_info: Dict = {'summary_statistic': 'sample_cov', 'normalization': None, 'normalization_value': None, 'remove_diag': False},
                 label: str = 'adjacency',
                 label_norm: Dict = None, #{'normalization': 'min_eig', 'min_eig': 1.0},
                 transform=None,
                 dtype=torch.float32):
        self.num_samples = adjs.shape[0]
        self.fcs = fcs.clone().detach().to(dtype)
        self.adjs = adjs.clone().detach().to(dtype)/sc_info['scaling'] #torch.tensor(adjs).to(dtype)
        self.num_patients = num_patients

        self.subject_ids = subject_ids
        self.scan_dirs = scan_dirs
        self.tasks = tasks

        self.transform = transform

        self.sc_info = sc_info
        self.fc_info = fc_info
        assert fc_info['summary_statistic'] in ['cov', 'corr', 'sample_cov', 'sample_corr']
        self.fcs = normalize_slices(fcs, which_norm=fc_info['normalization'], extra=fc_info['normalization_value'])

        self.dtype = dtype

        self.label = label
        assert label in ['adjacency', 'laplacian']
        self.label_norm = label_norm
        self.min_eigs['adjacency'] = torch.linalg.eigvalsh(self.adjs).min(dim=1)[0]
        self.min_eigs['laplacian'] = torch.linalg.eigvalsh(construct_gso(self.adjs, gso=self.label)).min(dim=1)[0]
        # new min eig
        if (label_norm is not None) and ('min_eig' in self.label_norm):
            set_min_eig, current_min_eigs = self.label_norm['min_eig'], self.min_eigs[self.label]
            z = torch.zeros_like(self.adjs.shape[-1])
            self.delta = torch.where(current_min_eigs < 1, (set_min_eig - current_min_eigs), z)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.numpy()
        n = self.adjs.shape[-1]

        if self.label_norm is not None and self.label_norm['normalization'] == 'min_eig':
            s = construct_gso(self.adjs[idx], gso=self.label)
            y = s + self.delta[idx] * torch.eye(n)
            #min_eig_y = torch.linalg.eigvalsh(y).min(dim=1)[0].min()
            #assert torch.all(min_eig_y > (1 - 1e-3)), f'must be PSD to invert. min_eig(s + dI) =  {min_eig_y.min():.3f}'

        x = self.fcs[idx] if 'cov' in self.fc_info['summary_statistic'] else correlation_from_covariance(self.fcs[idx])

        subject_id = self.subject_ids[idx]
        scan_dir = self.scan_dirs[idx]
        task = self.tasks[idx]
        return x.to(self.dtype), y.to(self.dtype), subject_id, scan_dir, task

    def full_ds(self):
        y = self.adjs
        num_matrices, n = self.adjs.shape[:2]

        if self.label_norm is not None and self.label_norm['normalization'] == 'min_eig':
            s = construct_gso(self.adjs, gso=self.label)
            y = s + self.delta * torch.eye(n).expand(num_matrices, n, n)
            #min_eig_y = torch.linalg.eigvalsh(y).min(dim=1)[0].min()
            #assert torch.all(min_eig_y > (1 - 1e-3)), f'must be PSD to invert. min_eig(s + dI) =  {min_eig_y.min():.3f}'

        x = self.fcs if 'cov' in self.fc_info['summary_statistic'] else correlation_from_covariance(self.fcs)

        return x.to(self.dtype), y.to(self.dtype), self.subject_ids, self.scan_dirs, self.tasks

class RealDataModule(pl.LightningDataModule):
    def __init__(self,
                 seed=50,
                 batch_size: int = 256,
                 val_batch_size=None,
                 test_batch_size=None,
                 num_patients_test: Optional[int] = 5,
                 num_patients_val: Optional[int] = 70,
                 num_workers: int = 2,
                 num_splits: Optional[int] = None,
                 transform=None,
                 # brain info
                 include_subcortical: bool = False,
                 # fc info
                 fc_info: Dict = {'construction': 'concat_all_timeseries', 'scan_combination': 'mean',
                                  'remove_diag': False, 'summary_statistic': 'sample_cov',
                                  'normalization': None, 'normalization_value': None,
                                  'frob_norm_high': None},
                 # sc info
                 sc_info: Dict = {'edge_density_low': 0.35},
                 label: str = 'adjacency',
                 label_norm = None
                 ):
        super().__init__()
        assert batch_size is not None, f'batch_size is None, likely error with CLI parser'
        self.seed = seed
        self.batch_size, self.val_batch_size, self.test_batch_size = batch_size, val_batch_size, test_batch_size
        self.transform = transform

        if num_patients_test is not None and num_patients_test is not None:
            if num_patients_val == num_patients_test:
                num_patients_test += 1
                num_patients_val -=1
                print(f'val/test sizes identical. Adjusted val/test size for even SUBJECT split (not scan split). For consistent loading.')

        self.num_workers = num_workers

        self.include_subcortical = include_subcortical
        self.metadata = None
        self.fc_info, self.sc_info = fc_info, sc_info
        # fc_info required keys: construction, scan_combination, remove_diag, summary_statistic, normalization(_value), frob_norm_high
        if 'construction' not in fc_info: fc_info['construction'] = 'concat_all_timeseries'
        if 'scan_combination' not in fc_info: fc_info['scan_combination'] = 'mean'
        if 'remove_diag' not in fc_info: fc_info['remove_diag'] = False
        if 'summary_statistic' not in fc_info: raise ValueError(f"Must specify summary_statistic in fc_info")
        if 'normalization' not in fc_info: raise ValueError(f"Must specify normalization in fc_info")
        if 'normalization_value' not in fc_info: fc_info['normalization_value'] = 'symeig' if fc_info['normalization'] == 'max_eig' else None
        if 'frob_norm_high' not in fc_info: fc_info['frob_norm_high'] = None

        if 'scaling' not in sc_info: sc_info['scaling'] = 9.9
        if 'edge_density_low' not in sc_info: sc_info['edge_density_low'] = 0.35

        self.interesting_edges = None
        assert label in ['adjacency', 'laplacian']
        self.label, self.label_norm = label, label_norm

        self.non_neg_labels = (label in ['adjacency'])
        self.self_loops = (label in ['laplacian', 'precision'])

        # data immediately split into train/test
        self.all_train_idxs, self.test_idxs = None, None
        self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids, self.all_train_scan_dirs = None, None, None, None
        self.test_fcs, self.test_scs, self.test_subject_ids, self.test_scan_dirs = None, None, None, None

        # for cross validation
        self.num_splits, self.split = num_splits, 0
        if self.num_splits is not None:
            if (num_patients_val is not None):
                print(f'WARNING: Can either specify size of validation set for single train/val/test split OR '
                      f'num_splits for cross validation. Not both.')
            print(f'Using Cross Validation...')
            self.num_patients_val = self.num_patients_test = None
        else:
            self.num_patients_val, self.num_patients_test = num_patients_val, num_patients_test


        # length 1 array for simple train/val/test split
        # length num_splits array for cross validation
        self.train_val_splits = []

        # ensure fit only run once
        self.fit_already_run = False

        self.subnetwork_masks = None
        self.num_vertices = None # cache during setup

    def full_ds(self):
        # convert from .mat to python datastructures
        brain_data, self.metadata = load_brain_data()
        # extract relevant dataset
        fcs, scs, subject_ids, scan_dirs, tasks \
            = matlab_struct_to_tensors(brain_data, fc_construction=self.fc_info['construction'],
                                       scan_combination=self.fc_info['scan_combination'])
        fcs = (correlation_from_covariance(fcs) if 'corr' in self.fc_info['summary_statistic'] else fcs)

        return fcs, scs, subject_ids, scan_dirs

    def setup(self, stage=None):
        if (stage == 'fit' or stage is None) and (not self.fit_already_run):
            self.fit_already_run = True
            if self.seed is not None:
                pl.seed_everything(self.seed, workers=True)

            # convert from .mat to python datastructures
            brain_data, self.metadata = load_brain_data()
            self.subnetwork_masks = subnetwork_masks(self.metadata["idx2label_cortical"]["lobe"])

            # extract relevant dataset
            fcs, scs, subject_ids, scan_dirs, tasks \
                = matlab_struct_to_tensors(brain_data, fc_construction=self.fc_info['construction'],
                                           scan_combination=self.fc_info['scan_combination'])
            # remove outliers in the data
            remove_fc_idxs, remove_sc_idxs = self.filter_data(fcs, scs, subject_ids, scan_dirs, tasks)

            keep_idxs = torch.logical_not(remove_sc_idxs | remove_fc_idxs)
            fcs, scs, subject_ids, scan_dirs, tasks = \
                fcs[keep_idxs], scs[keep_idxs], subject_ids[keep_idxs], scan_dirs[keep_idxs], tasks[keep_idxs]

            #shuffle data
            fcs, scs, subject_ids, scan_dirs, tasks = shuffle(fcs, scs, subject_ids, scan_dirs, tasks, random_state=self.seed)

            # optionally remove subcortical nodes
            # only cortical by default now!
            """
            num_subcortical = len(self.metadata['subcortical_roi'])
            if not self.include_subcortical:
                fcs = fcs[:, num_subcortical:, num_subcortical:]
                scs = scs[:, num_subcortical:, num_subcortical:]
            """
            self.num_vertices = fcs.shape[-1]
            # train/test split
            if self.num_patients_test is None:
                self.all_train_idxs = np.arange(0, len(fcs))
                self.test_idxs = []
            else:
                train_test_split_group = GroupShuffleSplit(n_splits=1, test_size=self.num_patients_test, random_state=self.seed)
                self.all_train_idxs, self.test_idxs = next(train_test_split_group.split(fcs, scs, subject_ids)) # subject_ids are group

            self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids, self.all_train_scan_dirs, self.all_train_tasks = \
                fcs[self.all_train_idxs], scs[self.all_train_idxs], subject_ids[self.all_train_idxs], scan_dirs[self.all_train_idxs], tasks[self.all_train_idxs]
            self.test_fcs, self.test_scs, self.test_subject_ids, self.test_scan_dirs, self.test_tasks = \
                fcs[self.test_idxs],  scs[self.test_idxs],  subject_ids[self.test_idxs], scan_dirs[self.test_idxs], tasks[self.test_idxs]

            # Cross Validation
            if self.num_splits is not None:
                # further split train into num_splits folds
                val_split_iter = GroupKFold(n_splits=self.num_splits).split(self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids)

            # Straight train/val/test split
            else:
                val_split = GroupShuffleSplit(n_splits=1, test_size=self.num_patients_val, random_state=self.seed)
                val_split_iter = val_split.split(self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids)

            for fold, (train_fold_idxs, val_fold_idxs) in enumerate(val_split_iter): #idxs for all_train_*
                patients_in_train_fold = len(np.unique(self.all_train_subject_ids[train_fold_idxs]))
                patients_in_val_fold = len(np.unique(self.all_train_subject_ids[val_fold_idxs]))
                split_info = {'train_fold_idxs': train_fold_idxs, 'patients_in_train_fold': patients_in_train_fold,
                    'val_fold_idxs': val_fold_idxs, 'patients_in_val_fold': patients_in_val_fold}
                print(f'\tpatients in train fold: {patients_in_train_fold}, val fold: {patients_in_val_fold}, test set {len(np.unique(self.test_subject_ids[self.test_idxs]))}')
                self.train_val_splits.append(split_info)

        elif (stage == 'fit') and self.fit_already_run:
            print(f'stage is {stage}. Already run fit?: {self.fit_already_run}. Prob being called again in trainer. Dont run twice...')

    def filter_data(self, fcs, scs, subject_ids, scan_dirs, tasks):
        if self.fc_info['frob_norm_high'] is not None:
            # filter scs (edge density) and fcs (frob norm)
            remove_fc_idxs = torch.linalg.norm(fcs, ord='fro', dim=(1, 2)) > self.fc_info['frob_norm_high']
            rm_fcs_ids = subject_ids[np.where(remove_fc_idxs)]
            rm_fcs_scan_dirs = scan_dirs[np.where(remove_fc_idxs)]
            rm_fcs_tasks = tasks[np.where(remove_fc_idxs)]
            print(f"FC filter: Removing following scans bc: ||Cov||_F > {self.fc_info['frob_norm_high']}")
            for id, dir, task in zip(rm_fcs_ids, rm_fcs_scan_dirs, rm_fcs_tasks):
                print(f'\t{id}-{dir}-{task}')
        else:
            remove_fc_idxs = torch.zeros(len(fcs), dtype=torch.bool)

        if self.sc_info['edge_density_low'] is not None:
            remove_sc_idxs = edge_density(A=scs) < self.sc_info['edge_density_low']
            rm_scs_ids = subject_ids[np.where(remove_sc_idxs)]
            rm_scs_scan_dirs = scan_dirs[np.where(remove_sc_idxs)]
            rm_scs_tasks = tasks[np.where(remove_sc_idxs)]
            print(f"SC filter: Removing following scans bc: edge_density(SC) < {self.sc_info['edge_density_low']}")
            for id, dir, task in zip(rm_scs_ids, rm_scs_scan_dirs, rm_scs_tasks):
                print(f'\t{id}-{dir}-{task}')
        else:
            remove_sc_idxs = torch.zeros(len(scs), dtype=torch.bool)

        return remove_fc_idxs, remove_sc_idxs

    ############
    # For Cross Validation
    def set_split(self, split):
        if self.num_splits is None:
            # self.split must always be 0 when using train/val/test
            warnings.warn(f'Setting split Cross Validation split, but only train/val/test split was made (and thus only 1 CV split. Keeping Split set to 0.')
            return
        assert 0 <= self.split < self.num_splits, f'self.splits must be in [0, {self.num_splits-1}'
        self.split = split
    ############

    def train_val_dl(self, idxs, num_patients, batch_size, shuffle, seed=None, num_workers=0):
        split_fcs, split_scs, split_subject_ids, split_scan_dirs, split_tasks = \
            self.all_train_fcs[idxs], self.all_train_scs[idxs], self.all_train_subject_ids[idxs], self.all_train_scan_dirs[idxs], self.all_train_tasks[idxs]

        ds = BrainDataWrapper(fcs=split_fcs,
                              adjs=split_scs,
                              subject_ids=split_subject_ids,
                              scan_dirs=split_scan_dirs,
                              tasks=split_tasks,
                              num_patients=num_patients,
                              fc_info=self.fc_info,
                              sc_info=self.sc_info,
                              label=self.label, label_norm=self.label_norm)
        return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)#torch.cuda.is_available())

    def train_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        batch_size = len(split_info['train_fold_idxs']) if self.batch_size is None else self.batch_size
        return self.train_val_dl(idxs=split_info['train_fold_idxs'],
                                 num_patients=split_info['patients_in_train_fold'],
                                 batch_size=batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        batch_size = len(split_info['val_fold_idxs']) if (self.val_batch_size is None) else self.val_batch_size
        return self.train_val_dl(idxs=split_info['val_fold_idxs'],
                                 num_patients=split_info['patients_in_val_fold'],
                                 batch_size=batch_size, shuffle=False, num_workers=self.num_workers)

    def test_dataloader(self):
        num_test_patients = len(np.unique(self.test_subject_ids))

        test_ds = \
            BrainDataWrapper(fcs=self.test_fcs,
                             adjs=self.test_scs,
                             subject_ids=self.test_subject_ids,
                             scan_dirs=self.test_scan_dirs,
                             tasks=self.test_tasks,
                             num_patients=num_test_patients,
                             fc_info=self.fc_info,
                             sc_info=self.sc_info,
                             label=self.label, label_norm=self.label_norm)
        batch_size = self.num_samples_test if (self.test_batch_size is None) else self.test_batch_size
        return DataLoader(test_ds, batch_size=batch_size, num_workers=self.num_workers)


class PsuedoSyntheticDiffusionDataset(Dataset):
    """Diffused white signals over a given brain graphs using graph filter H then
            compute summary statistics (Cov, Corr)
        Input: Covariance Cx
        Label: GSO S
    """

    def __init__(self,
                 adjs,
                 fc_info: Dict,
                 sc_info: Dict,
                 gso: str = 'adjacency',
                 label: str = 'adjacency',
                 label_norm=None,
                 num_signals=50,
                 subject_ids=None, scan_dirs=None, num_patients=None,
                 transform=None,
                 dtype=torch.float32,
                 split_id='',
                 seed=50
                 ):
        self.num_scans, self.N, _ = adjs.shape
        self.num_patients = num_patients
        self.subject_ids = subject_ids
        self.scan_dirs = scan_dirs
        # gso is used to perform diffusions
        self.gso = gso
        assert gso in ['adjacency', 'laplacian']
        # label is what is returned for training
        self.label, self.label_norm = label, label_norm
        assert label in ['adjacency', 'laplacian']
        self.fc_info, self.sc_info = fc_info, sc_info
        assert (self.fc_info['summary_statistic'] in ['sample_cov', 'sample_corr', 'analytic_cov', 'analytic_corr']), f"Dataset: sum_stat must be cov or corr, is {self.fc_info['summary_statistic']}\n"
        self.num_signals = num_signals
        self.transform = transform
        self.seed = seed
        self.dtype = dtype
        self.split_id = split_id

        print(f'\tDoes NOT exist! Creating dataset...')
        adjs = adjs if torch.is_tensor(adjs) else torch.from_numpy(adjs).to(dtype)

        # this samples signals and performs diffusions
        pl.seed_everything(self.seed)
        assert gso == 'adjacency', f'need to think about how this changes if not adj'
        s = construct_gso(adjs, gso) # diffuse with RAW
        sample_cov = diffusion_summary_stat(S=s, coeffs=self.fc_info['coeffs'], num_signals=num_signals, sum_stat=self.fc_info['summary_statistic'])
        sample_corr = correlation_from_covariance(sample_cov)
        analytic_cov, analytic_corr = analytic_summary_stats(S=adjs, coeffs=self.fc_info['coeffs'])

        which_norm, fc_norm_val = fc_info['normalization'], fc_info['normalization_value']
        self.sample_cov = normalize_slices(sample_cov, which_norm=which_norm, extra=fc_norm_val)
        self.sample_corr = normalize_slices(sample_corr, which_norm=which_norm, extra=fc_norm_val)
        self.analytic_cov = normalize_slices(analytic_cov, which_norm=which_norm, extra=fc_norm_val)
        self.analytic_corr = normalize_slices(analytic_corr, which_norm=which_norm, extra=fc_norm_val)

        self.adjs = adjs/sc_info['scaling']
        self.min_eigs = {'adjacency': torch.linalg.eigvalsh(self.adjs).min(dim=1)[0],
                         'laplacian': torch.linalg.eigvalsh(construct_gso(self.adjs, gso=self.label)).min(dim=1)[0]
                         }
        # new min eig
        if (label_norm is not None) and ('min_eig' in self.label_norm):
            set_min_eig, current_min_eigs = self.label_norm['min_eig'], self.min_eigs[self.label]
            z = torch.zeros_like(self.min_eigs[self.label])
            self.delta = torch.where(current_min_eigs < 1, (set_min_eig - current_min_eigs), z)

    def __len__(self):
        return self.adjs.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.numpy()

        n = self.adjs.shape[-1]
        # scaling already applied
        y = construct_gso(self.adjs[idx], gso=self.label)
        if self.label_norm is not None and self.label_norm['normalization'] == 'min_eig':
            y = y + self.delta[idx] * torch.eye(n)
            #min_eig_y = torch.linalg.eigvalsh(y).min(dim=1)[0].min()
            #assert torch.all(min_eig_y > (1 - 1e-3)), f'must be PSD to invert. min_eig(s + dI) =  {min_eig_y.min():.3f}'

        if self.fc_info['summary_statistic'] == 'sample_cov':
            x = self.sample_cov[idx]
        elif self.fc_info['summary_statistic'] == 'sample_corr':
            x = self.sample_corr[idx]
        elif self.fc_info['summary_statistic'] == 'analytic_cov':
            x = self.analytic_cov[idx]
        elif self.fc_info['summary_statistic'] == 'analytic_corr':
            x = self.analytic_corr[idx]
        else:
            raise ValueError(f"ps-ds: getitem: unrecognized summary statistic string {self.fc_info['summary_statistic']} used")

        return x.to(self.dtype), y.to(self.dtype), self.subject_ids[idx], "", ""

    def full_ds(self):
        n = self.adjs.shape[-1]
        y = construct_gso(self.adjs, gso=self.label)
        if self.label_norm is not None and self.label_norm['normalization'] == 'min_eig':
            y = y + self.delta.view(len(self.delta), 1, 1) * torch.eye(n).expand(len(y), n, n)

        if self.fc_info['summary_statistic'] == 'sample_cov':
            x = self.sample_cov
        elif self.fc_info['summary_statistic'] == 'sample_corr':
            x = self.sample_corr
        elif self.fc_info['summary_statistic'] == 'analytic_cov':
            x = self.analytic_cov
        elif self.fc_info['summary_statistic'] == 'analytic_corr':
            x = self.analytic_corr
        else:
            raise ValueError(
                f"ps-ds: getitem: unrecognized summary statistic string {self.fc_info['summary_statistic']} used")

        return x.to(self.dtype), y.to(self.dtype), self.subject_ids, "", ""

class PsuedoSyntheticDataModule(RealDataModule):

    def __init__(self,
                 seed=50,
                 batch_size: int = 256,
                 val_batch_size=None,
                 test_batch_size=None,
                 num_patients_test: Optional[int] = 5,
                 num_patients_val: Optional[int] = 70,
                 num_workers: int = 2,
                 num_splits: Optional[int] = None,
                 transform=None,
                 # brain info
                 include_subcortical: bool = False,
                 fc_info: Dict = {'normalization': 'max_eig', 'normalization_val': 'symeig', 'coeffs': None},
                 sc_info: Dict = {'scaling': 9.9, 'edge_density_low': 0.35},
                 gso: str = 'adjacency',
                 label: str = 'adjacency',
                 label_norm = None,
                 # A_O construction
                 num_signals=50):
        super().__init__(
            seed=seed,
            batch_size=batch_size,
            val_batch_size=val_batch_size,
            test_batch_size=test_batch_size,
            num_patients_test=num_patients_test,
            num_patients_val=num_patients_val,
            num_workers=num_workers,
            num_splits=num_splits,
            transform=transform,
            include_subcortical=include_subcortical,
            fc_info=fc_info, sc_info=sc_info,
            label=label
        )
        # variables unique to Psuedo-Synthetic
        self.num_signals = num_signals
        assert 'coeffs' in fc_info , f'fc_info missing required coeffs'

        self.fc_info, self.sc_info = fc_info, sc_info
        self.gso = gso
        self.label, self.label_norm = label, label_norm
        self.train_seed, self.val_seed, self.test_seed = seed, seed + 10, seed + 50

        self.p_synth_ds, self.train_ds, self.val_ds, self.test_ds = None, None, None, None
        self.train_dl, self.val_dl, self.test_dl = None, None, None

    def train_val_dl(self, idxs, num_patients, batch_size, shuffle, seed, num_workers=0):
        split_scs, split_subject_ids, split_scan_dirs = \
            self.all_train_scs[idxs], self.all_train_subject_ids[idxs], self.all_train_scan_dirs[idxs]

        ps_ds = PsuedoSyntheticDiffusionDataset(adjs=split_scs,
                                                fc_info=self.fc_info,
                                                sc_info=self.sc_info,
                                                gso=self.gso,
                                                label=self.label, label_norm=self.label_norm,
                                                num_signals=self.num_signals,
                                                subject_ids=split_subject_ids,
                                                scan_dirs=split_scan_dirs,
                                                num_patients=num_patients,
                                                split_id=str(self.split),
                                                seed=seed)
        # pin_memory makes it faster, but annoying warning printed (I think)
        return DataLoader(ps_ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)#torch.cuda.is_available())

    def train_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        batch_size = len(split_info['train_fold_idxs']) if self.batch_size is None else self.batch_size
        if self.train_dl is None:
            self.train_dl =\
                self.train_val_dl(idxs=split_info['train_fold_idxs'],
                                  num_patients=split_info['patients_in_train_fold'],
                                  batch_size=batch_size, shuffle=True,
                                  seed=self.train_seed,
                                  num_workers=self.num_workers)
        return self.train_dl

    def val_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        batch_size = len(split_info['val_fold_idxs']) if self.val_batch_size is None else self.val_batch_size

        # will only work for train/val/test split!
        if self.val_dl is None:
            self.val_dl = self.train_val_dl(idxs=split_info['val_fold_idxs'],
                                            num_patients=split_info['patients_in_val_fold'],
                                            batch_size=batch_size, shuffle=False,
                                            seed=self.val_seed,
                                            num_workers=self.num_workers)
        return self.val_dl

    def test_dataloader(self):
        num_test_patients = len(np.unique(self.test_subject_ids))
        ps_ds = PsuedoSyntheticDiffusionDataset(adjs=self.test_scs,
                                                fc_info=self.fc_info,
                                                sc_info=self.sc_info,
                                                gso=self.gso,
                                                label=self.label, label_norm=self.label_norm,
                                                num_signals=self.num_signals,
                                                subject_ids=self.test_subject_ids,
                                                scan_dirs=self.test_scan_dirs,
                                                num_patients=num_test_patients,
                                                split_id=str(self.split),
                                                seed=self.test_seed)
        batch_size = num_test_patients if self.test_batch_size is None else self.test_batch_size
        return DataLoader(ps_ds, batch_size=batch_size, shuffle=False, num_workers=self.num_workers)



if __name__ == "__main__":
    """
    dm_args = {'num_splits': None,
               'num_patients_val': 100,
               'num_patients_test': 100,
               'num_workers': 4 if torch.cuda.is_available() else 0,
               'batch_size': 50,
               'seed': 50,
               'fc_info': {'construction': 'concat_all_timeseries', 'scan_combination': 'mean',
                           'remove_diag': False, 'summary_statistic': 'sample_cov',
                           'normalization': 'max_eig', 'normalization_value': 'symeig',
                           'frob_norm_high': None},
               'sc_info': {'scaling': 9.9, 'edge_density_low': 0.35}
               }
    dm = RealDataModule(**dm_args)
    dm.setup('fit')
    train_dl = dm.train_dataloader()
    print('num training samples: ', len(train_dl.dataset))
    """

    ## PS ##
    dm_args = {'seed': 50,
               'fc_info': {'summary_statistic': 'sample_cov', 'normalization': 'max_eig', 'normalization_val': 'symeig',
                           'coeffs': torch.tensor([0.5, 0.5, 0.2])},
               'sc_info': {'edge_density_low': 0.35},
               'label': 'laplacian',
               'label_norm': {'normalization': 'min_eig', 'min_eig': 1.0},
               'num_workers': 0,
               'num_signals': 50}
    dm = PsuedoSyntheticDataModule(**dm_args)
    dm.setup('fit')

    # viz some fcs/sc pairs
    batch = next(iter(dm.train_dataloader()))
    x, y = batch[:2]
    fig, axes = plt.subplots(nrows=2, ncols=5, constrained_layout=True)
    n = x.shape[-1]
    x_max = x[0:5].abs().max()
    import matplotlib as mpl, copy
    inferno = copy.copy(mpl.cm.get_cmap("inferno"))
    raw_sc_cm = copy.copy(mpl.cm.get_cmap("gist_gray"))
    zd = torch.ones(n, n) - torch.eye(n)
    for i in range(5):
        axes[0, i].imshow(x[i], vmin=-x_max, vmax=x_max, cmap=inferno)
        axes[1, i].imshow(-y[i]*zd, vmin=0, vmax=1, cmap=raw_sc_cm)

    print('num training samples in train: ', len(dm.train_dataloader()), 'size of batch: ', len(x))


