#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 19 15:50:14 2017

@author: ly
"""
import numpy as np
import scipy.io as sio
import os


###### important:::::: please remember whether the label is normalized to 0-2 !!!!! ######

def load_data(eeg_dir):
    eeg_train_data  = sio.loadmat(eeg_dir)['train_de']
    eeg_test_data = sio.loadmat(eeg_dir)['test_de']
    train_label = sio.loadmat(eeg_dir)['train_label_eeg']
    test_label = sio.loadmat(eeg_dir)['test_label_eeg']
    eeg_data = np.concatenate((eeg_train_data, eeg_test_data), axis=0)
    eeg_data = (eeg_data - np.min(eeg_data, axis=0))/(np.max(eeg_data, axis=0)-np.min(eeg_data, axis=0))
    train_label = train_label + np.ones_like(train_label)
    test_label = test_label + np.ones_like(test_label)
    label = np.concatenate((train_label, test_label), axis=0)
    
    return eeg_data, label

class DataHandler(object):
    def __init__(self, num, time_step=15):
        ### load data from mat ###
        eeg_dir    = ''
        eeg_name   = 'EEG_X.mat'
        label_name = 'EEG_Y.mat'
        xdata   = sio.loadmat(eeg_dir+eeg_name)['X'].flatten()
        ydata  = sio.loadmat(eeg_dir+label_name)['Y'].flatten()
        self.src_data = []
        self.src_label = []
        self.trg_data = []
        self.trg_label = []
        
        for i in range(len(xdata)):
            cur_data = xdata[i]
            cur_data = (cur_data - np.min(cur_data, axis=0))/(np.max(cur_data, axis=0)-np.min(cur_data, axis=0))
            xdata[i] = cur_data 

        ydata0 = ydata[0] + np.ones_like(ydata[0])


        emotion_edge = [0, 234, 467, 674, 911, 1096, 1291, 1528, 1744, 2010, 2246, 2481, 2714, 2949, 3187, 3393, len(ydata0)-1]
        total_size = 0
        for i in range(len(emotion_edge)-1):
            total_size += (emotion_edge[i+1] + 1 - emotion_edge[i] - time_step + 1)

        for _ in range(xdata.shape[0]):
            begin_num = 0
            cur_data = np.zeros([total_size, time_step, xdata[_].shape[1]]) # the first
            cur_label = np.zeros([total_size,1])
            for i in range(len(emotion_edge)-1):
                for j in range(emotion_edge[i], emotion_edge[i+1] + 1 - time_step):
                    cur_data[begin_num,:,:] = xdata[_][j+1:j+time_step+1, :]
                    cur_label[begin_num] = np.mean(ydata0[j+1:j+time_step+1,0])
                    if cur_label[begin_num] % 1 != 0:
                        # print(i,j)
                        # for kkk in ydata0[j+1:j+time_step+1,0]:
                        #     print(kkk)
                        # print(cur_label[begin_num])
                        raise ValueError
                    begin_num += 1
            if _ == num:
                self.trg_data = cur_data
                self.trg_label = cur_label
            else:
                self.src_data.append(cur_data)
                self.src_label.append(cur_label)

        self.src_data = np.array(self.src_data)
        self.src_label = np.array(self.src_label)
