import mne
import torch
import copy
import bisect
import tqdm
import numpy as np

from abc import ABC
from collections import OrderedDict
from collections.abc import Iterable
from pathlib import Path
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import ConcatDataset, DataLoader

from baseline.bendr.model.transforms.instance import InstanceTransform, same_channel_sets
from baseline.bendr.model.transforms.preprocessors import Preprocessor
from baseline.bendr.model.utils import DN3atasetNanFound, unfurl, rand_split, DN3atasetException


class DN3ataset(TorchDataset):

 def __init__(self):
 """
 Base class for that specifies the interface for DN3 datasets.
 """
 self._transforms = list()
 self._safe_mode = False
 self._mutli_proc_start = None
 self._mutli_proc_end = None

 def __getitem__(self, item):
 raise NotImplementedError

 def __len__(self):
 raise NotImplementedError

 @property
 def sfreq(self):
 """
 Returns
 -------
 sampling_frequency: float, list
 The sampling frequencies employed by the dataset.
 """
 raise NotImplementedError

 @property
 def channels(self):
 """
 Returns
 -------
 channels: list
 The channel sets used by the dataset.
 """
 raise NotImplementedError

 @property
 def sequence_length(self):
 """
 Returns
 -------
 sequence_length: int, list
 The length of each instance in number of samples
 """
 raise NotImplementedError

 def clone(self):
 """
 A copy of this object to allow the repetition of recordings, thinkers, etc. that load data from
 the same memory/files but have their own tracking of ids.

 Returns
 -------
 cloned : DN3ataset
 New copy of this object.
 """
 return copy.deepcopy(self)

 def add_transform(self, transform):
 """
 Add a transformation that is applied to every fetched item in the dataset

 Parameters
 ----------
 transform : BaseTransform
 For each item retrieved by __getitem__, transform is called to modify that item.
 """
 if isinstance(transform, InstanceTransform):
 self._transforms.append(transform)

 def _execute_transforms(self, *x):
 for transform in self._transforms:
 assert isinstance(transform, InstanceTransform)
 if transform.only_trial_data:
 new_x = transform(x[0])
 if isinstance(new_x, (list, tuple)):
 x = (*new_x, *x[1:])
 else:
 x = (new_x, *x[1:])
 else:
 x = transform(*x)

 if self._safe_mode:
 for i in range(len(x)):
 if torch.any(torch.isnan(x[i])):
 raise DN3atasetNanFound("NaN generated by transform {} for {}'th tensor".format(
 self, i))
 return x

 def clear_transforms(self):
 """
 Remove all added transforms from dataset.
 """
 self._transforms = list()

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True):
 """
 Applies a preprocessor to the dataset

 Parameters
 ----------
 preprocessor : Preprocessor
 A preprocessor to be applied
 apply_transform : bool
 Whether to apply the transform to this dataset (and all members e.g thinkers or sessions)
 after preprocessing them. Alternatively, the preprocessor is returned for manual application
 of its transform through :meth:`Preprocessor.get_transform()`

 Returns
 ---------
 processed_data : ndarry
 Data that has been modified by the preprocessor, should be in the shape of [*, C, T], with C
 and T according with the `channels` and `sequence_length` properties respectively.
 """
 raise NotImplementedError

 def to_numpy(self, batch_size=64, batch_transforms: list = None, num_workers=4, **dataloader_kwargs):
 """
 Commits the dataset to numpy-formatted arrays. Useful for saving dataset to disk, or preparing for tools that
 expect numpy-formatted data rather than iteratable.

 Notes
 -----
 A pytorch :any:`DataLoader` is used to fetch the data to conveniently leverage multiprocessing, and naturally

 Parameters
 ----------
 batch_size: int
 The number of items to fetch per worker. This probably doesn't need much tuning.
 num_workers: int
 The number of spawned processes to fetch and transform data.
 batch_transforms: list
 These are potential batch-level transforms that
 dataloader_kwargs: dict
 Keyword arguments for the pytorch :any:`DataLoader` that underpins the fetched data

 Returns
 -------
 data: list
 A list of numpy arrays.
 """
 dataloader_kwargs.setdefault('batch_size', batch_size)
 dataloader_kwargs.setdefault('num_workers', num_workers)
 dataloader_kwargs.setdefault('shuffle', False)
 dataloader_kwargs.setdefault('drop_last', False)

 batch_transforms = list() if batch_transforms is None else batch_transforms

 loaded = None
 loader = DataLoader(self, **dataloader_kwargs)
 for batch in tqdm.tqdm(loader, desc="Loading Batches"):
 for xform in batch_transforms:
 assert callable(xform)
 batch = xform(batch)
 # cpu just to be certain, shouldn't affect things otherwise
 batch = [b.cpu().numpy() for b in batch]
 if loaded is None:
 loaded = batch
 else:
 loaded = [np.concatenate([loaded[i], batch[i]], axis=0) for i in range(len(batch))]

 return loaded


class _Recording(DN3ataset, ABC):
 """
 Abstract base class for any supported recording
 """
 def __init__(self, info, session_id, person_id, tlen, ch_ind_picks=None):
 super().__init__()
 self.info = info
 self.picks = ch_ind_picks if ch_ind_picks is not None else list(range(len(info['chs'])))
 self._recording_channels = [(ch['ch_name'], int(ch['kind'])) for idx, ch in enumerate(info['chs'])
 if idx in self.picks]
 self._recording_sfreq = info['sfreq']
 self._recording_len = int(self._recording_sfreq * tlen)
 assert self._recording_sfreq is not None
 self.session_id = session_id
 self.person_id = person_id

 def get_all(self):
 all_recordings = [self[i] for i in range(len(self))]
 return [torch.stack(t) for t in zip(*all_recordings)]

 @property
 def sfreq(self):
 sfreq = self._recording_sfreq
 for xform in self._transforms:
 sfreq = xform.new_sfreq(sfreq)
 return sfreq

 @property
 def channels(self):
 channels = np.array(self._recording_channels)
 for xform in self._transforms:
 channels = xform.new_channels(channels)
 return channels

 @property
 def sequence_length(self):
 sequence_length = self._recording_len
 for xform in self._transforms:
 sequence_length = xform.new_sequence_length(sequence_length)
 return sequence_length


