from scipy.signal import hilbert, chirp
from tqdm import tqdm
import os
import numpy as np
import pandas as pd

from typing import Optional, List, Tuple
from .utils import compute_m5_hash
from .h5_data_reader import H5DataReader

class TrialDataReader(H5DataReader):
    def __init__(self, trial_data,
                 cached_transcript_aligns: str,
                 selected_electrodes: List[str],
                 selected_words: List[str],
                 rereference=None,
                 normalization=None,
                 despike=None,
                 high_gamma=False) -> None:
        '''
            Input: trial_data=ecog and word data to perform processing on
        '''
        super().__init__(trial_data, selected_electrodes,
                         rereference=rereference, 
                         normalization=normalization,
                         despike=despike,
                         high_gamma=high_gamma)

        self.start_col = 'start'
        self.end_col = 'end'
        self.trig_time_col = 'movie_time'
        self.trig_idx_col = 'index'
        self.est_idx_col = 'est_idx'
        self.est_end_idx_col = 'est_end_idx'
        self.word_time_col = 'word_time'
        self.word_text_col = 'text'
        self.is_onset_col = 'is_onset'
        self.is_offset_col = 'is_offset'

        self.aligned_script_df = self.get_aligned_movie_transcript(cached_transcript_aligns)

        self.selected_words = selected_words

    def estimate_sample_index(self, t, near_t, near_trig):
        '''
            input: movie time t and the closest trigger time
            returns: linear interpolation to the nearest sample index to t
        '''
        samp_frequency = self.trial_data.samp_frequency
        trig_diff = (t-near_t)*samp_frequency
        return round(near_trig+trig_diff)

    def add_estimated_sample_index(self, w_df: pd.DataFrame) -> pd.DataFrame:
        '''
            input: a dataframe of word features
            returns: the input word dataframe, but augmented with onset and offset times
        '''
        tmp_w_df = w_df.copy(deep=True)
        trigs_df = self.trial_data.get_trigger_times()
        last_t = trigs_df.loc[len(trigs_df) - 1, self.trig_time_col]

        for i, t, endt in zip(w_df.index, w_df[self.start_col], w_df[self.end_col]):
            if t > last_t:
                break
            idx = (abs(trigs_df[self.trig_time_col] - t)).idxmin()
            tmp_w_df.loc[i, :] = w_df.loc[i, :]
            trigger_t = trigs_df.loc[idx, self.trig_time_col]
            trigger_idx = trigs_df.loc[idx, self.trig_idx_col] 
            tmp_w_df.loc[i, self.est_idx_col] = self.estimate_sample_index(t, trigger_t, trigger_idx)

            end_idx = (abs(trigs_df[self.trig_time_col] - endt)).idxmin()
            end_trigger_t = trigs_df.loc[end_idx, self.trig_time_col]
            end_trigger_idx = trigs_df.loc[end_idx, self.trig_idx_col] 
            tmp_w_df.loc[i, self.est_end_idx_col] = self.estimate_sample_index(endt, end_trigger_t, end_trigger_idx)
        return tmp_w_df

    def get_aligned_movie_transcript(self, cached_transcript_aligns: str) -> pd.DataFrame:
        '''
            returns the dataframe of word data for the trial, but augmented with onset and offset times
        '''

        save_path = None
        if cached_transcript_aligns:
            save_path = os.path.join(cached_transcript_aligns, "aligned_script.h5") 
            computed_hash = compute_m5_hash(self.trial_data.transcript_file)

        if save_path and os.path.exists(save_path): 
            cached_df = pd.read_hdf(save_path, key='transcript_data')
            if cached_df['orig_transcript_hash'][0] == computed_hash:
                return cached_df

        words_df = self.trial_data.get_movie_transcript()
        words_df = self.add_estimated_sample_index(words_df)

        if save_path:
            words_df['orig_transcript_hash'] = computed_hash #TODO eventually look into storing metadata
            words_df.to_hdf(save_path, key='transcript_data')

        return words_df

    def select_words(self, words_df: pd.DataFrame) -> pd.DataFrame:
        '''
        Input:
            word_window_arr = array of shape [n_electrodes, n_words, n_samples]
            words_df = pandas dataframe of word features 
        Output:
            pandas dataframe of the word data where only
                rows with the selected words are present
            
        '''
        filtered_df = words_df[words_df['text'] != '']

        if self.selected_words==[]:
            return filtered_df

        filtered_df = filtered_df[filtered_df['text'].isin(self.selected_words)]
        return filtered_df
  
    def get_aligned_non_words_matrix(self, duration: int=3, delta: int=-1, save_path: Optional[str]=None) -> Tuple[pd.DataFrame, np.ndarray]:
        '''
            input:
                delta=context to take before onset
                duration=context to take after start. Note that the start = onset + delta.
                save_path=where to save/load the aligned matrix
            output:
                an array of shape [n_electrodes, n_words, n_samples]
                for each electrode, each row is the ecog data from self.trial_data for a given word
        '''
        print("Getting linguistic and non-linguistic intervals")
        filtered_data = self.get_filtered_data()

        w_df = self.aligned_script_df
        w_df = self.select_words(w_df)

        samp_frequency = self.trial_data.samp_frequency
        window_duration = int(duration*samp_frequency)

        w_df = w_df.iloc[w_df[self.est_idx_col].dropna().index] #drop all rows that don't have a start time

        start_idxs = w_df[self.est_idx_col].astype(int) .tolist()
        end_idxs = w_df[self.est_end_idx_col].astype(int).tolist()
        word_intervals = list(zip(start_idxs, end_idxs))
        total_length = filtered_data.shape[-1]
        total_n_intervals = int(total_length/window_duration)
        all_intervals = [(i*window_duration, (i+1)*window_duration) for i in range(total_n_intervals)]
        intersect = lambda a, b: a[0] < b[1] and a[1] > b[0]
        intersect_with_word = lambda a: np.any([intersect(a,b) for b in word_intervals])
        non_word_intervals = list(filter(lambda x: not intersect_with_word(x), all_intervals))
        word_intervals = list(filter(lambda x: intersect_with_word(x), all_intervals))

        all_non_word_samples = np.stack([filtered_data[:, start:end] for (start,end) in non_word_intervals])
        all_word_samples = np.stack([filtered_data[:, start:end] for (start,end) in word_intervals])

        result = np.concatenate([all_word_samples, all_non_word_samples], axis=0)
        labels = np.repeat([True, False], [all_word_samples.shape[0], all_non_word_samples.shape[0]])
        result = np.transpose(result, [1,0,2]) #[n_electrodes, n_intervals, n_samples]

        return result, pd.DataFrame({"linguistic_content": labels})

    def get_aligned_predictor_matrix(self, duration: int=3, delta: int=-1, save_path: Optional[str]=None, alignment = 'language') -> Tuple[pd.DataFrame, np.ndarray]:
        '''
            input:
                delta=context to take before onset
                duration=context to take after start. Note that the start = onset + delta.
                save_path=where to save/load the aligned matrix
            output:
                an array of shape [n_electrodes, n_words, n_samples]
                for each electrode, each row is the ecog data from self.trial_data for a given word
        '''
        filtered_data = self.get_filtered_data()

        w_df = self.aligned_script_df
        if alignment == 'language':
            w_df = self.select_words(w_df)

        samp_frequency = self.trial_data.samp_frequency
        window_duration = int(duration*samp_frequency)
        window_onset = int(delta*samp_frequency)

        w_df = w_df.iloc[w_df[self.est_idx_col].dropna().index] #drop all rows that don't have a start time
        word_window_arr = np.empty((filtered_data.shape[0], len(w_df.index), window_duration))

        print('Generating aligned predictor matrix of size {}'.format((filtered_data.shape[0], len(w_df.index), window_duration)))
        start_idxs = w_df[self.est_idx_col].astype(int) + window_onset
        end_idxs = start_idxs + window_duration
        for i,word_idx in tqdm(enumerate(w_df.index)):
            row = w_df.loc[word_idx]
            try:
                start = start_idxs[word_idx]
                end = end_idxs[word_idx]
                word_window_arr[:, i, :] = filtered_data[:, start:end]
            except ValueError as err:
                print('Neural recording stopped before movie ended')
                break
        w_df['ecog_idx'] = range(word_window_arr.shape[1])
        return w_df, word_window_arr #TODO check that this w_df is the one being used at train time  

    def get_aligned_sentence_data(self, save_path: Optional[str]=None) -> Tuple[pd.DataFrame, np.ndarray]:
        '''
            input:
                delta=context to take before onset
                duration=context to take after start. Note that the start = onset + delta.
                save_path=where to save/load the aligned matrix
            output:
                an array of shape [n_electrodes, n_words, n_samples]
                for each electrode, each row is the ecog data from self.trial_data for a given word
        '''
        filtered_data = self.get_filtered_data()

        w_df = self.aligned_script_df
        w_df = self.select_words(w_df)
        sentence_onset_locs = w_df[w_df.idx_in_sentence==0].index
        sentence_offset_locs = sentence_onset_locs[1:] - 1
        sentence_onset_locs = sentence_onset_locs[:-1]
        sentence_onset_idxs = w_df.loc[sentence_onset_locs][self.est_idx_col]
        sentence_offset_idxs = w_df.loc[sentence_offset_locs][self.est_end_idx_col]
        diffs = np.array(sentence_offset_idxs.tolist()) - np.array(sentence_onset_idxs.tolist())
        sentence_df = w_df.loc[sentence_onset_locs][['text', 'sentence']]

        short_snt_idxs = diffs < 20*10**3
        non_nan_idxs = ~np.isnan(diffs)
        good_samples = short_snt_idxs & non_nan_idxs
        sentence_df = sentence_df.iloc[good_samples]
        sentence_offset_idxs = sentence_offset_idxs[good_samples]
        sentence_onset_idxs = sentence_onset_idxs[good_samples]
        diffs = diffs[good_samples]
        assert not np.isnan(diffs).any()

        print(f'Generating aligned sentence data of size {len(diffs)}')
        sentences = []
        assert sentence_offset_idxs.iloc[-1] < filtered_data.shape[-1]
        for (start, end) in tqdm(zip(sentence_onset_idxs, 
                                     sentence_offset_idxs)):
            s_data = filtered_data[:, int(start):int(end)]
            sentences.append(s_data)
        return sentence_df, sentences