import os
import numpy as np 
import pandas as pd 
import math

from tqdm import tqdm
from data_loading import SubjectData
from data_loading.trial_data import TrialData
from args import *

def sent_context_windows(movie_id, movie):
    movie['idx_in_sentence'] = movie['idx_in_sentence'].astype(int)
    contexts = []
    words = []
    filtered_words = []
    indices = []
    movie_index_iter = tqdm(movie.index, desc = f'Generating contexts for {movie_id}')
    for i in movie_index_iter:
        sent = movie.loc[i, 'sentence']
        word = movie.loc[i, 'text']
        idx_in_sentence = movie.loc[i, 'idx_in_sentence']
        split_sent = [w.split('\'') for w in sent.split(' ')]
        # for each word in the split sentence, this indicates sublist membership. So [We've, got, him] -> [[We, ve], [got], [him] -> [0, 0, 1, 2]
        split_sent_idxs = [i for (i,sublist) in enumerate(split_sent) for j in sublist] 

        #the below is to deal with a bug where the idx_in_sentence is wrong
        if idx_in_sentence >= len(split_sent_idxs) or \
            word not in split_sent[split_sent_idxs[idx_in_sentence]]:
            #if the index is too big (meaning we missed the word)
            filtered_words.append(True)
            idx_in_sentence = 0
        else:
            filtered_words.append(False)
            context_end_idx = split_sent_idxs[idx_in_sentence]+1
            context_slice = split_sent[:context_end_idx]
            context = ['\''.join(x) for x in context_slice]
            context = ' '.join(context)
            contexts.append(context)
            words.append(word)
            indices.append(i)
    return contexts, words, filtered_words, indices

def find_avg_response(y):
    assert len(y.shape) == 3
    y = np.average(y, axis = -1)
    return y

def make_y_targets(y, args):
    lags = []
    lag = 0
    duration = args.time_window/1000.0 #200 ms from Goldstein
    sample_frequency = 2048
    upper_bound = args.context_duration - 0.5
    lower_bound = int(args.context_delta * (10**3))
    targets = []
    num_iters = int(upper_bound/0.025 + 1)
    progress = tqdm(total = num_iters, desc = f'Averaging SEEG response over {duration}-second windows')
    while lag <= upper_bound:
        lags.append(lower_bound + round(lag*(10**3)))
        start = math.ceil(lag * sample_frequency)
        end = start + math.ceil(duration * sample_frequency)
        window_y = y[:, :, start: end]
        window_y = find_avg_response(window_y)
        lag += 0.025 #25ms step size from Goldstein
        targets.append(window_y.copy())
        progress.update(1)
    progress.close()
    targets = np.float32(np.stack(targets, axis = 2))

    return targets, lags

def safe_remove(l, x):
    if x in l:
        l.remove(x)

def get_electrodes(args):
    dummy_trial = TrialData(args.subject,
                        args.trial_list[0], 
                        args.dataset_dir, args)
    electrodes = dummy_trial.get_brain_region_localization()
    #Trigger channels: remove from any further processing and analysis.
    safe_remove(electrodes, 'DC4')
    safe_remove(electrodes, 'DC10')
    safe_remove(electrodes, 'TRIG4')
    return list(electrodes)

def make_lang_stim(args, subject_data):
    print('Making language stimulus dataframe')
    transcript_df = subject_data.words
    image_path = []
    image_name = []
    contexts = []
    filtered_words = []
    for trial in subject_data.trials:
        movie_id = trial.movie_id
        movie_words = transcript_df[transcript_df['movie_id'] == movie_id].reset_index(drop=True)
        movie_frames = [x for x in os.listdir(os.path.join('movie-frames', f'{movie_id}-images')) if x[-3:] == 'png' and x[0] != '.']
        movie_frames = [f'{movie_id}-{i}.png' for i in movie_words.index if f'{movie_id}-{i}.png' in movie_frames]
        image_load = tqdm(movie_frames, desc = f'Loading images for {movie_id}')
        for img in image_load:
            image_path.append(os.path.join(args.frames_path, f'{movie_id}-images', img))
            image_name.append(img)
        trial_contexts, _, trial_filtered_words, _ = sent_context_windows(movie_id, movie_words)
        contexts += trial_contexts
        filtered_words += trial_filtered_words
    assert len(filtered_words) == len(transcript_df.index)
    transcript_df['image_path'] = image_path
    transcript_df['image_name'] = image_name
    transcript_df['ecog_idx'] = range(subject_data.neural_data.shape[1])
    transcript_df = transcript_df[~np.array(filtered_words)].reset_index(drop=True) #Filter out all no-context words after we match words to images!
    transcript_df['context'] = contexts
    return transcript_df

