import pandas as pd
import numpy as np
import wfdb
import ast

import matplotlib.pyplot as plt
from collections import Counter

import neurokit2 as nk
from tqdm import tqdm
import sys
import os
import scipy

def load_raw_data(df, sampling_rate, path):
    if sampling_rate == 100:
        data = [wfdb.rdsamp(path+f) for f in df.filename_lr]
    else:
        data = [wfdb.rdsamp(path+f) for f in df.filename_hr]
    data = np.array([signal for signal, meta in data])
    return data

def aggregate_diagnostic(y_dic):
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))


path = './'
sampling_rate=500

# load and convert annotation data
Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id')
Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x))

# Load raw signal data
X = load_raw_data(Y, sampling_rate, path)

# Load scp_statements.csv for diagnostic aggregation
agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

# Apply diagnostic superclass
Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic)

# Split data into train and test
test_fold = 10
# Train
X_train = X[np.where((Y.strat_fold != 9) & (Y.strat_fold != 10))]
y_train = Y[(Y.strat_fold != 9) & (Y.strat_fold != 10)].diagnostic_superclass
# Val
X_val = X[np.where(Y.strat_fold == 9)]
y_val = Y[Y.strat_fold == 9].diagnostic_superclass
# Test
X_test = X[np.where(Y.strat_fold == test_fold)]
y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass


train_index = []
for y in y_train.values:
    if len(y) == 1:
        train_index.append(True)
    else:
        train_index.append(False)
        
val_index = []
for y in y_val.values:
    if len(y) == 1:
        val_index.append(True)
    else:
        val_index.append(False)        

test_index = []
for y in y_test.values:
    if len(y) == 1:
        test_index.append(True)
    else:
        test_index.append(False)

train_dict = {'ECG_signal':X_train, 'label':y_train_single}
val_dict = {'ECG_signal':X_val, 'label':y_val_single}
test_dict = {'ECG_signal':X_test, 'label':y_test_single}

np.save("train_path", train_dict)
np.save("val_path", val_dict)
np.save("test_path", test_dict)