import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
import numpy as np
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
from sklearn.utils import shuffle

from data.brain_data.matlab_to_python_brain_data import load_brain_data, subnetwork_masks
from data.dataset import PureSyntheticDiffusionDataset, PsuedoSyntheticDiffusionDataset, BrainDataWrapper
from utils.util_funcs import filter_repeats, sparsity, upper_tri_as_vec_batch, add_bool_arg, correlation_from_covariance

from typing import Optional
import warnings


class SyntheticDataModule(pl.LightningDataModule):
    def __init__(self,
                 num_vertices=68,
                 num_signals=50,
                 batch_size=512,
                 val_batch_size=None,
                 test_batch_size=None,
                 num_samples_train=1000,
                 num_samples_val=500,
                 num_samples_test=500,
                 sum_stat='sample_cov',
                 graph_gen='geom',
                 r=0.56,
                 sparse_thresh_low=0.5,
                 sparse_thresh_high=0.6,
                 norm_before_diffusion=False,
                 binarize_labels_for_train=False,
                 fc_norm: Optional[str] = 'max_eig',
                 fc_norm_val: Optional[float] = 'symeig',
                 transform=None,
                 rand_seed=None,
                 num_train_workers=0,
                 num_val_workers=0,
                 num_test_workers=0,
                 coeffs=np.array([0.5, 0.5, 0.2])):#coeffs=np.array([-.18, 2.446, -.0211])):
        super().__init__()
        self.num_vertices = num_vertices
        self.num_signals = num_signals
        self.num_samples_train, self.num_samples_val, self.num_samples_test = \
            num_samples_train, num_samples_val, num_samples_test
        if self.num_samples_train == self.num_samples_val and self.num_samples_train>0: #they will load identical datasets otherwise
            self.num_samples_val += 1
        if self.num_samples_val == self.num_samples_test and self.num_samples_val>0:
            self.num_samples_test += 1

        self.graph_gen = graph_gen
        self.r = r
        self.sparse_thresh_low, self.sparse_thresh_high = sparse_thresh_low, sparse_thresh_high

        self.batch_size, self.val_batch_size, self.test_batch_size = batch_size, val_batch_size, test_batch_size

        self.sum_stat = sum_stat

        self.norm_before_diffusion, self.binarize_labels_for_train = norm_before_diffusion, binarize_labels_for_train
        self.fc_norm, self.fc_norm_val = fc_norm, fc_norm_val

        self.coeffs = coeffs

        self.transform = transform
        self.rand_seed = rand_seed
        self.num_train_workers = num_train_workers
        self.num_val_workers = num_val_workers
        self.num_test_workers = num_test_workers
        self.synth_ds, self.train_ds, self.val_ds, self.test_ds = None, None, None, None

        self.subnetwork_masks = {'full': torch.ones((num_vertices, num_vertices), dtype=torch.bool).view(1, num_vertices, num_vertices)}

    def setup(self, stage=None):

        if (stage == 'fit') or (stage is None):
            seed_everything(self.rand_seed, workers=True)

            kwargs_dict =\
                {
                'num_vertices': self.num_vertices,
                'num_signals': self.num_signals,
                'graph_gen': self.graph_gen,
                'r': self.r,
                'sum_stat': self.sum_stat,
                'sparse_thresh_low': self.sparse_thresh_low, 'sparse_thresh_high': self.sparse_thresh_high,
                'norm_before_diff': self.norm_before_diffusion,
                'bin_labels_after_diff': self.binarize_labels_for_train,
                'coeffs': self.coeffs,
                'fc_norm': self.fc_norm,
                'fc_norm_val': self.fc_norm_val
                 }

            if self.num_samples_train > 0:
                self.train_ds = PureSyntheticDiffusionDataset(num_samples=self.num_samples_train,
                                                              rand_seed=self.rand_seed,
                                                              **kwargs_dict)
            if self.num_samples_val > 0:
                self.val_ds = PureSyntheticDiffusionDataset(num_samples=self.num_samples_val,
                                                            rand_seed=self.rand_seed+17,
                                                            **kwargs_dict)
            if self.num_samples_test > 0:
                self.test_ds = PureSyntheticDiffusionDataset(num_samples=self.num_samples_test,
                                                             rand_seed=self.rand_seed+50,
                                                             **kwargs_dict)
        else:
            #print(f'Synthetic Datamodule: have not implimented a {stage} part')
            return

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_train_workers)

    def val_dataloader(self):
        bs = self.num_samples_val if (self.val_batch_size is None) else self.val_batch_size
        return DataLoader(self.val_ds, batch_size=bs, num_workers=self.num_val_workers) #cant go into debug mode if >0

    def test_dataloader(self):
        bs = self.num_samples_test if (self.test_batch_size is None) else self.test_batch_size
        return DataLoader(self.test_ds, batch_size=bs, num_workers=self.num_train_workers) # if this is >0 for large graphs => timeout


