#*----------------------------------------------------------------------------*
#* Copyright (C) 2020 ETH Zurich, Switzerland                                 *
#* SPDX-License-Identifier: Apache-2.0                                        *
#*                                                                            *
#* Licensed under the Apache License, Version 2.0 (the "License");            *
#* you may not use this file except in compliance with the License.           *
#* You may obtain a copy of the License at                                    *
#*                                                                            *
#* http://www.apache.org/licenses/LICENSE-2.0                                 *
#*                                                                            *
#* Unless required by applicable law or agreed to in writing, software        *
#* distributed under the License is distributed on an "AS IS" BASIS,          *
#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.   *
#* See the License for the specific language governing permissions and        *
#* limitations under the License.                                             *
#*                                                                            *
#* Authors: Batuhan Toemekce, Burak Kaya, Michael Hersche                     *
#*----------------------------------------------------------------------------*

#!/usr/bin/env python3

"""
Loads '.edf' MI data from Physionet 
"""


import os
import numpy as np
# pyEDFlib is a python library to read/write EDF+/BDF+ files based on EDFlib.
import pyedflib as edf
import statistics
import random
import csv


__author__ = "Batuhan Tomekce, Burak Alp Kaya, Michael Hersche"
__email__ = "tbatuhan@ethz.ch, bukaya@ethz.ch, herschmi@ethz.ch"

def get_data(path, long = False, normalization = 0,subjects_list=range(1,110), n_classes=4, context=False):
    '''
    Load data samples and return it as one array 

    Parameters:
    -----------
    path:   string
        path to .edf data of 
    normalization   int {0,1}
        normalization per trial 
        0: no normalization; 1: normalized per channel
    long:    bool 
        length of read time window
        True: Trials of length 6s returned; False: Trials of length 3s returned
    subjects_list   list [1, .. , 109] 
        list of subject numbers to be loaded
    n_classes:      int 
        number of classes
        2: L/R, 3: L/R/0, 4
    
    Return: X:  numpy array (n_sub*n_trials, 64, n_samples) 
                EEG data 
            y:  numpy array (n_sub*n_trials, n_samples)
                labels 
    '''
    # Define subjects whose data is not taken, for details see data tester added 106 again to analyze it, deleted from the exluded list
    excluded_subjects = [88,92,100,104]
    # Define subjects whose data is taken, namely from 1 to 109 excluding excluded_subjects
    subjects = [x for x in subjects_list if (x not in excluded_subjects)]
   

    mi_runs = [1, 4, 6, 8, 10, 12, 14]
    # Extract only requested number of classes
    if(n_classes == 3):
        print('Returning 3 Class data')
        mi_runs.remove(6) # feet
        mi_runs.remove(10) # feet
        mi_runs.remove(14) # feet
    elif(n_classes == 2):
        print('Returning 2 Class data')
        mi_runs.remove(6) # feet
        mi_runs.remove(10) # feet
        mi_runs.remove(14) # feet
        mi_runs.remove(1) #rest 
    print(f'Data from runs: {mi_runs}')

    X, y = read_data(subjects = subjects,runs = mi_runs, path=path, long=long, context=context)
   
    # do normalization if wanted
    if(normalization == 1):
        print("NORMALIZING")
        #TODO: declare std_dev, mean arrays to return
        for i in range(X.shape[0]):
            for ii in range(X.shape[1]):
                std_dev = statistics.stdev(X[i,ii])
                mean = statistics.mean(X[i,ii])
                X[i,ii] = (X[i,ii] - mean) / std_dev
        
    return X, y
    

    
