from scipy import signal, stats#TODO remove import
import os
import numpy as np

from torch.utils import data
from .trial_data import TrialData
from .trial_data_reader import TrialDataReader
from typing import Optional, List, Dict, Any, Tuple
import pandas as pd
from types import SimpleNamespace


class SubjectData():
    def __init__(self, subject: str, trials: List[str], data_dir: str, data_params: SimpleNamespace,
                 cached_transcript_aligns: str, save_path: Optional[str]=None,
                 duration: int=3, delta: int=-1, electrodes: Optional[List[str]]=None, 
                 words: Optional[List[str]]=None, high_gamma=False,
                 rereference=None, despike=None, normalization=None, alignment = 'language') -> None:

        '''
            Input: 
                subject=subject id
                trials=list of trial ids
                data_dir=path to ecog data
                cached_transcript_aligns=path to save aligned data
                save_path=path to saved dataset
                duration=context window after word onset
                delta=context window start before word onset
                electrodes=only take data from these electrodes. 
                           An empty list is assumed to mean all the electrodes. 
                           NOTE: assume that if this argument is given, the 
                           electrodes are present across all trials
                words=only take data for these words
                subject=subject id
                trials=list of trial ids
                data_dir=path to ecog data
                cached_transcript_aligns=path to load/save aligned data
                duration=context window after word onset
                delta=context window start before word onset
            returns:
                numpy array of words
                numpy array of shape [n_electrodes, n_words, n_samples] which holds the 
                    aligned data across all trials
        '''
        if electrodes is None:
            self.selected_electrodes = []
        else:
            self.selected_electrodes = electrodes

        if words is None:
            self.selected_words = []
        else:
            self.selected_words = words
        self.data_params = data_params

        self.high_gamma = high_gamma
        self.rereference = rereference
        self.normalization = normalization
        self.despike = despike

        self.words, self.neural_data, self.trials = self.get_subj_data(subject, trials, 
                                                               data_dir, 
                                                               cached_transcript_aligns=cached_transcript_aligns,  
                                                               duration=duration, delta=delta, alignment = alignment)

    def get_subj_data(self, subject: str, trial_ids: List[str], data_dir: str, 
                      cached_transcript_aligns: str, duration: int=3, delta: int=-1, alignment = 'language') -> Tuple[pd.DataFrame, np.ndarray, List[TrialData]]:

        words, seeg_data, trials = [], [], []
        for trial in trial_ids:
            if cached_transcript_aligns: #TODO: I want to make this automatic
                cached_transcript_aligns = os.path.join(cached_transcript_aligns, subject, trial)
                os.makedirs(cached_transcript_aligns, exist_ok=True)
            t = TrialData(subject, trial, 
                          data_dir, self.data_params, alignment = alignment)
            reader = TrialDataReader(t, cached_transcript_aligns, self.selected_electrodes,
                                     self.selected_words,
                                     high_gamma=self.high_gamma, 
                                     normalization=self.normalization,
                                     despike=self.despike,
                                     rereference=self.rereference)

            trial_words, seeg_trial_data = reader.get_aligned_predictor_matrix(duration=duration, delta=delta, alignment = alignment)
            assert (range(seeg_trial_data.shape[1]) == trial_words.index).all()
            trial_words['movie_id'] = t.movie_id
            trials.append(t)
            words.append(trial_words)
            seeg_data.append(np.float16(seeg_trial_data[:, trial_words.ecog_idx, :]))

        neural_data = np.float16(np.concatenate(seeg_data, axis=1)) #NOTE: float16 is regression usage. We are loading too much data that will overlead the RAM.
        words_df = pd.concat(words) #NOTE the index will not be unique, but the location will
        return words_df, neural_data, trials

if __name__ == '__main__':
    #For debugging purposes
    SubjectData(subject = 'm00185', trials = ['trial000'], data_dir = '/storage/datasets/neuroscience/ecog', data_params = None,
                 cached_transcript_aligns = None, save_path=None,
                 duration=3, delta=-1, alignment = 'vision')