#MIT-BIH data preprocessing code taken from https://github.com/eddymina/ECG_Classification_Pytorch.git

import warnings

import importlib
import random 
import time 
import pandas as pd 
import numpy as np
from collections import Counter
from scipy import signal
from scipy.signal import find_peaks, resample
import matplotlib.pyplot as plt 
import os 
from os import listdir
import seaborn as sns
from os.path import isfile, join
import sys
import warnings 
from tqdm import tqdm 
import pickle as pkl

DATADIR = "mitbih_database"
OUTDIR = "out"
MODELDIR = "models"

ALL_MITBIHCLASSES = {'N': 'Normal beat',
 'L': 'Left bundle branch block beat',
 'R': 'Right bundle branch block beat',
 'A': 'Atrial premature beat',
 'a': 'Aberrated atrial premature beat',
 'J': 'Nodal (junctional) premature beat',
 'S': 'Supraventricular premature beat',
 'V': 'Premature ventricular contraction',
 'F': 'Fusion of ventricular and normal beat',
 '[': 'Start of ventricular flutter/fibrillation',
 '!': 'Ventricular flutter wave',
 ']': 'End of ventricular flutter/fibrillation',
 'e': 'Atrial escape beat',
 'j': 'Nodal (junctional) escape beat',
 'E': 'Ventricular escape beat',
 '/': 'Paced beat',
 'f': 'Fusion of paced and normal beat',
 'x': 'Non-conducted P-wave (blocked APB)',
 'Q': 'Unclassifiable beat',
 '|': 'Isolated QRS-like artifact'}

RELEVANT_CLASS_IDS_TO_NAMES = {0:'N',1:'S',2:'V',3:'F',4:'Q'}
RELEVANT_CLASS_NAMES_TO_IDS = {v:k for (k,v) in RELEVANT_CLASS_IDS_TO_NAMES.items()}
CLASSES_REDUCER ={'N':['N','L','R','e','j'],
                 'S':['S','A','a','J'],'V':['V','E'],'F':['F'],'Q':['/','Q','f']}


def z_norm(result):
    """
    Normalize Data. This fits
    all values between 0 and 1. 
    """    
    result = (result-min(result))/(max(result)-min(result))
    return result


# PATIENT DETAILS AT https://archive.physionet.org/physiobank/database/html/mitdbdir/records.htm

def get_patient_data(datadir,patient,norm=True, sample_plot=False):
    """
    Assumes that all folder called mit_data is next folder 
    in current directory. Can change this function internally 
    or write your own personalized one. 
    
    Input: 
        patient:: Patient Number [Str or Int]
        norm:: (optional) =True --> Normalize Data 
        sample_plot:: (optional) Show Patient ECG Signal [True or False]
    Output: 
        Normalized Signal Data, Ecg Notes 
            Ecg_Notes:: Labeled Sample Peaks and Heart Conditions 
            Ecg_Data:: np.array of signal
    """
    widths= [4,8,11,6,3,5,5,8]
    
    patient=str(patient)
    ecg_notes= pd.read_fwf(os.path.join(datadir,
                            '{}annotations.txt'.format(patient)),widths=widths).drop(['Unnamed: 0'],axis=1)
    ecg_data= pd.read_csv(os.path.join(datadir,'{}.csv'.format(patient)))                
    ecg_data.columns= ['samp_num','signal','V']
    ecg_notes=ecg_notes[['Sample #','Type','Aux']]
    ecg_notes.columns=['sample_num','type','aux']
    if norm == True:
        ecg_data.signal= z_norm(ecg_data.signal)
    if sample_plot == True:
        peaklist= ecg_notes.sample_num.astype(int).values
        plt.figure()
        b=np.random.choice(len(ecg_data.signal))
        plt.plot(ecg_data.signal)
        plt.xlim(b,b+5000)
        plt.plot(peaklist, ecg_data.signal[peaklist], "x")
        plt.title('  Sample Patient {} data'.format(patient))
        return None
        
    return ecg_data.signal,ecg_notes


def hr_sample_len(HR,fs=360):
    """
    Convert a HR to sample len
    
    """
    return int((fs*60)/HR)# 60 seconds * samples/sec 

def get_HR(peaklist,fs):
    """
    Takes in List of Sample Peaks and sampling freq.
    Computes average distance between peaks w/r time.
    Returns BPM
    
    Inputs: 
        peaklist:: list of HR R peaks 
        fs:: sampling rate 
    Output: 
        HR (float)
    """
    RR_list = []
    for beat in range(len(peaklist)-1):
        RR_interval = (peaklist[beat+1] - peaklist[beat]) #Calculate distance between beats in # of samples
        ms_dist = ((RR_interval / fs)) #Convert sample distances to s distances
        RR_list.append(ms_dist)
    bpm =60 / np.mean(RR_list) # (1 minute) / average R-R interval of signal
    return bpm

def zero_pad(lst):
    """
    Import a list of lists [[1 x n],[1 x c],[1 x m]]
    Create np array size: with the number of columns 
    being the length of longest list within list of lists.
    All shorter other links are zero padded. 
    
    Ex.
    Given [[1 x n],[1 x c],[1 x m]] and m > n > c
    return array [3,m]
    
    [[1 x 1],[1 x 2],[1 x 3]] --> [a1 0 0],
                                  [b1 b2 0],
                                  [m1 m2m m3]
                                  
    Input: [[list],[list],[list],...,]
    Output: Zero Padded Array [len(list) X len_longest_list]
    
    """
    pad = len(max(lst, key=len))
    return np.array([i + [0]*(pad-len(i)) for i in lst])