class RealDataModule(pl.LightningDataModule):
    def __init__(self,
                 rand_seed=50,
                 batch_size: int = 256,
                 val_batch_size=None,
                 test_batch_size=None,
                 num_patients_test: Optional[int] = 5,
                 num_patients_val: Optional[int] = 70,
                 num_train_workers: int = 2,
                 num_val_workers: int = 2,
                 num_test_workers: int = 0,
                 num_splits: Optional[int] = None,
                 binarize_labels_for_train=False,
                 transform=None,
                 include_subcortical: bool = False,
                 subset_construction=None,
                 fc_construction: Optional[str] = 'concat_all_timeseries',
                 scan_combination: Optional['str'] = None,
                 remove_fc_diag = False,
                 sum_stat = 'cov',
                 fc_norm: Optional[str] = None,
                 fc_norm_val: Optional[float] = None,
                 fc_filter: Optional[float] = None,
                 sc_filter: Optional[float] = .35
                 ):
        super().__init__()
        assert batch_size is not None, f'batch_size is None, likely error with CLI parser'
        self.rand_seed = rand_seed
        self.batch_size, self.val_batch_size, self.test_batch_size = batch_size, val_batch_size, test_batch_size
        self.binarize_labels_for_train = binarize_labels_for_train
        self.transform = transform

        if num_patients_test is not None and num_patients_test is not None:
            if num_patients_val == num_patients_test:
                num_patients_test += 1
                num_patients_val  -=1
                print(f'val/test sizes identical. Adjusted val/test size for even SUBJECT split (not scan split). For consistent loading.')

        self.num_train_workers = num_train_workers
        self.num_val_workers = num_val_workers
        self.num_test_workers = num_test_workers

        self.include_subcortical = include_subcortical
        self.metadata = None
        self.remove_fc_diag = remove_fc_diag
        self.subset_construction = subset_construction
        self.fc_construction, self.scan_combination = fc_construction, scan_combination
        self.sum_stat = sum_stat
        self.fc_norm, self.fc_norm_val = fc_norm, fc_norm_val
        self.fc_filter, self.sc_filter = fc_filter, sc_filter

        self.interesting_edges = None

        # data immediately split into train/test
        self.all_train_idxs, self.test_idxs = None, None
        self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids, self.all_train_scan_dirs = None, None, None, None
        self.test_fcs, self.test_scs, self.test_subject_ids, self.test_scan_dirs = None, None, None, None

        # for cross validation
        self.num_splits, self.split = num_splits, 0
        if self.num_splits is not None:
            if (num_patients_val is not None):
                print(f'WARNING: Can either specify size of validation set for single train/val/test split OR '
                      f'num_splits for cross validation. Not both.')
            print(f'Using Cross Validation...')
            self.num_patients_val = self.num_patients_test = None
        else:
            self.num_patients_val, self.num_patients_test = num_patients_val, num_patients_test


        # length 1 array for simple train/val/test split
        # length num_splits array for cross validation
        self.train_val_splits = []

        # ensure fit only run once
        self.fit_already_run = False

        self.subnetwork_masks = None

    @staticmethod
    def add_module_specific_args(parser):
        # data splits

        parser.add_argument('--num_splits', type=int)
        parser.add_argument('--num_splits_train', type=int)
        parser.add_argument('--num_patients_val', type=Optional[int])
        parser.add_argument('--num_patients_test', type=Optional[int])

        # which data
        add_bool_arg(parser, 'include_subcortical', default=False)
        parser.add_argument('--scan_combination', type=str)

        parser.add_argument('--batch_size', type=int)
        parser.add_argument('--num_train_workers', type=int)

        return parser

    def full_ds(self):
        # convert from .mat to python datastructures
        brain_data, self.metadata = load_brain_data()
        # extract relevant dataset
        fcs, scs, subject_ids, scan_dirs, tasks = matlab_struct_to_tensors(brain_data, fc_construction=self.fc_construction, scan_combination=self.scan_combination)
        fcs = (correlation_from_covariance(fcs) if self.sum_stat=='corr' else fcs)

        return fcs, scs, subject_ids, scan_dirs

    def setup(self, stage=None):
        if (stage == 'fit' or stage is None) and (not self.fit_already_run):
            self.fit_already_run = True
            if self.rand_seed is not None:
                seed_everything(self.rand_seed, workers=True)

            # convert from .mat to python datastructures
            brain_data, self.metadata = load_brain_data()
            self.subnetwork_masks = subnetwork_masks(self.metadata["idx2label_cortical"]["lobe"])

            # extract relevant dataset
            fcs, scs, subject_ids, scan_dirs, tasks = matlab_struct_to_tensors(brain_data, fc_construction=self.fc_construction, scan_combination=self.scan_combination)

            # remove outliers in the data
            remove_fc_idxs, remove_sc_idxs = self.filter_data(fcs, scs, subject_ids, scan_dirs, tasks)

            keep_idxs = torch.logical_not(remove_sc_idxs | remove_fc_idxs)
            fcs, scs, subject_ids, scan_dirs, tasks = \
                fcs[keep_idxs], scs[keep_idxs], subject_ids[keep_idxs], scan_dirs[keep_idxs], tasks[keep_idxs]

            #shuffle data
            fcs, scs, subject_ids, scan_dirs, tasks = shuffle(fcs, scs, subject_ids, scan_dirs, tasks, random_state=self.rand_seed)

            # optionally remove subcortical nodes
            # only cortical by default now!
            """
            num_subcortical = len(self.metadata['subcortical_roi'])
            if not self.include_subcortical:
                fcs = fcs[:, num_subcortical:, num_subcortical:]
                scs = scs[:, num_subcortical:, num_subcortical:]
            """
            # train/test split
            if self.num_patients_test is None:
                self.all_train_idxs = np.arange(0, len(fcs))
                self.test_idxs = []
            else:
                train_test_split_group = GroupShuffleSplit(n_splits=1, test_size=self.num_patients_test, random_state=self.rand_seed)
                self.all_train_idxs, self.test_idxs = next(train_test_split_group.split(fcs, scs, subject_ids)) # subject_ids are group

            self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids, self.all_train_scan_dirs, self.all_train_tasks = \
                fcs[self.all_train_idxs], scs[self.all_train_idxs], subject_ids[self.all_train_idxs], scan_dirs[self.all_train_idxs], tasks[self.all_train_idxs]
            self.test_fcs, self.test_scs, self.test_subject_ids, self.test_scan_dirs, self.test_tasks = \
                fcs[self.test_idxs],  scs[self.test_idxs],  subject_ids[self.test_idxs], scan_dirs[self.test_idxs], tasks[self.test_idxs]

            # Cross Validation
            if self.num_splits is not None:
                # further split train into num_splits folds
                val_split_iter = GroupKFold(n_splits=self.num_splits).split(self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids)

            # Straight train/val/test split
            else:
                val_split = GroupShuffleSplit(n_splits=1, test_size=self.num_patients_val, random_state=self.rand_seed)
                val_split_iter = val_split.split(self.all_train_fcs, self.all_train_scs, self.all_train_subject_ids)

            for fold, (train_fold_idxs, val_fold_idxs) in enumerate(val_split_iter): #idxs for all_train_*
                #self.train_val_splits.append((train_fold_idxs, val_fold_idxs))
                #unique_scs_train_folds = filter_repeats(self.all_train_scs[train_fold_idxs], self.all_train_subject_ids[train_fold_idxs])
                #unique_scs_val_fold    = filter_repeats(self.all_train_scs[val_fold_idxs], self.all_train_subject_ids[val_fold_idxs])
                #unique_scs_test        = filter_repeats(self.test_scs, self.test_subject_ids)

                patients_in_train_fold = len(np.unique(self.all_train_subject_ids[train_fold_idxs]))
                patients_in_val_fold = len(np.unique(self.all_train_subject_ids[val_fold_idxs]))
                split_info = {'train_fold_idxs': train_fold_idxs, 'patients_in_train_fold': patients_in_train_fold,
                    'val_fold_idxs': val_fold_idxs, 'patients_in_val_fold': patients_in_val_fold}
                print(f'\tpatients in train fold: {patients_in_train_fold}, val fold: {patients_in_val_fold}, test set {len(np.unique(self.test_subject_ids[self.test_idxs]))}')
                self.train_val_splits.append(split_info)

        elif (stage == 'fit') and self.fit_already_run:
            print(f'stage is {stage}. Already run fit?: {self.fit_already_run}. Prob being called again in trainer. Dont run twice...')

    def filter_data(self, fcs, scs, subject_ids, scan_dirs, tasks):
        if self.fc_filter is not None:
            # filter data by scs (sparsity) and fcs (Frob norm)

            remove_fc_idxs = torch.linalg.norm(fcs, ord='fro', dim=(1, 2)) > self.fc_filter
            rm_fcs_ids = subject_ids[np.where(remove_fc_idxs)]
            rm_fcs_scan_dirs = scan_dirs[np.where(remove_fc_idxs)]
            rm_fcs_tasks = tasks[np.where(remove_fc_idxs)]
            print(f'FC filter: Removing following scans bc: ||Cov||_F > {self.fc_filter}')
            for id, dir, task in zip(rm_fcs_ids, rm_fcs_scan_dirs, rm_fcs_tasks):
                print(f'\t{id}-{dir}-{task}')
        else:
            remove_fc_idxs = torch.zeros(len(fcs), dtype=torch.bool)

        if self.sc_filter is not None:
            remove_sc_idxs = sparsity(As=scs, directed=False, self_loops=False) < self.sc_filter
            rm_scs_ids = subject_ids[np.where(remove_sc_idxs)]
            rm_scs_scan_dirs = scan_dirs[np.where(remove_sc_idxs)]
            rm_scs_tasks = tasks[np.where(remove_sc_idxs)]
            print(f'SC filter: Removing following scans bc: Sparsity(SC) < {self.sc_filter}')
            for id, dir, task in zip(rm_scs_ids, rm_scs_scan_dirs, rm_scs_tasks):
                print(f'\t{id}-{dir}-{task}')
        else:
            remove_sc_idxs = torch.zeros(len(scs), dtype=torch.bool)

        return remove_fc_idxs, remove_sc_idxs

    ############
    # For Cross Validation
    def set_split(self, split):
        if self.num_splits is None:
            # self.split must always be 0 when using train/val/test
            warnings.warn(f'Setting split Cross Validation split, but only train/val/test split was made (and thus only 1 CV split. Keeping Split set to 0.')
            return
        assert 0 <= self.split < self.num_splits, f'self.splits must be in [0, {self.num_splits-1}'
        self.split = split
    ############

    def train_val_dl(self, idxs, num_patients, batch_size, shuffle, num_workers=0):
        split_fcs, split_scs, split_subject_ids, split_scan_dirs, split_tasks = \
            self.all_train_fcs[idxs], self.all_train_scs[idxs], self.all_train_subject_ids[idxs], self.all_train_scan_dirs[idxs], self.all_train_tasks[idxs]

        ds = BrainDataWrapper(fcs=split_fcs,
                              adjs=split_scs,
                              subject_ids=split_subject_ids,
                              scan_dirs=split_scan_dirs,
                              tasks=split_tasks,
                              num_patients=num_patients,
                              binarize_labels=self.binarize_labels_for_train,
                              remove_fc_diag=self.remove_fc_diag,
                              sum_stat=self.sum_stat,
                              fc_norm=self.fc_norm,
                              fc_norm_val=self.fc_norm_val)

        #TODO
        # num_workers>0 causes this warning:
        #  "[W ParallelNative.cpp:206] Warning: Cannot set number of intraop threads
        #  after parallel work has started or after set_num_threads call when using
        #  native parallel backend (function set_num_threads)"
        # There does not seem to be a solution to this: https://stackoverflow.com/questions/64772335/pytorch-w-parallelnative-cpp206:w

        return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)#torch.cuda.is_available())

    def train_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        if self.batch_size == 'all' or self.batch_size is None:
            batch_size = len(split_info['train_fold_idxs'])
        else:
            batch_size = self.batch_size
        return self.train_val_dl(idxs=split_info['train_fold_idxs'],
                                 num_patients=split_info['patients_in_train_fold'],
                                 batch_size=batch_size, shuffle=True, num_workers=self.num_train_workers)

    def val_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        bs = len(split_info['val_fold_idxs']) if (self.val_batch_size is None) else self.val_batch_size
        return self.train_val_dl(idxs=split_info['val_fold_idxs'],
                                 num_patients=split_info['patients_in_val_fold'],
                                 batch_size=bs, shuffle=False, num_workers=self.num_val_workers)

    def test_dataloader(self):
        num_test_patients = len(np.unique(self.test_subject_ids))

        test_ds = \
            BrainDataWrapper(fcs=self.test_fcs,
                             adjs=self.test_scs,
                             subject_ids=self.test_subject_ids,
                             scan_dirs=self.test_scan_dirs,
                             tasks = self.test_tasks,
                             num_patients=num_test_patients,
                             binarize_labels=self.binarize_labels_for_train,
                             remove_fc_diag=self.remove_fc_diag,
                             sum_stat=self.sum_stat,
                             fc_norm=self.fc_norm, fc_norm_val=self.fc_norm_val)
        bs = self.num_samples_test if (self.test_batch_size is None) else self.test_batch_size
        return DataLoader(test_ds, batch_size=bs, num_workers=self.num_test_workers)