class RawTorchRecording(_Recording):
 """
 Interface for bridging mne Raw instances as PyTorch compatible "Dataset".

 Parameters
 ----------
 raw : mne.io.Raw
 Raw data, data does not need to be preloaded.
 tlen : float
 Length of recording specified in seconds.
 session_id : (int, str, optional)
 A unique (with respect to a thinker within an eventual dataset) identifier for the current recording
 session. If not specified, defaults to '0'.
 person_id : (int, str, optional)
 A unique (with respect to an eventual dataset) identifier for the particular person being recorded.
 stride : int
 The number of samples to skip between each starting offset of loaded samples.
 """

 def __init__(self, raw: mne.io.Raw, tlen, session_id=0, person_id=0, stride=1, ch_ind_picks=None, decimate=1,
 bad_spans=None, **kwargs):

 """
 Interface for bridging mne Raw instances as PyTorch compatible "Dataset".

 Parameters
 ----------
 raw : mne.io.Raw
 Raw data, data does not need to be preloaded.
 tlen : float
 Length of each retrieved portion of the recording.
 session_id : (int, str, optional)
 A unique (with respect to a thinker within an eventual dataset) identifier for the current recording
 session. If not specified, defaults to '0'.
 person_id : (int, str, optional)
 A unique (with respect to an eventual dataset) identifier for the particular person being recorded.
 stride : int
 The number of samples to skip between each starting offset of loaded samples.
 ch_ind_picks : list[int]
 A list of channel indices that have been selected for.
 decimate : int
 The number of samples to move before taking the next sample, in other words take every decimate'th
 sample.
 bad_spans: List[tuple], None
 These are tuples of (start_seconds, end_seconds) of times that should be avoided. Any sequences that
 would overlap with these sections will be excluded.
 """
 super().__init__(raw.info, session_id, person_id, tlen, ch_ind_picks)
 self.filename = raw.filenames[0]
 self.decimate = int(decimate)
 self._recording_sfreq /= self.decimate
 self._recording_len = int(tlen * self._recording_sfreq)
 self.stride = stride
 # Implement my own (rather than mne's) in-memory buffer when there are savings
 self._stride_load = self.decimate > 1 and raw.preload
 self.max = kwargs.get('max', None)
 self.min = kwargs.get('min', 0)
 bad_spans = list() if bad_spans is None else bad_spans
 self.__dict__.update(kwargs)

 self._decimated_sequence_starts = list(
 range(0, raw.n_times // self.decimate - self._recording_len, self.stride)
 )
 # TODO come back to this inefficient BS
 for start, stop in bad_spans:
 start = int(self._recording_sfreq * start)
 stop = int(stop * self._recording_sfreq)
 drop = list()
 for i, span_start in enumerate(self._decimated_sequence_starts):
 if start <= span_start < stop or start <= span_start + self._recording_len <= stop:
 drop.append(span_start)
 for span_start in drop:
 self._decimated_sequence_starts.remove(span_start)

 # When the stride is greater than the sequence length, preload savings can be found by chopping the
 # sequence into subsequences of length: sequence length. Also, if decimating, can significantly reduce memory
 # requirements not otherwise addressed with the Raw object.
 if self._stride_load and len(self._decimated_sequence_starts) > 0:
 x = raw.get_data(self.picks)
 # pre-decimate this data for more preload savings (and for the stride factors to be valid)
 x = x[:, ::decimate]
 self._x = np.empty([x.shape[0], self._recording_len, len(self._decimated_sequence_starts)], dtype=x.dtype)
 for i, start in enumerate(self._decimated_sequence_starts):
 self._x[..., i] = x[:, start:start + self._recording_len]
 else:
 self._raw_workaround(raw)

 def _raw_workaround(self, raw):
 self.raw = raw

 def __getitem__(self, index):
 if index < 0:
 index += len(self)

 if self._stride_load:
 x = self._x[:, :, index]
 else:
 start = self._decimated_sequence_starts[index]
 x = self.raw.get_data(self.picks, start=start, stop=start + self._recording_len * self.decimate)
 if self.decimate > 1:
 x = x[:, ::self.decimate]

 scale = 1 if self.max is None else (x.max() - x.min()) / (self.max - self.min)
 if scale > 1 or np.isnan(scale):
 print('Warning: scale exeeding 1')

 x = torch.from_numpy(x).float()

 if torch.any(torch.isnan(x)):
 print("Nan found: raw {}, index {}".format(self.filename, index))
 print("Replacing with random values with same shape for now...")
 x = torch.rand_like(x)

 return self._execute_transforms(x)

 def __len__(self):
 return len(self._decimated_sequence_starts)

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True):
 self.raw = preprocessor(recording=self)
 if apply_transform:
 self.add_transform(preprocessor.get_transform())