def read_data(subjects , runs, path, long=False, context=False):
    '''
    Load data samples and return it as one array 

    Parameters:
    -----------
    subjects   list [1, .. , 109] 
        list of subject numbers to be loaded
    path:   string
        path to .edf data of 
    runs    list 
        runs to read from 
    long:    bool 
        length of read time window
        True: Trials of length 6s returned; False: Trials of length 3s returned
    
    
    Return: X:  numpy array (n_sub*n_trials, 64, n_samples) 
                EEG data 
            y:  numpy array (n_sub*n_trials, n_samples)
                labels 
    '''

    """
    DATA EXPLANATION:
        
        LABELS:
        both first_set and second_set
            T0: rest
        first_set (real motion in runs 3, 7, and 11; imagined motion in runs 4, 8, and 12)
            T1: the left fist 
            T2: the right fist
        second_set (real motion in runs 5, 9, and 13; imagined motion in runs 6, 10, and 14)
            T1: both fists
            T2: both feet
        
        Here, we get data from the first_set (rest, left fist, right fist), 
        and also data from the second_set (rest, both feet).
        We ignore data for T1 from the second_set and thus return data for 
        four classes/categories of events: Rest, Left Fist, Right Fist, Both Feet.
    """
    base_file_name = 'S{:03d}R{:02d}.edf'
    base_subject_directory = 'S{:03d}'
    
    # Define runs where the two different sets of tasks were performed
    baseline = np.array([1])
    first_set = np.array([4,8,12])
    second_set = np.array([6,10,14])
    
    # Number of EEG channels
    NO_channels = 64
    # Number of Trials extracted per Run
    NO_trials = 7
    
    # Define Sample size per Trial 
    if not long:
        n_samples = int(160 * 3) # 3s Trials: 480 samples
    else:
        n_samples = int(160 * 6) # 6s Trials: 960 samples 

    if context:
        context_before = int(160 * 2) # 2s context before -> 2s before cue
        context_after = int(160 * 3) # 3s context after -> 2s after cue
    else:
        context_before = 0
        context_after = 0
    
    # initialize empty arrays to concatanate with itself later
    X = np.empty((0,NO_channels,n_samples+context_before+context_after))
    y = np.empty(0)
    
    for subject in subjects:

        for run in runs:
            #For each run, a certain number of trials from corresponding classes should be extracted
            counter_0 = 0
            counter_L = 0
            counter_R = 0
            counter_F = 0
            
            # Create file name variable to access edf file
            filename = base_file_name.format(subject,run)
            directory = base_subject_directory.format(subject)
            file_name = os.path.join(path,directory,filename)
            # Read file
            f = edf.EdfReader(file_name)
            # Signal Parameters - measurement frequency
            fs = f.getSampleFrequency(0)
            # Number of eeg channels = number of signals in file
            n_ch = f.signals_in_file
            # Initiate eg.: 64*20000 matrix to hold all datapoints
            sigbufs = np.zeros((n_ch, f.getNSamples()[0]))
            
            for ch in np.arange(n_ch):
                # Fill the matrix with all datapoints from each channel
                sigbufs[ch, :] = f.readSignal(ch)
            
            # Get Label information
            annotations = f.readAnnotations()
            
            # close the file
            f.close()
            
            # Get the specific label information
            labels = annotations[2]
            points = fs*annotations[0]
            
            labels_int = np.empty(0)
            data_step = np.empty((0,NO_channels, n_samples+context_before+context_after))             
            
            if run in second_set:
                for ii in range(0,np.size(labels)):
                    if(labels[ii] == 'T0' and counter_0 < NO_trials):
                        continue
                        counter_0 += 1
                        labels_int = np.append(labels_int,[2])
                        
                    elif(labels[ii] == 'T2' and counter_F < NO_trials):
                        counter_F += 1
                        labels_int = np.append(labels_int,[3])
                        # change data shape and seperate events
                        data_step = np.vstack((data_step, np.array(sigbufs[:,int(points[ii])-context_before:int(points[ii])+n_samples+context_after])[None]))        
                
            elif run in first_set:
                for ii in range(0,np.size(labels)):
                    if(labels[ii] == 'T0' and counter_0 < NO_trials):
                        continue
                        counter_0 += 1
                        labels_int = np.append(labels_int, [2])
                        
                    elif(labels[ii] == 'T1' and counter_L < NO_trials):
                        counter_L += 1
                        labels_int = np.append(labels_int, [0])
                        data_step = np.vstack((data_step, np.array(sigbufs[:,int(points[ii])-context_before:int(points[ii])+n_samples+context_after])[None]))
                        
                    elif(labels[ii] == 'T2' and counter_R < NO_trials):
                        counter_R += 1
                        labels_int = np.append(labels_int, [1])
                        data_step = np.vstack((data_step, np.array(sigbufs[:,int(points[ii])-context_before:int(points[ii])+n_samples+context_after])[None]))
                        
            elif run in baseline:
                for ii in range(0,20):
                    if(counter_0 < 20):  
                        counter_0 += 1
                        labels_int = np.append(labels_int, [2])
                        if (ii*n_samples)-context_before < 0: # start too early
                            data_step = np.vstack((data_step, np.array(sigbufs[:,(ii*n_samples):(ii+1)*n_samples+context_before+context_after])[None]))
                        elif (ii+1)*n_samples+context_after > len(sigbufs[0]): # end too late
                            data_step = np.vstack((data_step, np.array(sigbufs[:,(ii*n_samples)-context_before-context_after:(ii+1)*n_samples])[None]))
                        else:   # normal case
                            data_step = np.vstack((data_step, np.array(sigbufs[:,(ii*n_samples)-context_before:((ii+1)*n_samples+context_after)])[None]))
                # randomly choose resting trials
                np.random.seed(7)
                index = random.randint(0*fs+context_before,57*fs-context_after)
                labels_int = np.append(labels_int, [2])
                data_step = np.vstack((data_step, np.array(sigbufs[:,(index)-context_before:(index+n_samples+context_after)])[None]))
               
            # concatenate arrays in order to get the whole data in one input array    
            X = np.concatenate((X,data_step))
            y = np.concatenate((y,labels_int))
            # print(data_step)
        
    return X, y

orig_data_path = "/scratch/sem23h11/physionet.org/files/eegmmidb/1.0.0"
if not os.path.exists("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns"):
    os.mkdir("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns")
if not os.path.exists("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/manifests"):
    os.mkdir("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/manifests")

if os.path.exists("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/manifests/manifest.tsv"):
    os.remove("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/manifests/manifest.tsv")
with open("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/manifests/manifest.tsv", 'wt') as out_file:
    tsv_writer = csv.writer(out_file, delimiter='\t')

    # iterate over all files in the physionet dataset
    for subj in range(1,85): # exclude test data
        data = get_data(orig_data_path, normalization = 0, subjects_list = [subj], n_classes = 4, context = True)
        if data is None:
            continue
        data = data[0]
        if os.path.exists("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/S" + str(subj)):
            os.system("rm -r /scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/S" + str(subj))
        os.mkdir("/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/S" + str(subj))
        #save content as .npy file
        for trial in range(data.shape[0]):
            path = "/scratch/sem23h11/BrainBERT/pretrain_data_segments_context_allchns/S" + str(subj) + "/T" + str(trial+1) + ".npy"
            trial_data = data[trial]
            np.save(path, trial_data)
            tsv_writer.writerow([path, trial_data.shape[1]]) # use time series length as length
                