class PsuedoSyntheticDataModule(RealDataModule):

    def __init__(self,
                 rand_seed=50,
                 batch_size=128,
                 val_batch_size=None,
                 test_batch_size=None,
                 num_patients_test: int = 5,
                 num_patients_val: Optional[int] = 70,
                 num_train_workers: int = 0,
                 num_val_workers: int = 2,
                 num_test_workers: int = 0,
                 num_splits: Optional[int] = None,
                 binarize_labels_for_train=False,
                 transform=None,
                 include_subcortical=False,
                 num_signals=50,
                 sum_stat='sample_cov',
                 diffuse_over_weighted=True,
                 norm_before_diffusion=False,
                 remove_fc_diag = False,
                 fc_norm=None,
                 fc_norm_val=None,
                 sc_filter: Optional[float] = .35,
                 coeffs=np.array([0.5, 0.5, 0.2])):#np.array([-.18, 2.446, -.0211])):
        super().__init__(
            rand_seed=rand_seed,
            batch_size=batch_size,
            val_batch_size=val_batch_size,
            test_batch_size=test_batch_size,
            num_patients_test=num_patients_test,
            num_patients_val=num_patients_val,
            num_train_workers=num_train_workers,
            num_val_workers=num_val_workers,
            num_test_workers=num_test_workers,
            num_splits=num_splits,
            binarize_labels_for_train=binarize_labels_for_train,
            transform=transform,
            include_subcortical=include_subcortical,
            subset_construction=None,
            remove_fc_diag=remove_fc_diag,
            fc_construction='concat_all_timeseries',
            scan_combination=None,
            fc_norm=fc_norm,
            fc_norm_val=fc_norm_val,
            fc_filter=None,
            sc_filter=sc_filter
        )
        # variables unique to Psuedo-Synthetic
        self.num_signals = num_signals
        # self.num_samples_val, self.num_samples_test = num_samples_val, num_samples_test
        self.sum_stat, self.diffuse_over_weighted, self.norm_before_diffusion = sum_stat, diffuse_over_weighted, norm_before_diffusion
        self.coeffs = coeffs

        self.p_synth_ds, self.train_ds, self.val_ds, self.test_ds = None, None, None, None

    @staticmethod
    def add_module_specific_args(parser):
        # data splits

        parser.add_argument('--num_splits', type=int)
        parser.add_argument('--num_splits_train', type=int)
        parser.add_argument('--num_patients_val', type=Optional[int])
        parser.add_argument('--num_patients_test', type=Optional[int])

        # which data
        add_bool_arg(parser, 'include_subcortical', default=False)
        #parser.add_argument('--scan_combination', type=str)

        parser.add_argument('--batch_size', type=int)

        parser.add_argument('--coeffs', '--list', nargs='+', type=float, help='<Required> Set flag')
        parser.add_argument('--num_signals', type=int)
        parser.add_argument('--sum_stat', type=str,
                            choices=['sample_cov', 'analytic_cov', 'sample_corr', 'analytic_corr'])
        add_bool_arg(parser, 'diffuse_over_weighted', default=True)
        add_bool_arg(parser, 'norm_before_diffusion', default=False)

        return parser

    def train_val_dl(self, idxs, num_patients, batch_size, shuffle, num_workers=0):
        split_scs, split_subject_ids, split_scan_dirs = \
            self.all_train_scs[idxs], self.all_train_subject_ids[idxs], self.all_train_scan_dirs[idxs]

        ps_ds = PsuedoSyntheticDiffusionDataset(adjs=split_scs,
                                                subject_ids=split_subject_ids,
                                                scan_dirs=split_scan_dirs,
                                                num_patients=num_patients,
                                                num_signals=self.num_signals,
                                                sum_stat=self.sum_stat,
                                                diff_over_weighted=self.diffuse_over_weighted,
                                                norm_before_diff=self.norm_before_diffusion,
                                                bin_labels_after_diff=self.binarize_labels_for_train,
                                                fc_norm=self.fc_norm,
                                                fc_norm_val=self.fc_norm_val,
                                                coeffs=self.coeffs,
                                                split_id=str(self.split))
        # pin_memory makes it faster, but annoying warning printed (I think)
        return DataLoader(ps_ds, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=False)#torch.cuda.is_available())

    def train_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        if self.batch_size == 'all':
            batch_size = len(split_info['train_fold_idxs'])
        else:
            batch_size = self.batch_size
        return self.train_val_dl(idxs=split_info['train_fold_idxs'],
                                 num_patients=split_info['patients_in_train_fold'],
                                 batch_size=batch_size, shuffle=True, num_workers=self.num_train_workers)

    def val_dataloader(self):
        assert (self.split == 0) or (self.num_patients_val is None), f'self.split must always be 0 when using train/val/test. Is {self.split}'
        split_info = self.train_val_splits[self.split]
        return self.train_val_dl(idxs=split_info['val_fold_idxs'],
                                 num_patients=split_info['patients_in_val_fold'],
                                 batch_size=len(split_info['val_fold_idxs']), shuffle=False, num_workers=1 if torch.cuda.is_available() else 0)

    def test_dataloader(self):
        num_test_patients = len(np.unique(self.test_subject_ids))

        ps_ds = PsuedoSyntheticDiffusionDataset(adjs=self.test_scs,
                                                subject_ids=self.test_subject_ids,
                                                scan_dirs=self.test_scan_dirs,
                                                num_patients=num_test_patients,
                                                num_signals=self.num_signals,
                                                sum_stat=self.sum_stat,
                                                diff_over_weighted=self.diffuse_over_weighted,
                                                norm_before_diff=self.norm_before_diffusion,
                                                bin_labels_after_diff=self.binarize_labels_for_train,
                                                fc_norm=self.fc_norm,
                                                fc_norm_val=self.fc_norm_val,
                                                coeffs=self.coeffs,
                                                split_id=str(self.split))
        return DataLoader(ps_ds, batch_size=num_test_patients, shuffle=False, num_workers=1 if torch.cuda.is_available() else 0)