class EpochTorchRecording(_Recording):
 def __init__(self, epochs: mne.Epochs, session_id=0, person_id=0, force_label=None, cached=False,
 ch_ind_picks=None, event_mapping=None, skip_epochs=None):
 """
 Wraps :any:`mne.Epochs` instances so that they conform to the :any:`Recording` API.

 Parameters
 ----------
 epochs
 session_id
 person_id
 force_label : bool, Optional
 Whether to force the labels provided by the epoch instance. By default (False), will convert
 output label (for N classes) into codes 0 -> N-1.
 cached
 ch_ind_picks
 event_mapping : dict, Optional
 Mapping of human-readable names to numeric codes used by `epochs`.
 skip_epochs: List[int]
 A list of epochs to skip
 """
 super().__init__(epochs.info, session_id, person_id, epochs.tmax - epochs.tmin + 1 / epochs.info['sfreq'],
 ch_ind_picks)
 self.epochs = epochs
 # TODO scrap this cache option, it seems utterly redundant now
 self._cache = [None for _ in range(len(epochs.events))] if cached else None
 if event_mapping is None:
 # mne parses this for us
 event_mapping = epochs.event_id
 if force_label:
 self.epoch_codes_to_class_labels = event_mapping
 else:
 reverse_mapping = {v: k for k, v in event_mapping.items()}
 self.epoch_codes_to_class_labels = {v: i for i, v in enumerate(sorted(reverse_mapping.keys()))}
 skip_epochs = list() if skip_epochs is None else skip_epochs
 self._skip_map = [i for i in range(len(self.epochs.events)) if i not in skip_epochs]
 self._skip_map = dict(zip(range(len(self._skip_map)), self._skip_map))

 def __getitem__(self, index):
 index = self._skip_map[index]
 ep = self.epochs[index]

 if self._cache is None or self._cache[index] is None:
 # TODO Could have a speedup if not using ep, but items, but would need to drop bads?
 x = ep.get_data(picks=self.picks)
 if len(x.shape) != 3 or 0 in x.shape:
 print("I don't know why: {} index{}/{}".format(self.epochs.filename, index, len(self)))
 print(self.epochs.info['description'])
 print("Using trial {} in place for now...".format(index-1))
 return self.__getitem__(index - 1)
 x = torch.from_numpy(x.squeeze(0)).float()
 if self._cache is not None:
 self._cache[index] = x
 else:
 x = self._cache[index]

 y = torch.tensor(self.epoch_codes_to_class_labels[ep.events[0, -1]]).squeeze().long()

 return self._execute_transforms(x, y)

 def __len__(self):
 return len(self._skip_map)

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True, **kwargs):
 processed = preprocessor(self, **kwargs)
 if processed is not None:
 self._cache = [processed[i] for i in range(processed.shape[0])]
 if apply_transform:
 self.add_transform(preprocessor.get_transform())

 def event_mapping(self):
 """
 Maps the labels returned by this to the events as recorded in the original annotations or stim channel.

 Returns
 -------
 mapping : dict
 Keys are the class labels used by this object, values are the original event signifier.
 """
 return self.epoch_codes_to_class_labels

 def get_targets(self):
 return np.apply_along_axis(lambda x: self.epoch_codes_to_class_labels[x[0]], 1,
 self.epochs.events[list(self._skip_map.values()), -1, np.newaxis]).squeeze()