def isolate_patient_data(data_dir,patients,classes,classes_further,classes_reducer=None, \
                             min_HR= 40,max_HR= 140,fs=360,verbose=False,plot_figs=False):
    """
    Isolation Model. Examines Patients, Normalizes Signal,
    and creates a python array with a length of the number of heat
    beates by the lenght of the longest heart rate signal. Signals
    that are smaller this are zero padded. These represent the X 
    values that used for training. They data includes patient number,
    patient heart rate, and class of condition the heart beat corres-
    ponds too. 
    
    
    Input: 
        patients:: Patient Numbers list of Patient numbers [list]
        classes:: classes to be examined {dic}
        classes_further:: expansion of previous classes with names {dic}
        classes_reducer:: optional dictionary to reduce classes 
        min_HR:: (optional) minimum HR to consider (longer HR Sample Rate)
        max_HR:: (optional) max HR to consider (longer HR Sample Rate)
        fs:: (optional) sampling frequency --> 360 for this database
        verbose:: (optional) prints out some information per patient if true [boolean]
        plot_figs:: (optional) prints out HR and Heat Beat distributions 
        
    Output: 
        X,y np arrays 
        Isolated beat:: list of lists of each patient ecg data (unpadded)
    """   
    isolated_beat= []
    start=time.time()
    print('Examining {} patients...'.format(len(patients)))

    for i,patient in tqdm(enumerate(patients)):
        ecg_signal,ecg_notes= get_patient_data(data_dir,patient)
        peaklist= ecg_notes.sample_num.astype(int).values 
        for c in classes.values():
            class_loc=[]
            if classes_reducer != None:
                for rc in classes_reducer[c]:
                    class_loc.extend(ecg_notes.loc[ecg_notes.type == rc]['sample_num'].values.tolist())     
            else:
                class_loc= ecg_notes.loc[ecg_notes.type == c]['sample_num'].values 
            for n in range(1,len(peaklist)-1):
    
                if peaklist[n] in set(class_loc):
                    delta1= int(np.round((peaklist[n+1]-peaklist[n])/2))
                    delta2= int(np.round((peaklist[n]-peaklist[n-1])/2))
                    peak_data= ecg_signal[peaklist[n]-delta2:peaklist[n]+delta1] 
                    if hr_sample_len(max_HR) <= len(peak_data) <= hr_sample_len(min_HR):
                        isolated_beat.append([patient,get_HR(peaklist,fs=fs),c]+peak_data.tolist())

        if verbose == True:
            print('\nPatient {}...'.format(patient))
            print('Normalizing --> [0 1]')
            print('Patient HR',get_HR(peaklist,fs=fs))

    print('\nPadding...\n')
    isolated_beats= zero_pad(isolated_beat) 
    X=isolated_beats[:,3:].astype(float)
    y=isolated_beats[:,:3]
    avg_samp=np.array([len(l) for l in isolated_beat]).mean()
    print('\nAverage HR Sample Len: {:.2f} samples ({:.2f}s per beat)'.format(avg_samp,avg_samp/360))
    print('Average HR: {:.2f} bpm'.format(y[:,1].astype(float).mean()))
    
    if plot_figs== True:
        if len(patients)==1:
            print('\n*****Error Will Arise with Plot because only single sample used*****\n')
        print('Plotting...\n')
        plt.figure(figsize=(20,10))
        plt.subplot(121)
        x=[len(elem) for elem in isolated_beat]
        warnings.filterwarnings("ignore")
        sns.distplot(x,rug=True)
        plt.title('Heart Rate RR Width')
        plt.xlabel('HR Sample Interval Length')
        plt.subplot(122)
        sns.distplot(y[:,1].astype(float),rug=True)
        plt.title('Heart Rate Distribution')
        plt.xlabel('HR [bpm]')
        plt.show()
        warnings.resetwarnings()

    print('Data Loaded | Shape:{}\n'.format(isolated_beats.shape))
    for label,count in Counter(y[:,2].tolist()).items():
        print('    {} cases of {}\n'.format(count,classes_further[label]))
    print('{:.2f}min Runtime'.format((time.time()-start)/60))
    return X,y,isolated_beat

def resample_vals(X,samp_len=187):
    """
    Signal resampling function 
    """

    X_resamp= np.zeros((X.shape[0],samp_len))
    for i,x in enumerate(X):
        X_resamp[i]=resample(x,samp_len)
    return X_resamp


def get_data(data_dir,out_dir,resample_len = None):
    files = [f for f in listdir(data_dir) if isfile(join(data_dir, f))]
    patient_names = list(set([os.path.basename(_).split('.')[0].strip() for _ in files if _.endswith('csv')]))
    print(patient_names)
    X,y,isolated_beat= isolate_patient_data(data_dir = data_dir,patients=patient_names,classes=RELEVANT_CLASS_IDS_TO_NAMES,
                    classes_further=ALL_MITBIHCLASSES, classes_reducer=CLASSES_REDUCER, 
                     min_HR= 40,max_HR= 140,fs=360,verbose=False,plot_figs=False)


    print("X shape:", X.shape)
    print("y shape:", y.shape)
    print("sample y vals:: [patient#, HR, Condition Class]:",y[0])

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    
    if resample_len is not None:
        X = resample_vals(X,samp_len=resample_len)

    with open(os.path.join(out_dir,'mitbih.pkl'),'wb') as fout:
        pkl.dump((X,y,isolated_beat),fout)

    return 

if __name__=="__main__":
    get_data(data_dir = DATADIR,out_dir = OUTDIR,resample_len=187)