def scan_composition_text(lr, rl):
    if lr and rl:  # or check if
        scan_dir = 'ave'
    elif lr:
        scan_dir = 'lr'
    elif rl:
        scan_dir = 'rl'
    else:
        scan_dir = None
    return scan_dir


def parse_scandir(exist_LR, exist_RL):
    if exist_LR and exist_RL:
        scan_dir = 'LR+RL'
    elif exist_LR:
        scan_dir = 'LR'
    elif exist_RL:
        scan_dir = 'RL'
    else:
        raise ValueError('Must have some scan dir')
    return scan_dir


def parse_task(exist_REST1, exist_REST2):
    if exist_REST1 and exist_REST2:
        task = '1+2'
    elif exist_REST1:
        task = '1'
    elif exist_REST2:
        task = '2'
    else:
        raise ValueError('Must have some task')
    return task


def find_scandir_and_task(which_scans_exist, scan_names, fc_construction='concat_all_timeseries', task=None, scan_dir=None):
    # does the fc contain and LR timeseries, RL timeseries, or both?
    # does the fc contain and REST1 timeseries, REST2 timeseries, or both?
    # scan_names = ['REST1_LR', 'REST1_RL', 'REST2_LR', 'REST2_RL']

    if fc_construction == 'concat_all_timeseries':
        exist_LR, exist_RL, exist_REST1, exist_REST2 = False, False, False, False

        for scan in scan_names:
            # If the scan exists, take an accounting of which one it is
            if which_scans_exist[scan]:
                if 'LR' in scan: exist_LR = True
                if 'RL' in scan: exist_RL = True
                if 'REST1' in scan: exist_REST1 = True
                if 'REST2' in scan: exist_REST2 = True

        scan_dir = parse_scandir(exist_LR, exist_RL)
        task = parse_task(exist_REST1, exist_REST2)

        return scan_dir, task

    elif fc_construction == 'concat_task_timeseries':
        exist_LR, exist_RL = False, False
        for scan in scan_names:
            if which_scans_exist[scan] and (task in scan):
                if 'LR' in scan: exist_LR = True
                if 'RL' in scan: exist_RL = True

        if not (exist_LR or exist_RL):
            print('error, must be one')

        scan_dir = parse_scandir(exist_LR, exist_RL)
        return scan_dir, task

    elif fc_construction == 'concat_scandir_timeseries':
        exist_REST1, exist_REST2 = False, False
        for scan in scan_names:
            if which_scans_exist[scan] and (scan_dir in scan):
                if 'REST1' in scan: exist_REST1 = True
                if 'REST2' in scan: exist_REST2 = True

        if not (exist_REST1 or exist_REST2):
            print('error, must be one')
        task = parse_task(exist_REST1, exist_REST2)
        return scan_dir, task