class Thinker(DN3ataset, ConcatDataset):
 """
 Collects multiple recordings of the same person, intended to be of the same task, at different times or conditions.
 """

 def __init__(self, sessions, person_id="auto", return_session_id=False, return_trial_id=False,
 propagate_kwargs=False):
 """
 Collects multiple recordings of the same person, intended to be of the same task, at different times or
 conditions.

 Parameters
 ----------
 sessions : Iterable, dict
 Either a sequence of recordings, or a mapping of session_ids to recordings. If the former, the
 recording's session_id is preserved. If the
 person_id : int, str
 Label to be used for the thinker. If set to "auto" (default), will automatically pick the person_id
 using the most common person_id in the recordings.
 return_session_id : bool
 Whether to return (enumerated - see `Dataset`) session_ids with the data itself. Overridden
 by `propagate_kwargs`, with key `session_id`
 propagate_kwargs : bool
 If True, items are returned additional tensors generated by transforms, and session_id as
 """
 DN3ataset.__init__(self)
 if not isinstance(sessions, dict) and isinstance(sessions, Iterable):
 self.sessions = OrderedDict()
 for r in sessions:
 self.__add__(r)
 elif isinstance(sessions, dict):
 self.sessions = OrderedDict(sessions)
 else:
 raise TypeError("Recordings must be iterable or already processed dict.")
 if person_id == 'auto':
 ids = [sess.person_id for sess in self.sessions.values()]
 person_id = max(set(ids), key=ids.count)
 self.person_id = person_id

 for sess in self.sessions.values():
 sess.person_id = person_id

 self._reset_dataset()
 self.return_session_id = return_session_id
 self.return_trial_id = return_trial_id

 def _reset_dataset(self):
 for _id in self.sessions:
 self.sessions[_id].session_id = _id
 ConcatDataset.__init__(self, self.sessions.values())

 def __str__(self):
 return "Person {} - {} trials | {} transforms".format(self.person_id, len(self), len(self._transforms))

 @property
 def sfreq(self):
 sfreq = set(self.sessions[s].sfreq for s in self.sessions)
 if len(sfreq) > 1:
 print("Warning: Multiple sampling frequency values found. Over/re-sampling may be necessary.")
 return unfurl(sfreq)
 sfreq = sfreq.pop()
 for xform in self._transforms:
 sfreq = xform.new_sfreq(sfreq)
 return sfreq

 @property
 def channels(self):
 channels = [self.sessions[s].channels for s in self.sessions]
 if not same_channel_sets(channels):
 raise ValueError("Multiple channel sets found. A consistent mapping like Deep1010 is necessary to proceed.")
 channels = channels.pop()
 for xform in self._transforms:
 channels = xform.new_channels(channels)
 return channels

 @property
 def sequence_length(self):
 sequence_length = set(self.sessions[s].sequence_length for s in self.sessions)
 if len(sequence_length) > 1:
 print("Warning: Multiple sequence lengths found. A cropping transformation may be in order.")
 return unfurl(sequence_length)
 sequence_length = sequence_length.pop()
 for xform in self._transforms:
 sequence_length = xform.new_sequence_length(sequence_length)
 return sequence_length

 def __add__(self, sessions):
 assert isinstance(sessions, (_Recording, Thinker))
 if isinstance(sessions, Thinker):
 if sessions.person_id != self.person_id:
 print("Person IDs don't match: adding {} to {}. Assuming latter...")
 sessions = sessions.sessions

 if sessions.session_id in self.sessions.keys():
 self.sessions[sessions.session_id] += sessions
 else:
 self.sessions[sessions.session_id] = sessions

 self._reset_dataset()

 def pop_session(self, session_id, apply_thinker_transform=True):
 assert session_id in self.sessions.keys()
 sess = self.sessions.pop(session_id)
 if apply_thinker_transform:
 for xform in self._transforms:
 sess.add_transform(xform)
 self._reset_dataset()
 return sess

 def __getitem__(self, item, return_id=False):
 x = list(ConcatDataset.__getitem__(self, item))
 session_idx = bisect.bisect_right(self.cumulative_sizes, item)
 idx_offset = 2 if len(x) > 1 and x[1].dtype == torch.bool else 1
 if self.return_trial_id:
 trial_id = item if session_idx == 0 else item - self.cumulative_sizes[session_idx-1]
 x.insert(idx_offset, torch.tensor(trial_id).long())
 if self.return_session_id:
 x.insert(idx_offset, torch.tensor(session_idx).long())
 return self._execute_transforms(*x)

 def __len__(self):
 return ConcatDataset.__len__(self)

 def _make_like_me(self, sessions: Iterable):
 if not isinstance(sessions, dict):
 sessions = {s: self.sessions[s] for s in sessions}
 like_me = Thinker(sessions, self.person_id, self.return_session_id)
 for x in self._transforms:
 like_me.add_transform(x)
 return like_me

 def split(self, training_sess_ids=None, validation_sess_ids=None, testing_sess_ids=None, test_frac=0.25,
 validation_frac=0.25):
 """
 Split the thinker's data into training, validation and testing sets.

 Parameters
 ----------
 test_frac : float
 Proportion of the total data to use for testing, this is overridden by `testing_sess_ids`.
 validation_frac : float
 Proportion of the data remaining - after removing test proportion/sessions - to use as
 validation data. Likewise, `validation_sess_ids` overrides this value.
 training_sess_ids : : (Iterable, None)
 The session ids to be explicitly used for training.
 validation_sess_ids : (Iterable, None)
 The session ids to be explicitly used for validation.
 testing_sess_ids : (Iterable, None)
 The session ids to be explicitly used for testing.

 Returns
 -------
 training : DN3ataset
 The training dataset
 validation : DN3ataset
 The validation dataset
 testing : DN3ataset
 The testing dataset
 """
 training_sess_ids = set(training_sess_ids) if training_sess_ids is not None else set()
 validation_sess_ids = set(validation_sess_ids) if validation_sess_ids is not None else set()
 testing_sess_ids = set(testing_sess_ids) if testing_sess_ids is not None else set()
 duplicated_ids = training_sess_ids.intersection(validation_sess_ids).intersection(testing_sess_ids)
 if len(duplicated_ids) > 0:
 print("Ids duplicated across train/val/test split: {}".format(duplicated_ids))
 use_sessions = self.sessions.copy()
 training, validating, testing = (
 self._make_like_me({s_id: use_sessions.pop(s_id) for s_id in ids}) if len(ids) else None
 for ids in (training_sess_ids, validation_sess_ids, testing_sess_ids)
 )
 if training is not None and validating is not None and testing is not None:
 if len(use_sessions) > 0:
 print("Warning: sessions specified do not span all sessions. Skipping {} sessions.".format(
 len(use_sessions)))
 return training, validating, testing

 # Split up the rest if there is anything left
 if len(use_sessions) > 0:
 remainder = self._make_like_me(use_sessions.keys())
 if testing is None:
 assert test_frac is not None and 0 < test_frac < 1
 remainder, testing = rand_split(remainder, frac=test_frac)
 if validating is None:
 assert validation_frac is not None and 0 <= test_frac < 1
 if validation_frac > 0:
 validating, remainder = rand_split(remainder, frac=validation_frac)

 training = remainder if training is None else training

 return training, validating, testing

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True, sessions=None, **kwargs):
 """
 Applies a preprocessor to the dataset

 Parameters
 ----------
 preprocessor : Preprocessor
 A preprocessor to be applied
 sessions : (None, Iterable)
 If specified (default is None), the sessions to use for preprocessing calculation
 apply_transform : bool
 Whether to apply the transform to this dataset (all sessions, not just those specified for
 preprocessing) after preprocessing them. Exclusive application to select sessions can be
 done using the return value and a separate call to `add_transform` with the same `sessions`
 list.

 Returns
 ---------
 preprocessor : Preprocessor
 The preprocessor after application to all relevant sessions
 """
 for sid, session in enumerate(self.sessions.values()):
 session.preprocess(preprocessor, session_id=sid, apply_transform=apply_transform, **kwargs)
 return preprocessor

 def clear_transforms(self, deep_clear=False):
 self._transforms = list()
 if deep_clear:
 for s in self.sessions.values():
 s.clear_transforms()

 def add_transform(self, transform, deep=False):
 if deep:
 for s in self.sessions.values():
 s.add_transform(transform)
 else:
 self._transforms.append(transform)

 def get_targets(self):
 """
 Collect all the targets (i.e. labels) that this Thinker's data is annotated with.

 Returns
 -------
 targets: np.ndarray
 A numpy-formatted array of all the targets/label for this thinker.
 """
 targets = list()
 for sess in self.sessions:
 if hasattr(self.sessions[sess], 'get_targets'):
 targets.append(self.sessions[sess].get_targets())
 if len(targets) == 0:
 return None
 return np.concatenate(targets)