def add_sent_info(vision_stimulus, language_stimulus):
    contexts = []
    for scene_time in vision_stimulus['start']:
        closest_idx = (scene_time - language_stimulus['start']).abs().idxmin()
        sentence = language_stimulus.loc[closest_idx, 'sentence']
        contexts.append(sentence)
    vision_stimulus['context'] = contexts
    return vision_stimulus

def make_vis_stim(args, subject_data):
    print('Making vision stimulus dataframe')
    transcript_df = subject_data.words
    image_path = []
    image_name = []
    for trial in subject_data.trials:
        movie_id = trial.movie_id
        movie_words = transcript_df[transcript_df['movie_id'] == movie_id].reset_index(drop=True)
        scene_nums = movie_words['Scene Number'].tolist()
        for scene_num in scene_nums:
            image_path.append(os.path.join(args.scenes_path, f'{movie_id}', f'{movie_id}-{scene_num}.jpg'))
            image_name.append(f'{movie_id}-{scene_num}.jpg')
    transcript_df['image_path'] = image_path
    transcript_df['image_name'] = image_name
    transcript_df['ecog_idx'] = range(subject_data.neural_data.shape[1])
    trials = ''.join(args.trial_list)
    language_stimulus = pd.read_csv(f'data-by-subject/{args.subject}/{trials}_word_stimulus_metadata.csv')
    transcript_df = add_sent_info(transcript_df, language_stimulus)
    return transcript_df

def make_regression_dataframes(args):
    print(f'Making regression dataframes for {args.subject}, {args.trial_list} with {args.alignment} alignment')
    args.context_duration=4.5 #start + 45000 ms for 2500 ms upper bound, hardcoding for now.
    args.context_delta=-2.0
    electrodes = get_electrodes(args)
    args.electrode_list = electrodes
    subject_data = SubjectData(args.subject, args.trial_list, cached_transcript_aligns = args.cached_transcript_aligns, data_dir = args.dataset_dir,
                                data_params = None, duration = args.context_duration, delta = args.context_delta, electrodes = args.electrode_list, alignment = args.alignment)

    if not os.path.exists(f'../data-by-subject/{args.subject}'):
        os.makedirs(f'data-by-subject/{args.subject}')

    if args.alignment == 'language':
        transcript_df = make_lang_stim(args, subject_data)
        align_str = 'word'
    else:
        transcript_df = make_vis_stim(args, subject_data)
        align_str = 'scene'
    response_data_columns = [f'stimuli_{i}' for i in range(len(transcript_df.index))]
    transcript_df['index'] =  response_data_columns
    trials = ''.join(args.trial_list)
    print('Saving stimulus dataframe to csv')
    transcript_df.to_csv(os.path.join('data-by-subject', args.subject, f'{trials}_{align_str}_stimulus_metadata.csv'))

    print('Making neural response dataframe')
    neural_data = subject_data.neural_data
    assert neural_data.shape[1] == len(transcript_df.index)

    neural_data, lags = make_y_targets(neural_data, args)
    neural_data = np.transpose(neural_data, axes = (0, 2, 1))
    neural_data = np.reshape(neural_data, (neural_data.shape[0]*neural_data.shape[1], neural_data.shape[2]))
    elec_lag_combo = []
    for elec in args.electrode_list:
        for lag in lags:
            elec_lag_combo.append((elec, lag))
    elecs = np.array([x[0] for x in elec_lag_combo])
    lags = np.array([str(x[1]) for x in elec_lag_combo])
    arrays = [elecs, lags]
    response_df = pd.DataFrame(neural_data, index=arrays)
    response_df.index.names = ['Electrode', 'times']
    response_df.columns = response_data_columns
    trials = ''.join(args.trial_list)
    print('Saving neural response dataframe to parquet')
    response_df.to_parquet(os.path.join('data-by-subject', args.subject, f'{trials}_{align_str}_response_data-{args.time_window}mswindow.parquet.gzip'), compression = 'gzip')

    if not os.path.exists(os.path.join('data-by-subject', args.subject, 'electrode_metadata.csv')):
        print('Making electrode metadata dataframe')
        subject_data.regions_df.to_csv(os.path.join('data-by-subject', args.subject, 'electrode_metadata.csv'))
    return response_df, transcript_df

if __name__ == '__main__':
    args = data_args()
    make_regression_dataframes(args)
    pass