# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import os
from math import isnan
import copy

def load_data(root_path, winsize, overlap):
    file_protocol_list = os.listdir(os.path.join(root_path, 'Protocol'))
    file_optinal_list = os.listdir(os.path.join(root_path, 'Optional'))
    file_list = file_protocol_list + file_optinal_list

    list_len = len(file_list)
    x_all, y_all, s_all = [], [], []
    for filenum in range(list_len):
        data_i = []
        filename = file_list[filenum]
        if filenum < len(file_protocol_list):
            data_i = np.loadtxt(os.path.join(root_path, 'Protocol', filename))
        else:
            data_i = np.loadtxt(os.path.join(root_path, 'Optional', filename))
        subject = int(filename.split('0')[1].split('.')[0])
        print('subject:', subject)
        x_i = np.hstack((data_i[:, 7:10], data_i[:, 10:16], data_i[:, 21:24],
                         data_i[:, 27:33], data_i[:, 38:41], data_i[:, 44:50]))
        y_i = data_i[:, 1]
        tx, ty, ts = getwin_replace(x_i, y_i, subject,
                                    winsize=winsize, overlap=overlap)
        if filenum == 0:
            x_all, y_all, s_all = tx, ty, ts
        else:
            x_all = np.vstack((x_all, tx))
            y_all = np.vstack((y_all, ty))
            s_all = np.vstack((s_all, ts))
    return x_all, y_all, s_all


def getwin_replace(x, y, s, winsize, overlap):
    data_num = len(x)
    overlap_size = int(winsize*overlap)
    stepsize = winsize-overlap_size
    head, tail = 0, winsize
    xx, yy = [], []
    while tail <= data_num:
        while head < data_num and y[head] == 0:
            head += 1
            tail = head + winsize
        if tail > data_num:
            break

        ry = np.unique(y[head:tail])
        if len(ry) == 1:
            x_win = x[head:tail, :]
            x_new = replace_nan(x_win)
            xx.append(x_new)
            yy.append(y[head])
            head += stepsize
            tail += stepsize
        else:
            print('ry:', ry)
            while y[head] == y[head+1]:
                head += 1
            head += 1
            tail = head + winsize

    ss = np.ones(len(yy)) * s
    return np.array(xx), np.array(yy).reshape(-1, 1), np.array(ss).reshape(-1, 1)


def replace_nan(x_win):
    x_new = []
    for col in range(x_win.shape[1]):
        x_col = x_win[:, col]
        x_col_mean = calculate_mean_value(x_col)
        index_nan = np.argwhere(np.isnan(x_col))
        x_col[index_nan] = x_col_mean
        if col == 0:
            x_new = x_col.reshape(-1, 1)
        else:
            x_new = np.hstack((x_new, x_col.reshape(-1, 1)))
    return x_new


def calculate_mean_value(x):
    x_new = []
    for x_i in x:
        if isnan(x_i):
            continue
        else:
            x_new.append(x_i)
    x_mean = np.mean(np.array(x_new), axis=0)
    return x_mean


def get_pamap_npy(root_path, save_path, winsize, overlap):
    x, y, s = load_data(root_path, winsize, overlap)
    print("label num:", np.unique(y))
    np.savez(save_path+'pamap_processwin.npz', x=x, y=y, s=s)
    print("npy saved")

def select_sub_act(x, y, s):
    x_new, y_new, s_new = [], [], []
    sub_list = [1, 2, 3, 4, 5, 6, 7, 8, 9]
    act_list = [1, 2, 3, 4, 5, 6, 7, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 24] 
    for index in range(len(y)):
        if (s[index] in sub_list) and (y[index] in act_list):
            x_new.append(x[index])
            y_new.append(y[index])
            s_new.append(s[index])
        else:
            continue
    x_new, y_new, s_new = np.array(x_new), np.array(y_new), np.array(s_new)
    y_nnew = copy.deepcopy(y_new)
    index_9 = np.argwhere(y_new == 9)
    y_nnew[index_9] = 8
    index_10 = np.argwhere(y_new == 10)
    y_nnew[index_10] = 9
    index_11 = np.argwhere(y_new == 11)
    y_nnew[index_11] = 10
    index_12 = np.argwhere(y_new == 12)
    y_nnew[index_12] = 11
    index_13 = np.argwhere(y_new == 13)
    y_nnew[index_13] = 12
    index_16 = np.argwhere(y_new == 16)
    y_nnew[index_16] = 13
    index_17 = np.argwhere(y_new == 17)
    y_nnew[index_17] = 14
    index_18 = np.argwhere(y_new == 18)
    y_nnew[index_18] = 15
    index_19 = np.argwhere(y_new == 19)
    y_nnew[index_19] = 16
    index_20 = np.argwhere(y_new == 20)
    y_nnew[index_20] = 17
    index_24 = np.argwhere(y_new == 24)
    y_nnew[index_24] = 18
    return x_new, y_nnew, s_new

def process_data_and_labels(hdir):
    print("Processing data and labels")
    data = np.load(hdir + 'pamap_processwin.npz')
    print('y.shape:', data['y'].shape)
    print("y label:", min(data['y'][:,0]), max(data['y'][:,0]))
    print("y nique", np.unique(data['y'][:,0], return_counts=True))
    x_new, y_new, s_new = select_sub_act(data['x'], data['y'], data['s'])
    print('y_new.shape:', y_new.shape)
    print("y_new label:", min(y_new), max(y_new))
    print("y_new nique", np.unique(y_new, return_counts=True))
    s = s_new - 1
    label = y_new - 1
    print(min(label),max(label))
    print("x.shape:", x_new.shape)
    np.save(hdir + 'pamap_x.npy', x_new.transpose((0,2,1)))
    print(min(s),max(s))
    y = np.concatenate((label, s, np.zeros_like(label)), axis=1)
    print('y.shape:', y.shape)
    print("y label:", min(y[:,0]), max(y[:,0]))
    print("y nique", np.unique(y[:,0], return_counts=True))
    np.save(hdir + 'pamap_y.npy', y)

if __name__ == '__main__':
    # label num: [ 1.  2.  3.  4.  5.  6.  7.  9. 10. 11. 12. 13. 16. 17. 18. 19. 20. 24.]
    root_path = '../data/PAMAP/PAMAP2_Dataset/'
    save_path = '../data/PAMAP/'
    winsize = 200
    overlap = 0.5
    get_pamap_npy(root_path, save_path, winsize, overlap)
    process_data_and_labels(save_path)