class DatasetInfo(object):
 """
 This objects contains non-critical meta-data that might need to be tracked for :py:`Dataset` objects. Generally
 not necessary to be constructed manually, these are created by the configuratron to automatically create transforms
 and/or other processes downstream.
 """
 def __init__(self, dataset_name, data_max=None, data_min=None, excluded_people=None, targets=None):
 self.__dict__.update(dict(dataset_name=dataset_name, data_max=data_max, data_min=data_min,
 excluded_people=excluded_people, targets=targets))

 def __str__(self):
 return "{} | {} targets | Excluding {}".format(self.dataset_name, self.targets, self.excluded_people)


class Dataset(DN3ataset, ConcatDataset):
 """
 Collects thinkers, each of which may collect multiple recording sessions of the same tasks, into a dataset with
 (largely) consistent:
 - hardware:
 - channel number/labels
 - sampling frequency
 - annotation paradigm:
 - consistent event types
 """
 def __init__(self, thinkers, dataset_id=None, task_id=None, return_trial_id=False, return_session_id=False,
 return_person_id=False, return_dataset_id=False, return_task_id=False, dataset_info=None):
 """
 Collects recordings from multiple people, intended to be of the same task, at different times or
 conditions.
 Optionally, can specify whether to return person, session, dataset and task labels. Person and session ids will
 be converted to an enumerated set of integer ids, rather than those provided during creation of those datasets
 in order to make a minimal set of labels. e.g. if there are 3 thinkers, {A01, A02, and A05}, specifying
 `return_person_id` will return an additional tensor with 0 for A01, 1 for A02 and 2 for A05 respectively. To
 recover any original identifier, get_thinkers() returns a list of the original thinker ids such that the
 enumerated offset recovers the original identity. Extending the example above:
 ``self.get_thinkers()[1] == "A02"``

 .. warning:: The enumerated ids above are only ever used in the construction of model input tensors,
 otherwise, anywhere where ids are required as API, the *human readable* version is uesd
 (e.g. in our example above A02)

 Parameters
 ----------
 thinkers : Iterable, dict
 Either a sequence of `Thinker`, or a mapping of person_id to `Thinker`. If the latter, id's are
 overwritten by these id's.
 dataset_id : int
 An identifier associated with data from the entire dataset. Unlike person and sessions, this should
 simply be an integer for the sake of returning labels that can functionally be used for learning.
 task_id : int
 An identifier associated with data from the entire dataset, and potentially others of the same task.
 Like dataset_idm this should simply be an integer.
 return_person_id : bool
 Whether to return (enumerated - see above) person_ids with the data itself.
 return_session_id : bool
 Whether to return (enumerated - see above) session_ids with the data itself.
 return_dataset_id : bool
 Whether to return the dataset_id with the data itself.
 return_task_id : bool
 Whether to return the dataset_id with the data itself.
 return_trial_id: bool
 Whether to return the id of the trial (within the session)
 dataset_info : DatasetInfo, Optional
 Additional, non-critical data that helps specify additional features of the dataset.

 Notes
 -----------
 When getting items from a dataset, the id return order is returned most general to most specific, wrapped by
 the actual raw data and (optionally, if epoch-variety recordings) the label for the raw data, thus:
 raw_data, task_id, dataset_id, person_id, session_id, *label
 """
 super().__init__()
 self.info = dataset_info

 if not isinstance(thinkers, Iterable):
 raise ValueError("Provided thinkers must be in an iterable container, e.g. list, tuple, dicts")

 # Overwrite thinker ids with those provided as dict argument and sort by ids
 if not isinstance(thinkers, dict):
 thinkers = {t.person_id: t for t in thinkers}

 self.thinkers = OrderedDict()
 for t in sorted(thinkers.keys()):
 self.__add__(thinkers[t], person_id=t, return_session_id=return_session_id, return_trial_id=return_trial_id)
 self._reset_dataset()

 self.dataset_id = torch.tensor(dataset_id).long() if dataset_id is not None else None
 self.task_id = torch.tensor(task_id).long() if task_id is not None else None
 self.update_id_returns(return_trial_id, return_session_id, return_person_id, return_dataset_id, return_task_id)

 def update_id_returns(self, trial=None, session=None, person=None, task=None, dataset=None):
 """
 Updates which ids are to be returned by the dataset. If any argument is `None` it preserves the previous value.

 Parameters
 ----------
 trial : None, bool
 Whether to return trial ids.
 session : None, bool
 Whether to return session ids.
 person : None, bool
 Whether to return person ids.
 task : None, bool
 Whether to return task ids.
 dataset : None, bool
 Whether to return dataset ids.
 """
 self.return_trial_id = self.return_trial_id if trial is None else trial
 self.return_session_id = self.return_session_id if session is None else session
 self.return_person_id = self.return_person_id if person is None else person
 self.return_dataset_id = self.return_dataset_id if dataset is None else dataset
 self.return_task_id = self.return_task_id if task is None else task
 def set_ids_for_thinkers(th_id, thinker: Thinker):
 thinker.return_trial_id = self.return_trial_id
 thinker.return_session_id = self.return_session_id
 self._apply(set_ids_for_thinkers)

 def _reset_dataset(self):
 for p_id in self.thinkers:
 self.thinkers[p_id].person_id = p_id
 for s_id in self.thinkers[p_id].sessions:
 self.thinkers[p_id].sessions[s_id].session_id = s_id
 self.thinkers[p_id].sessions[s_id].person_id = p_id
 ConcatDataset.__init__(self, self.thinkers.values())

 def _apply(self, lam_fn):
 for th_id, thinker in self.thinkers.items():
 lam_fn(th_id, thinker)

 def __str__(self):
 ds_name = "Dataset-{}".format(self.dataset_id) if self.info is None else self.info.dataset_name
 return ">> {} | DSID: {} | {} people | {} trials | {} channels | {} samples/trial | {}Hz | {} transforms".\
 format(ds_name, self.dataset_id, len(self.get_thinkers()), len(self), len(self.channels),
 self.sequence_length, self.sfreq, len(self._transforms))

 def __add__(self, thinker, person_id=None, return_session_id=None, return_trial_id=None):
 assert isinstance(thinker, Thinker)
 return_session_id = self.return_session_id if return_session_id is None else return_session_id
 return_trial_id = self.return_trial_id if return_trial_id is None else return_trial_id
 thinker.return_session_id = return_session_id
 thinker.return_trial_id = return_trial_id
 if person_id is not None:
 thinker.person_id = person_id

 if thinker.person_id in self.thinkers.keys():
 print("Warning. Person {} already in dataset... Merging sessions.".format(thinker.person_id))
 self.thinkers[thinker.person_id] += thinker
 else:
 self.thinkers[thinker.person_id] = thinker
 self._reset_dataset()
 return self

 def pop_thinker(self, person_id, apply_ds_transforms=False):
 assert person_id in self.get_thinkers()
 thinker = self.thinkers.pop(person_id)
 if apply_ds_transforms:
 for xform in self._transforms:
 thinker.add_transform(xform)
 self._reset_dataset()
 return thinker

 def __getitem__(self, item):
 person_id = bisect.bisect_right(self.cumulative_sizes, item)
 person = self.thinkers[self.get_thinkers()[person_id]]
 if person_id == 0:
 sample_idx = item
 else:
 sample_idx = item - self.cumulative_sizes[person_id - 1]
 x = list(person.__getitem__(sample_idx))

 if self._safe_mode:
 for i in range(len(x)):
 if torch.any(torch.isnan(x[i])):
 raise DN3atasetNanFound("Nan found at tensor offset {}. "
 "Loading data from person {} and sample {}".format(i, person, sample_idx))

 # Skip deep1010 mask
 idx_offset = 2 if len(x) > 1 and x[1].dtype == torch.bool else 1
 if self.return_person_id:
 x.insert(idx_offset, torch.tensor(person_id).long())

 if self.return_dataset_id:
 x.insert(idx_offset, self.dataset_id)

 if self.return_task_id:
 x.insert(idx_offset, self.task_id)

 if self._safe_mode:
 try:
 return self._execute_transforms(*x)
 except DN3atasetNanFound as e:
 raise DN3atasetNanFound(
 "Nan found after transform | {} | from person {} and sample {}".format(e.args, person, sample_idx))

 return self._execute_transforms(*x)

 def safe_mode(self, mode=True):
 """
 This allows switching *safe_mode* on or off. When safe_mode is on, if data is ever NaN, it is captured
 before being returned and a report is generated.

 Parameters
 ----------
 mode : bool
 The status of whether in safe mode or not.
 """
 self._safe_mode = mode

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True, thinkers=None):
 """
 Applies a preprocessor to the dataset

 Parameters
 ----------
 preprocessor : Preprocessor
 A preprocessor to be applied
 thinkers : (None, Iterable)
 If specified (default is None), the thinkers to use for preprocessing calculation
 apply_transform : bool
 Whether to apply the transform to this dataset (all thinkers, not just those specified for
 preprocessing) after preprocessing them. Exclusive application to specific thinkers can be
 done using the return value and a separate call to `add_transform` with the same `thinkers`
 list.

 Returns
 ---------
 preprocessor : Preprocessor
 The preprocessor after application to all relevant thinkers
 """
 for thid, thinker in enumerate(self.thinkers.values()):
 thinker.preprocess(preprocessor, thinker_id=thid, apply_transform=apply_transform)
 return preprocessor

 @property
 def sfreq(self):
 sfreq = set(self.thinkers[t].sfreq for t in self.thinkers)
 if len(sfreq) > 1:
 print("Warning: Multiple sampling frequency values found. Over/re-sampling may be necessary.")
 return unfurl(sfreq)
 sfreq = sfreq.pop()
 for xform in self._transforms:
 sfreq = xform.new_sfreq(sfreq)
 return sfreq

 @property
 def channels(self):
 channels = [self.thinkers[t].channels for t in self.thinkers]
 if not same_channel_sets(channels):
 raise ValueError("Multiple channel sets found. A consistent mapping like Deep1010 is necessary to proceed.")
 channels = channels.pop()
 for xform in self._transforms:
 channels = xform.new_channels(channels)
 return channels

 @property
 def sequence_length(self):
 sequence_length = set(self.thinkers[t].sequence_length for t in self.thinkers)
 if len(sequence_length) > 1:
 print("Warning: Multiple sequence lengths found. A cropping transformation may be in order.")
 return unfurl(sequence_length)
 sequence_length = sequence_length.pop()
 for xform in self._transforms:
 sequence_length = xform.new_sequence_length(sequence_length)
 return sequence_length

 def get_thinkers(self):
 """
 Accumulates a consistently ordered list of all the thinkers in the dataset. It is this order that any automatic
 segmenting through :py:meth:`loso()` and :py:meth:`lmso()` will be done.

 Returns
 -------
 thinker_names : list
 """
 return list(self.thinkers.keys())

 def get_sessions(self):
 """
 Accumulates all the sessions from each thinker in the dataset in a nested dictionary.

 Returns
 -------
 session_dict: dict
 Keys are the thinkers of :py:meth:`get_thinkers()`, values are each another dictionary that maps
 session ids to :any:`_Recording`
 """
 return {th: self.thinkers[th].sessions.copy() for th in self.thinkers}

 def __len__(self):
 self._reset_dataset()
 return self.cumulative_sizes[-1]

 def _make_like_me(self, people: list):
 if len(people) == 1:
 like_me = self.thinkers[people[0]].clone()
 else:
 dataset_id = self.dataset_id.item() if self.dataset_id is not None else None
 task_id = self.task_id.item() if self.task_id is not None else None

 like_me = Dataset({p: self.thinkers[p] for p in people}, dataset_id, task_id,
 return_person_id=self.return_person_id, return_session_id=self.return_session_id,
 return_dataset_id=self.return_dataset_id, return_task_id=self.return_task_id,
 return_trial_id=self.return_trial_id, dataset_info=self.info)
 for x in self._transforms:
 like_me.add_transform(x)
 return like_me

 def _generate_splits(self, validation, testing):
 for val, test in zip(validation, testing):
 training = list(self.thinkers.keys())
 for v in val:
 training.remove(v)
 for t in test:
 training.remove(t)

 training = self._make_like_me(training)

 validating = self._make_like_me(val)
 _val_set = set(validating.get_thinkers()) if len(val) > 1 else {validating.person_id}

 testing = self._make_like_me(test)
 _test_set = set(testing.get_thinkers()) if len(test) > 1 else {testing.person_id}

 if len(_val_set.intersection(_test_set)) > 0:
 raise ValueError("Validation and test overlap with ids: {}".format(_val_set.intersection(_test_set)))

 print('Training: {}'.format(training))
 print('Validation: {}'.format(validating))
 print('Test: {}'.format(testing))

 yield training, validating, testing

 def loso(self, validation_person_id=None, test_person_id=None):
 """
 This *generates* a "Leave-one-subject-out" (LOSO) split. Tests each person one-by-one, and validates on the
 previous (the first is validated with the last).

 Parameters
 ----------
 validation_person_id : (int, str, list, optional)
 If specified, and corresponds to one of the person_ids in this dataset, the loso cross
 validation will consistently generate this thinker as `validation`. If *list*, must
 be the same length as `test_person_id`, say a length N. If so, will yield N
 each in sequence, and use remainder for test.
 test_person_id : (int, str, list, optional)
 Same as `validation_person_id`, but for testing. However, testing may be a list when
 validation is a single value. Thus if testing is N ids, will yield N values, with a consistent
 single validation person. If a single id (int or str), and `validation_person_id` is not also
 a single id, will ignore `validation_person_id` and loop through all others that are not the
 `test_person_id`.

 Yields
 -------
 training : Dataset
 Another dataset that represents the training set
 validation : Thinker
 The validation thinker
 test : Thinker
 The test thinker
 """
 if isinstance(test_person_id, (str, int)) and isinstance(validation_person_id, (str, int)):
 yield from self._generate_splits([[validation_person_id]], [[test_person_id]])
 return
 elif isinstance(test_person_id, str):
 yield from self._generate_splits([[v] for v in self.get_thinkers() if v != test_person_id],
 [[test_person_id] for _ in range(len(self.get_thinkers()) - 1)])
 return

 # Testing is now either a sequence or nothing. Should loop over everyone (unless validation is a single id)
 if test_person_id is None and isinstance(validation_person_id, (str, int)):
 test_person_id = [t for t in self.get_thinkers() if t != validation_person_id]
 validation_person_id = [validation_person_id for _ in range(len(test_person_id))]
 elif test_person_id is None:
 test_person_id = [t for t in self.get_thinkers()]

 if validation_person_id is None:
 validation_person_id = [test_person_id[i - 1] for i in range(len(test_person_id))]

 if not isinstance(test_person_id, list) or len(test_person_id) != len(validation_person_id):
 raise ValueError("Test ids must be same length iterable as validation ids.")

 yield from self._generate_splits([[v] for v in validation_person_id], [[t] for t in test_person_id])

 def lmso(self, folds=10, test_splits=None, validation_splits=None):
 """
 This *generates* a "Leave-multiple-subject-out" (LMSO) split. In other words X-fold cross-validation, with
 boundaries enforced at thinkers (each person's data is not split into different folds).

 Parameters
 ----------
 folds : int
 If this is specified and `splits` is None, will split the subjects into this many folds, and then use
 each fold as a test set in turn (and the previous fold - starting with the last - as validation).
 test_splits : list, tuple
 This should be a list of tuples/lists of either:
 - The ids of the consistent test set. In which case, folds must be specified, or validation_splits
 is a nested list that .
 - Two sub lists, first testing, second validation ids

 Yields
 -------
 training : Dataset
 Another dataset that represents the training set
 validation : Dataset
 The validation people as a dataset
 test : Thinker
 The test people as a dataset
 """

 def is_nested(split: list):
 should_be_nested = isinstance(split[0], (list, tuple))
 for x in split[1:]:
 if (should_be_nested and not isinstance(x, (list, tuple))) or (isinstance(x, (list, tuple))
 and not should_be_nested):
 raise ValueError("Can't mix list/tuple and other elements when specifying ids.")
 if not should_be_nested and folds is None:
 raise ValueError("Can't infer folds from non-nested list. Specify folds, or nest ids")
 return should_be_nested

 def calculate_from_remainder(known_split):
 _folds = len(known_split) if is_nested(list(known_split)) else folds
 if folds is None:
 print("Inferred {} folds from test split.".format(_folds))
 remainder = list(set(self.get_thinkers()).difference(known_split))
 return [list(x) for x in np.array_split(remainder, _folds)], [known_split for _ in range(_folds)]

 if test_splits is None and validation_splits is None:
 if folds is None:
 raise ValueError("Must specify <folds> if not specifying ids.")
 folds = [list(x) for x in np.array_split(self.get_thinkers(), folds)]
 test_splits, validation_splits = zip(*[(folds[i], folds[i-1]) for i in range(len(folds))])
 elif validation_splits is None:
 validation_splits, test_splits = calculate_from_remainder(test_splits)
 elif test_splits is None:
 test_splits, validation_splits = calculate_from_remainder(validation_splits)

 yield from self._generate_splits(validation_splits, test_splits)

 def add_transform(self, transform, deep=False):
 if deep:
 for t in self.thinkers.values():
 t.add_transform(transform, deep=deep)
 else:
 self._transforms.append(transform)

 def clear_transforms(self, deep_clear=False):
 self._transforms = list()
 if deep_clear:
 for t in self.thinkers.values():
 t.clear_transforms(deep_clear=deep_clear)

 def get_targets(self):
 """
 Collect all the targets (i.e. labels) that this Thinker's data is annotated with.

 Returns
 -------
 targets: np.ndarray
 A numpy-formatted array of all the targets/label for this thinker.
 """
 targets = list()
 for tid in self.thinkers:
 if hasattr(self.thinkers[tid], 'get_targets'):
 targets.append(self.thinkers[tid].get_targets())
 if len(targets) == 0:
 return None
 try:
 return np.concatenate(targets)
 # Catch exceptions due to inability to concatenate real targets.
 except ValueError:
 return None

 def dump_dataset(self, toplevel, compressed=True, apply_transforms=True, summary_file='dataset-dump.npz',
 chunksize=100):
 """
 Dumps the dataset to the directory specified by toplevel, with a single file per index.

 Parameters
 ----------
 toplevel : str
 The toplevel location to dump the dataset to. This folder (and path) will be created if it does not
 exist.
 apply_transforms: bool
 Whether to apply the transforms while preparing the data to be saved.
 """
 if apply_transforms is False:
 raise NotImplementedError

 toplevel = Path(toplevel)
 toplevel.mkdir(exist_ok=True, parents=True)

 thinkers = self.thinkers.copy()
 inds = 0
 for k in self.thinkers.keys():
 thinkers[k] = np.arange(inds, inds+len(thinkers[k]))
 inds += len(thinkers[k])

 np.savez_compressed(toplevel / summary_file, version='0.0.1', sfreq=self.sfreq, channels=self.channels,
 sequence_length=self.sequence_length, chunksize=chunksize, name=self.info.dataset_name,
 thinkers=thinkers, real_length=len(self))

 for i in tqdm.trange(round(len(self) / chunksize), desc="Saving", unit='files'):
 fp = toplevel / str(i)
 accumulated = list()
 for j in range(min(chunksize, len(self) - i*chunksize)):
 accumulated.append(self[i*chunksize + j])
 accumulated = [torch.stack(z) for z in zip(*accumulated)]
 if compressed:
 np.savez_compressed(fp, *[t.numpy() for t in accumulated])
 else:
 np.savez(fp, *[t.numpy() for t in accumulated])