def unpack_individual(subject_data_struct_list, scan_combination, scan_names, key='individual'):
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            task, scan_dir, scans = "", "", []

            for scan in scan_names:
                if s['which_scans_exist'][scan]: # we should never hit NaN fc now
                    fc = torch.tensor(s['fcs'][key].item()[scan].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir += ('LR|' if ('LR' in scan) else 'RL|')
                    task += '1|' if ('REST1' in scan) else '2|'
            if any(s['which_scans_exist'].item()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                subject_ids.append(s['subject_id'])
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean)
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
    else:
        for s in subject_data_struct_list:
            for scan in scan_names:
                if s['which_scans_exist'][scan]: # we should never hit NaN fc now
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[scan].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dirs.append('LR') if 'LR' in scan else scan_dirs.append('RL')
                    tasks.append('1') if 'REST1' in scan else tasks.append('2')

    return fcs, scs, subject_ids, scan_dirs, tasks


def unpack_task_concat(subject_data_struct_list, scan_combination, scan_names, key, grouped_fc_names, fc_construction):
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            which_task_exist = {}
            which_task_exist['REST1'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST1_RL'].item()
            which_task_exist['REST2'] = s['which_scans_exist']['REST2_LR'].item() or s['which_scans_exist']['REST2_RL'].item()
            task, scan_dir, scans = "", "", []

            for fc_name in grouped_fc_names:
                if which_task_exist[fc_name]:  # only proceed if we have something, if not will get error
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir_, task_ = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, task=fc_name)
                    scan_dir += scan_dir_ + '|'
                    task += task_ + '|'
            if any(which_task_exist.values()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean.repeat(1, 1, 1))
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
                subject_ids.append(s['subject_id'])
    else:
        for s in subject_data_struct_list:
            which_task_exist = {}
            which_task_exist['REST1'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST1_RL'].item()
            which_task_exist['REST2'] = s['which_scans_exist']['REST2_LR'].item() or s['which_scans_exist']['REST2_RL'].item()
            for fc_name in grouped_fc_names:
                if which_task_exist[fc_name]:  # only proceed if we have something, if not will get error
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, task=fc_name)
                    scan_dirs.append(scan_dir)
                    tasks.append(task)

    return fcs, scs, subject_ids, scan_dirs, tasks


def unpack_scandir_concat(subject_data_struct_list, scan_combination, scan_names, key, grouped_fc_names, fc_construction):
    fcs, scs, non_repeat_scs = [], [], []
    subject_ids, scan_dirs, tasks = [], [], []

    if scan_combination == 'mean':
        for s in subject_data_struct_list:
            which_scandir_exist = {}
            which_scandir_exist['LR'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST2_LR'].item()
            which_scandir_exist['RL'] = s['which_scans_exist']['REST1_RL'].item() or s['which_scans_exist']['REST2_RL'].item()
            task, scan_dir, scans = "", "", []

            for fc_name in grouped_fc_names:
                if which_scandir_exist[fc_name]:  # only proceed if we have something, if not will get error
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    scans.append(fc.repeat(1, 1, 1))
                    scan_dir_, task_ = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, scan_dir=fc_name)
                    scan_dir += scan_dir_ + '|'
                    task += task_ + '|'
            if any(which_scandir_exist.values()):
                scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))
                fc_mean = torch.cat(scans, dim=0).mean(dim=0).repeat(1, 1, 1)
                fcs.append(fc_mean.repeat(1, 1, 1))
                scan_dirs.append(scan_dir[:-1])
                tasks.append(task[:-1])
                subject_ids.append(s['subject_id'])
    else:
        for s in subject_data_struct_list:
            which_scandir_exist = {}
            which_scandir_exist['LR'] = s['which_scans_exist']['REST1_LR'].item() or s['which_scans_exist']['REST2_LR'].item()
            which_scandir_exist['RL'] = s['which_scans_exist']['REST1_RL'].item() or s['which_scans_exist']['REST2_RL'].item()
            for fc_name in grouped_fc_names:
                if which_scandir_exist[fc_name]:  # only proceed if we have something, if not will get error
                    scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
                    subject_ids.append(s['subject_id'])
                    fc = torch.tensor(s['fcs'][key].item()[fc_name].item())
                    fcs.append(fc.repeat(1, 1, 1))
                    scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names, fc_construction=fc_construction, scan_dir=fc_name)
                    scan_dirs.append(scan_dir)
                    tasks.append(task)

    return fcs, scs, subject_ids, scan_dirs, tasks