# TODO Convenience functions or classes for leave one and leave multiple datasets out.


class DumpedDataset(DN3ataset):

 def __init__(self, toplevel, cache_all=False, summary_file='dataset-dump.npz', info=None, cache_chunk_factor=0.1):
 super(DumpedDataset, self).__init__()
 self.toplevel = Path(toplevel)
 assert self.toplevel.exists()
 summary_file = self.toplevel / summary_file
 self.info = info
 self._summary = np.load(summary_file, allow_pickle=True)
 self.thinkers = self._summary['thinkers'].flat[0]
 self._chunksize = self._summary['chunksize']
 self._num_per_cache = int(cache_chunk_factor * self._chunksize)
 self._len = self._summary['real_length']
 assert summary_file.exists()
 self.filenames = sorted([f for f in self.toplevel.iterdir() if f.name != summary_file.name],
 key=lambda f: int(f.stem))

 self.cache = np.empty((self._len, len(self.channels), self.sequence_length)) if cache_all else None
 self.aux_cache = [None for _ in range(self._len)] if cache_all else None

 def __str__(self):
 ds_name = self.info.dataset_name if self.info is not None else "Dumped"
 return ">> {} | {} people | {} trials | {} channels | {} samples/trial | {}Hz | {} transforms | ". \
 format("ds_name", len(self.thinkers), len(self), len(self.channels),
 self.sequence_length, self.sfreq, len(self._transforms))

 def __len__(self):
 return self._len

 @property
 def sfreq(self):
 return self._summary['sfreq']

 @property
 def channels(self):
 return self._summary['channels']

 @property
 def sequence_length(self):
 return self._summary['sequence_length']

 def get_thinkers(self):
 return list(self.thinkers.keys())

 def preprocess(self, preprocessor: Preprocessor, apply_transform=True):
 raise DN3atasetException("Can't preprocess dumped dataset. Load from original files to do this.")

 def __getitem__(self, index):
 if self.aux_cache is not None and self.aux_cache[index] is not None:
 print(f"Hit: chunk {index // self._chunksize}, id: {id(self.cache[index // self._chunksize])}")
 data = [self.cache[index], *self.aux_cache[index]]

 else:
 idx = index // self._chunksize
 offset = index % self._chunksize

 data = np.load(self.filenames[idx], allow_pickle=True)
 if self.aux_cache is not None and self.aux_cache[index] is None:
 # Put all loaded indexes into the cache
 # for i in set([offset] + np.random.choice(range(self._chunksize), self._num_per_cache, replace=False)):
 # self.cache[int(idx * self._chunksize) + i] = [torch.from_numpy(data[f])[i] for f in data.files]
 for i in [offset]:
 self.cache[index] = torch.from_numpy(data['arr_0'])
 self.aux_cache[index] = [torch.from_numpy(data[f])[i] for f in data.files[1:]]
 data = [self.cache[index], *self.aux_cache[index]]
 else:
 data = [torch.from_numpy(data[f])[offset] for f in data.files]

 return self._execute_transforms(*data)