# util funcs
def matlab_struct_to_tensors(subject_data_struct_list, fc_construction='concat_all_timeseries', scan_combination='mean', datatype=torch.float32):
    # we require tensor of fcs and scs.
    # Map from input structure to tensors
    # construct dataset of inputs (fcs) and labels (scs). Each patient can have multiple
    #  fcs (REST1/2 and scan direction LR vs RL). Must also output list of unique subject id's
    #  to properly split dataset for train/validation/test

    assert fc_construction in ['individual_timeseries', 'concat_all_timeseries', 'concat_task_timeseries', 'concat_scandir_timeseries']

    scan_names = ['REST1_LR', 'REST1_RL', 'REST2_LR', 'REST2_RL']
    if fc_construction == 'individual_timeseries':
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_individual(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names)

    elif fc_construction == 'concat_all_timeseries':
        key = 'fcs_all'
        fcs, scs, non_repeat_scs = [], [], []
        subject_ids, scan_dirs, tasks = [], [], []
        for s in subject_data_struct_list:
            scs.append(torch.tensor(s['sc']).repeat(1, 1, 1))  # make sure batch dim is nonzero for cat() later
            subject_ids.append(s['subject_id'])
            fc = torch.tensor(s['fcs'][key].item().item()[0])
            fcs.append(fc.repeat(1, 1, 1))
            scan_dir, task = find_scandir_and_task(s['which_scans_exist'], scan_names)

            scan_dirs.append(scan_dir)
            tasks.append(task)

    elif fc_construction == 'concat_task_timeseries':
        key = 'task_grouped'
        grouped_fc_names = ['REST1', 'REST2']
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_task_concat(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names,
                               key=key, grouped_fc_names=grouped_fc_names, fc_construction=fc_construction)
    elif fc_construction == 'concat_scandir_timeseries':
        key = 'scandir_grouped'
        grouped_fc_names = ['LR', 'RL']
        fcs, scs, subject_ids, scan_dirs, tasks = \
            unpack_scandir_concat(subject_data_struct_list, scan_combination=scan_combination, scan_names=scan_names,
                                  key=key, grouped_fc_names=grouped_fc_names, fc_construction=fc_construction)
    else:
        raise ValueError(f'fc_construction {fc_construction} not recognized')

    assert len(fcs) == len(scs) == len(subject_ids) == len(scan_dirs) == len(tasks), f'all lens should be same: fcs {len(fcs)}, scs {len(scs)}, ids {len(subject_ids)}, sd {len(scan_dirs)} tasks {len(tasks)}'

    fcs_tensor = torch.cat(fcs, dim=0)
    scs_tensor = torch.cat(scs, dim=0)

    return torch.tensor(fcs_tensor, dtype=datatype), \
           torch.tensor(scs_tensor, dtype=datatype), \
           torch.tensor(subject_ids, dtype=torch.int32), \
           np.array(scan_dirs), np.array(tasks)


def print_batch_info(batch):
    fcs, scs, subject_ids, scan_dirs = batch
    print(f'fcs: shape {fcs.shape}')
    print(f'scs: shape {scs.shape}')
    print(f'ids: shape {subject_ids.shape}')
    print(f'scan_dirs len: {len(scan_dirs)}')


def fc_metric_distrib(coeffs, percentile, base=2):
    N, num_signals, num_samples = 68, 1000, 1200
    synth_ds_1 = PureSyntheticDiffusionDataset(num_vertices=N, num_signals=num_signals, num_samples=num_samples,
                                               sum_stat='sample_cov', coeffs=coeffs,
                                               input_adj=None)
    ps_ds = PsuedoSyntheticDataModule(num_signals=num_signals, num_patients_val=1, num_patients_test=1,
                                      sum_stat='sample_cov', coeffs=coeffs, sc_filter=None)
    real_ds = RealDataModule(fc_filter=None, num_patients_val=1, num_patients_test=1, sc_filter=None)

    ps_ds.setup('fit')
    real_ds.setup('fit')


    synth_fcs = synth_ds_1.sample_cov.numpy()
    ps_fcs = ps_ds.train_dataloader().dataset.full_ds()[0].numpy()
    real_fcs = real_ds.all_train_fcs.numpy()

    fs = 10
    import matplotlib.pyplot as plt
    from utils.util_funcs import normalize_slices
    axes_titles = ['Frob Norm', f'99th percentile of Entry Magnitudes', 'Max Eigval Magnitude', 'Max Entry Magnitude', 'Max Entry Off Diag', 'Medians']
    fig, axes = plt.subplots(nrows=len(axes_titles), ncols=1, constrained_layout=True, sharex='col')
    for ax, title in zip(axes, axes_titles):
        ax.set_title(title, fontsize=fs)

    #bins = 100
    if base == 2:
        bins = np.logspace(-12, 26, base=2, num=100)
    elif base == 10:
        bins = np.logspace(-4, 8, base=10, num=100)

    alpha = .5
    fc_list = [real_fcs, ps_fcs, synth_fcs]
    frob_normed_fc_list = [normalize_slices(torch.tensor(fc), which_norm='frob').numpy() for fc in fc_list]
    perc_normed_fc_list = [normalize_slices(torch.tensor(fc), which_norm='percentile', extra=99).numpy() for fc in fc_list]
    fc_tags = ['real', 'pseudo-synthetic', 'synthetic']
    raw_colors, frob_normed_colors, perc_normed_colors = ['green', 'red', 'blue'], ['lime', 'darkred', 'navy'], ['olivedrab', 'orangered', 'indigo']

    for fcs, frob_normed_fcs, perc_normed_fcs, fc_tag, raw_color, frob_normed_color, perc_normed_color in \
            zip(fc_list, frob_normed_fc_list, perc_normed_fc_list, fc_tags,
                raw_colors, frob_normed_colors, perc_normed_colors):
        ##

        fcs_zd, frob_normed_fcs_zd, perc_normed_fcs_zd = np.copy(fcs), np.copy(frob_normed_fcs), np.copy(perc_normed_fcs)
        for i in range(len(fcs)):
            np.fill_diagonal(fcs_zd[i], 0)
            np.fill_diagonal(frob_normed_fcs_zd[i], 0)
            np.fill_diagonal(perc_normed_fcs_zd[i], 0)
        eigs = np.linalg.norm(fcs, ord=2, axis=(1, 2))
        frob_normed_eigs = np.linalg.norm(frob_normed_fcs, ord=2, axis=(1, 2))
        perc_normed_eigs = np.linalg.norm(perc_normed_fcs, ord=2, axis=(1, 2))

        maxes, medians = np.max(np.abs(fcs), axis=(1, 2)), np.median(upper_tri_as_vec_batch(fcs), axis=1)
        frob_normed_maxes, frob_normed_medians = np.max(np.abs(frob_normed_fcs), axis=(1, 2)), np.median(frob_normed_fcs, axis=(1, 2))
        perc_normed_maxes, perc_normed_medians = np.max(np.abs(perc_normed_fcs), axis=(1, 2)), np.median(perc_normed_fcs, axis=(1, 2))
        zd_maxes, zd_medians = np.max(np.abs(fcs_zd), axis=(1, 2)), np.median(upper_tri_as_vec_batch(fcs_zd), axis=1)
        zd_frob_normed_maxes, zd_frob_normed_medians = np.max(np.abs(frob_normed_fcs_zd), axis=(1, 2)), np.median(frob_normed_fcs_zd, axis=(1, 2))
        zd_perc_normed_maxes, zd_perc_normed_medians = np.max(np.abs(perc_normed_fcs_zd), axis=(1, 2)), np.median(perc_normed_fcs_zd, axis=(1, 2))
        ##
        values_per_matrix = upper_tri_as_vec_batch(torch.abs(torch.tensor(fcs)), offset=0)
        percentiles = torch.quantile(values_per_matrix, .99, dim=1).view(-1, 1).numpy().flatten()

        p = 0
        axes[p].hist(np.linalg.norm(fcs, ord='fro', axis=(1, 2)), bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        p+=1
        axes[p].hist(percentiles, bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        p += 1
        axes[p].hist(eigs, bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        axes[p].hist(frob_normed_eigs, bins=bins, label=fc_tag + f' Frob normed', color=frob_normed_color, alpha=alpha)
        axes[p].hist(perc_normed_eigs, bins=bins, label=fc_tag + f' {percentile}% normed', color=perc_normed_color,
                     alpha=alpha)
        p+=1
        axes[p].hist(maxes, bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        axes[p].hist(frob_normed_maxes, bins=bins, label=fc_tag + f' Frob normed', color=frob_normed_color, alpha=alpha)
        axes[p].hist(perc_normed_maxes, bins=bins, label=fc_tag + f' {percentile}% normed', color=perc_normed_color, alpha=alpha)
        p+=1
        axes[p].hist(zd_maxes, bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        axes[p].hist(zd_frob_normed_maxes, bins=bins, label=fc_tag+ f' Frob normed', color=frob_normed_color, alpha=alpha)
        axes[p].hist(zd_perc_normed_maxes, bins=bins, label=fc_tag+ f' {percentile}% normed', color=perc_normed_color, alpha=alpha)
        p+=1
        axes[p].hist(medians, bins=bins, label=fc_tag, color=raw_color, alpha=alpha)
        axes[p].hist(frob_normed_medians, bins=bins, label=fc_tag+f' Frob normed', color=frob_normed_color, alpha=alpha)
        axes[p].hist(perc_normed_medians, bins=bins, label=fc_tag+f' {percentile}% normed', color=perc_normed_color, alpha=alpha)

    for ax in axes:
        ax.set_xscale('log', base=base)
        ax.set_yticklabels([])
        ax.set_yticks([])


    if base == 2:
        axes[-1].set_xticks([2**(i) for i in range(-8, 26)])
        axes[-1].set_xticklabels([f'2^{i}' if i%2==0 else '' for i in range(-8, 26)])
    elif base == 10:
        axes[-1].set_xticks([10**(i) for i in range(-2, 9)])

    hists, labels = axes[-1].get_legend_handles_labels()
    fig.legend(hists, labels, loc='upper left')
    fig.set_constrained_layout_pads(w_pad=0, h_pad=0, hspace=0, wspace=0)
    plt.show()


if __name__ == "__main__":

    scan_combination = 'separate'
    brain_data, metadata = load_brain_data()
    fcs, scs, subject_ids, scan_dirs, tasks = matlab_struct_to_tensors(brain_data, fc_construction='concat_scandir_timeseries', scan_combination=None)


    coeffs = np.array([.5, .5, .2])
    fc_metric_distrib(coeffs, percentile=99, base=10)
    #fc_metric_distrib(coeffs, which_norm='percentile')


    #
    rand_seed= 50
    from argparse import ArgumentParser
    parser = ArgumentParser()

    dm = RealDataModule(num_splits=5, num_patients_val=None, num_patients_test=3)#.from_argparse_args(args)
    dm.setup(stage='fit')  # setup called by trainer
    for fold in range(dm.num_splits):
        print(f'fold {fold}')
        dm.set_split(fold)
        train_dl = dm.train_dataloader()
        val_dl = dm.val_dataloader()
        print(f'# scans    train dl: {len(train_dl.dataset)}, size val dl: {len(val_dl.dataset)}')
        print(f'# patients train dl: {train_dl.dataset.num_patients}, size val dl: {val_dl.dataset.num_patients}')
        for batch in train_dl:
            fcs, scs, subject_ids, scan_dirs = batch
            print_batch_info(batch)
            break
        for batch in val_dl:
            fcs, scs, subject_ids, scan_dirs = batch
            print_batch_info(batch)
            break
    #dm.setup(stage='fit')
    test_dl = dm.test_dataloader()
    print(f'# scans test dl: {len(test_dl.dataset)}')
    print(f'# patients test dl: {test_dl.dataset.num_patients}')
    for batch in test_dl:
        fcs, scs, subject_ids, scan_dirs = batch
        print_batch_info(batch)
        break

    # straight train/val/test split
    dm = RealDataModule(num_splits=None, num_patients_val=50, num_patients_test=3)  # .from_argparse_args(args)
    dm.setup(stage='fit')  # setup called by trainer
    train_dl = dm.train_dataloader()
    val_dl = dm.val_dataloader()
    print(f'size train dl: {len(train_dl.dataset)}, size val dl: {len(val_dl.dataset)}')
    for batch in train_dl:
        fcs, scs, subject_ids, scan_dirs = batch
        print_batch_info(batch)
        break
    for batch in val_dl:
        fcs, scs, subject_ids, scan_dirs = batch
        print_batch_info(batch)
        break

    test_dl = dm.test_dataloader()
    print(f'size train dl: {len(test_dl.dataset)}')
    print(f'# patients test dl: {test_dl.dataset.num_patients}')
    for batch in test_dl:
        fcs, scs, subject_ids, scan_dirs = batch
        print_batch_info(batch)
        break

    #######################
    # PS dm
    parser.add_argument('--num_signals', type=int, default=200)
    parser.add_argument('--sum_stat', type=str, choices=['sample_cov', 'analytic_cov', 'sample_corr', 'analytic_corr'],
                        default='sample_cov')
    parser.add_argument('--diffuse_over_weighted', type=bool, default=True)
    parser.add_argument('--norm_before_diffusion', type=bool, default=False)
    parser.add_argument('--binarize_labels_for_train', type=bool, default=False)
    parser.add_argument('--num_samples_val', type=int, default=50)
    parser.add_argument('--num_samples_test', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=36)
    args = parser.parse_args()
    dm = PsuedoSyntheticDataModule.from_argparse_args(args)

