import os
import numpy as np
import pandas as pd
import os
import joblib
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
import warnings

warnings.filterwarnings('ignore')


class Dataset_ETT_hour(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='S', data_path='ETTh1.csv',
	             target='OT', scale=True, timeenc=0, freq='h'):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		# init
		assert flag in ['train', 'test', 'val']
		type_map = {'train': 0, 'val': 1, 'test': 2}
		self.set_type = type_map[flag]

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq

		self.root_path = root_path
		self.data_path = data_path
		self.__read_data__()

	def __read_data__(self):
		self.scaler = StandardScaler()
		df_raw = pd.read_csv(os.path.join(self.root_path,
		                                  self.data_path))

		border1s = [0, 12 * 30 * 24 - self.seq_len, 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len]
		border2s = [12 * 30 * 24, 12 * 30 * 24 + 4 * 30 * 24, 12 * 30 * 24 + 8 * 30 * 24]
		border1 = border1s[self.set_type]
		border2 = border2s[self.set_type]

		if self.features == 'M' or self.features == 'MS':
			cols_data = df_raw.columns[1:]
			df_data = df_raw[cols_data]
		elif self.features == 'S':
			df_data = df_raw[[self.target]]

		if self.scale:
			train_data = df_data[border1s[0]:border2s[0]]
			self.scaler.fit(train_data.values)
			data = self.scaler.transform(df_data.values)
		else:
			data = df_data.values

		df_stamp = df_raw[['date']][border1:border2]
		df_stamp['date'] = pd.to_datetime(df_stamp.date)
		if self.timeenc == 0:
			df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
			df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
			df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
			df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
			data_stamp = df_stamp.drop(['date'], 1).values
		elif self.timeenc == 1:
			data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)

		self.data_x = data[border1:border2]
		self.data_y = data[border1:border2]
		self.data_stamp = data_stamp

	def __getitem__(self, index):
		s_begin = index
		s_end = s_begin + self.seq_len
		r_begin = s_end - self.label_len
		r_end = r_begin + self.label_len + self.pred_len

		seq_x = self.data_x[s_begin:s_end]
		seq_y = self.data_y[r_begin:r_end]
		seq_x_mark = self.data_stamp[s_begin:s_end]
		seq_y_mark = self.data_stamp[r_begin:r_end]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x) - self.seq_len - self.pred_len + 1

	def inverse_transform(self, data):
		return self.scaler.inverse_transform(data)


class Dataset_ETT_minute(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='S', data_path='ETTm1.csv',
	             target='OT', scale=True, timeenc=0, freq='t'):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		# init
		assert flag in ['train', 'test', 'val']
		type_map = {'train': 0, 'val': 1, 'test': 2}
		self.set_type = type_map[flag]

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq

		self.root_path = root_path
		self.data_path = data_path
		self.__read_data__()

	def __read_data__(self):
		self.scaler = StandardScaler()
		df_raw = pd.read_csv(os.path.join(self.root_path,
		                                  self.data_path))

		border1s = [0, 12 * 30 * 24 * 4 - self.seq_len, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len]
		border2s = [12 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4]
		border1 = border1s[self.set_type]
		border2 = border2s[self.set_type]

		if self.features == 'M' or self.features == 'MS':
			cols_data = df_raw.columns[1:]
			df_data = df_raw[cols_data]
		elif self.features == 'S':
			df_data = df_raw[[self.target]]

		if self.scale:
			train_data = df_data[border1s[0]:border2s[0]]
			self.scaler.fit(train_data.values)
			data = self.scaler.transform(df_data.values)
		else:
			data = df_data.values

		df_stamp = df_raw[['date']][border1:border2]
		df_stamp['date'] = pd.to_datetime(df_stamp.date)
		if self.timeenc == 0:
			df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
			df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
			df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
			df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
			df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
			df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
			data_stamp = df_stamp.drop(['date'], 1).values
		elif self.timeenc == 1:
			data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)

		self.data_x = data[border1:border2]
		self.data_y = data[border1:border2]
		self.data_stamp = data_stamp

	def __getitem__(self, index):
		s_begin = index
		s_end = s_begin + self.seq_len
		r_begin = s_end - self.label_len
		r_end = r_begin + self.label_len + self.pred_len

		seq_x = self.data_x[s_begin:s_end]
		seq_y = self.data_y[r_begin:r_end]
		seq_x_mark = self.data_stamp[s_begin:s_end]
		seq_y_mark = self.data_stamp[r_begin:r_end]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x) - self.seq_len - self.pred_len + 1

	def inverse_transform(self, data):
		return self.scaler.inverse_transform(data)


class Dataset_Custom(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='S', data_path='ETTh1.csv',
	             target='OT', scale=True, timeenc=0, freq='h'):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		# init
		assert flag in ['train', 'test', 'val']
		type_map = {'train': 0, 'val': 1, 'test': 2}
		self.set_type = type_map[flag]

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq

		self.root_path = root_path
		self.data_path = data_path
		self.__read_data__()

	def __read_data__(self):
		self.scaler = StandardScaler()
		df_raw = pd.read_csv(os.path.join(self.root_path,
		                                  self.data_path))

		'''
		df_raw.columns: ['date', ...(other features), target feature]
		'''
		cols = list(df_raw.columns)
		cols.remove(self.target)
		cols.remove('date')
		df_raw = df_raw[['date'] + cols + [self.target]]
		# print(cols)
		num_train = int(len(df_raw) * 0.7)
		num_test = int(len(df_raw) * 0.2)
		num_vali = len(df_raw) - num_train - num_test
		border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
		border2s = [num_train, num_train + num_vali, len(df_raw)]
		border1 = border1s[self.set_type]
		border2 = border2s[self.set_type]

		if self.features == 'M' or self.features == 'MS':
			cols_data = df_raw.columns[1:]
			df_data = df_raw[cols_data]
		elif self.features == 'S':
			df_data = df_raw[[self.target]]

		if self.scale:
			train_data = df_data[border1s[0]:border2s[0]]
			self.scaler.fit(train_data.values)
			# print(self.scaler.mean_)
			# exit()
			data = self.scaler.transform(df_data.values)
		else:
			data = df_data.values

		df_stamp = df_raw[['date']][border1:border2]
		df_stamp['date'] = pd.to_datetime(df_stamp.date)
		if self.timeenc == 0:
			df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
			df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
			df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
			df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
			data_stamp = df_stamp.drop(['date'], 1).values
		elif self.timeenc == 1:
			data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)

		self.data_x = data[border1:border2]
		self.data_y = data[border1:border2]
		vel_x = np.zeros_like(self.data_x)
		vel_x[1:] = self.data_x[1:] - self.data_x[:-1]
		vel_y = np.zeros_like(self.data_y)
		vel_y[1:] = self.data_y[1:] - self.data_y[:-1]
		self.vel_x = vel_x
		self.vel_y = vel_y
		diff = data[border1+1:border2,] - data[border1:border2-1,]
		self.vel_mean = np.mean(diff, axis=0)
		self.vel_std = np.std(diff, axis=0)
		self.data_stamp = data_stamp

	def __getitem__(self, index):
		s_begin = index
		s_end = s_begin + self.seq_len
		r_begin = s_end - self.label_len
		r_end = r_begin + self.label_len + self.pred_len

		seq_x = self.data_x[s_begin:s_end]
		seq_y = self.data_y[r_begin:r_end]
		seq_x_mark = self.data_stamp[s_begin:s_end]
		seq_y_mark = self.data_stamp[r_begin:r_end]

		vel_x = self.vel_x[s_begin:s_end]
		vel_y = self.vel_y[r_begin:r_end]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x) - self.seq_len - self.pred_len + 1

	def inverse_transform(self, data):
		return self.scaler.inverse_transform(data)


def data_time_split(trajs, dates, input_time_step, output_time_step, label_shift):
	trajs = np.array(trajs)
	vels = np.zeros_like(trajs)
	vels[1:] = trajs[1:] - trajs[:-1]

	dates = np.array(dates)
	x_traj, y_traj = [], []
	x_vel, y_vel = [], []
	x_date, y_date=[],[]

	begin_t = 0
	end_t = input_time_step + output_time_step
	steps = len(trajs)

	while end_t <= steps:
		# input
		inp_traj = trajs[begin_t:begin_t + input_time_step].reshape((input_time_step, -1))
		x_traj.append(inp_traj)

		inp_date = dates[begin_t:begin_t + input_time_step].reshape((input_time_step, -1))
		x_date.append(inp_date)

		inp_vel = vels[begin_t:begin_t + input_time_step].reshape((input_time_step, -1))
		x_vel.append(inp_vel)

		# output
		out_traj = trajs[begin_t + input_time_step-label_shift:end_t].reshape((output_time_step+label_shift, -1))
		y_traj.append(out_traj)

		out_date = dates[begin_t + input_time_step-label_shift:end_t].reshape((output_time_step+label_shift, -1))
		y_date.append(out_date)

		out_vel = vels[begin_t + input_time_step-label_shift:end_t].reshape((output_time_step+label_shift, -1))
		y_vel.append(out_vel)

		begin_t += 1
		end_t += 1

	x_traj = np.array(x_traj)
	x_date = np.array(x_date)
	x_vel = np.array(x_vel)
	y_traj = np.array(y_traj)
	y_date = np.array(y_date)
	y_vel = np.array(y_vel)

	return x_traj, y_traj, x_date, y_date, x_vel, y_vel


class Dataset_Thor(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='M', data_path='thor_data_prep.pkl',
	             target='OT', scale=True, timeenc=0, freq='h', pred_diff=False,frame_skip=5 ):
		# size [seq_len, label_len, pred_len]
		# info

		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		if pred_diff:
			self.seq_len += 1
		# init
		assert flag in ['train', 'test', 'val']
		#train_flags =['body0_run0','body0_run1', 'body0_run2',]  #['body0_run0', 'body0_run1', 'body0_run2', 'body1_run0', 'body1_run1','body2_run0', 'body2_run1']
		#train_flags = ['body0_run0', 'body0_run1', 'body0_run2', 'body2_run0', 'body2_run1', ]
		train_flags = ['body0_run0', 'body0_run1', 'body0_run2',]
		val_flags = ['body1_run0','body1_run1']
		#val_flags = [ 'body1_run1']
		#test_flags = ['body3_run2','body4_run2']#['body7_run1','body8_run1']  # ['body3_run0','body4_run0','body7_run0','body8_run0'],
		#test_flags = ['body1_run2','body3_run2', 'body4_run2', 'body5_run2', 'body9_run2']
		test_flags = ['body1_run2', 'body1_run1', 'body3_run2', 'body3_run1', 'body9_run1', 'body9_run2']
		type_map = {'train': train_flags, 'val': val_flags, 'test': test_flags}
		self.set_type = type_map[flag]
		self.flag=flag
		self.train_flags = train_flags

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq
		self.frame_skip=frame_skip # original fps=100

		self.root_path = root_path
		self.data_path = data_path

		self.pred_diff = pred_diff

		self.__read_data__()

	def get_mean_std(self, data_list_raw):
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"body{raw['body_id']}_run{raw['run_id']}"
			#print(mark)
			if mark not in  self.train_flags:
				continue
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
			total_vels.append(traj[1:] - traj[:-1])
		total_traj = np.concatenate(total_traj,axis=0)
		total_vels = np.concatenate(total_vels, axis=0)
		traj_mean = np.mean(total_traj, axis=0, keepdims=True)
		traj_std = np.std(total_traj, axis=0, keepdims=True)
		vel_mean = np.mean(total_vels, axis=0, keepdims=True)
		vel_std = np.std(total_vels, axis=0, keepdims=True)
		stats = {'traj_mean':traj_mean,'traj_std':traj_std,'vel_mean':vel_mean,'vel_std':vel_std}
		self.stats_pt ={'traj_mean':torch.tensor(traj_mean).unsqueeze(0).float(),'traj_std':torch.tensor(traj_std).unsqueeze(0).float(),
		                'vel_mean':torch.tensor(vel_mean).unsqueeze(0).float(),'vel_std':torch.tensor(vel_std).unsqueeze(0).float()}
		self.train_traj = total_traj
		return stats

	def get_full_traj(self):
		pkl = joblib.load(os.path.join(self.root_path,
									   self.data_path))
		data_list_raw = pkl['data']
		total_traj = []
		for raw in data_list_raw:
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
		total_traj = np.concatenate(total_traj, axis=0)
		return total_traj


	def __read_data__(self):

		pkl = joblib.load(os.path.join(self.root_path,
		                               self.data_path))
		data_list_raw = pkl['data']
		# stats = pkl['stats']
		stats = self.get_mean_std(data_list_raw)
		x_trajs,y_trajs = [],[]
		x_vels, y_vels= [], []
		x_dates, y_dates=[],[]
		t0 = 946684800
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"body{raw['body_id']}_run{raw['run_id']}"
			#print(mark)
			if mark not in  self.set_type:
				continue
			traj = raw['traj']
			traj=traj[::self.frame_skip]
			norm_traj =  np.array(traj)
			#norm_traj = (traj - stats['traj_mean']) / stats['traj_std']
			#norm_traj = traj - stats['traj_mean']
			#norm_traj = np.array(norm_traj)
			total_traj.append(norm_traj)
			total_vels.append(norm_traj[1:] - norm_traj[:-1])
			ds = np.arange(len(traj)) * 60
			dates = [pd.to_datetime(t0 + i, unit='s') for i in ds]
			data_stamp = time_features(pd.to_datetime(np.array(dates)), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)
			x_traj, y_traj, x_date, y_date, x_vel, y_vel = data_time_split(norm_traj, data_stamp, self.seq_len, self.pred_len, self.label_len)

			x_trajs.append(x_traj)
			y_trajs.append(y_traj)
			x_vels.append(x_vel)
			y_vels.append(y_vel)
			x_dates.append(x_date)
			y_dates.append(y_date)
			t0 = t0 + ds[-1] + 60 * 60 * 24

		#joblib.dump(total_traj,'dataset/thor_'+self.flag+'.pkl')
		self.data_x=np.concatenate(x_trajs,axis=0)
		self.data_y = np.concatenate(y_trajs, axis=0)
		self.stamp_x = np.concatenate(x_dates, axis=0)
		self.stamp_y = np.concatenate(y_dates, axis=0)

		self.x_vels = np.concatenate(x_vels, axis=0)
		self.y_vels = np.concatenate(y_vels, axis=0)


	def __getitem__(self, index):

		seq_x = self.data_x[index]
		seq_y = self.data_y[index]
		seq_x_mark = self.stamp_x[index]
		seq_y_mark = self.stamp_y[index]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x)

	def normalize(self,data_x,data_y):
		full_traj = torch.cat([data_x,data_y[:,-self.pred_len:,:]],dim=1)
		self.tmp_full_traj = full_traj
		full_speed = full_traj[:,1:,] - full_traj[:,:-1,]
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			data = (full_traj - mn)/std
			inp_x = data[:,:self.seq_len,:]
			inp_y = data[:,self.seq_len - self.label_len:,:]
		else:
			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			data = (full_speed - mn) / std
			inp_x = data[:, :self.seq_len-1, :]
			inp_y = data[:, self.seq_len - self.label_len-1:, :]
		return inp_x,inp_y

	def denormalize(self,data_x,data_y, output, u = None):
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			inp_x = data_x * std + mn
			inp_y = data_y * std + mn
			out = output * std + mn
		else:

			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			vel_x = data_x * std + mn
			vel_y = data_y * std + mn
			vel_out = output * std + mn
			x0 = self.tmp_full_traj[:,:1,:]

			rel_x = torch.cumsum(vel_x,dim=1)
			rel_y = torch.cumsum(vel_y, dim=1)
			rel_out = torch.cumsum(vel_out, dim=1)

			inp_x = rel_x + x0
			inp_y = rel_y + inp_x[:,-1:,:]
			out = rel_out + inp_x[:,-1:,:]
		if u is not None:
			new_u = u.to(device) * std
			return inp_x,inp_y, out, new_u

		return inp_x,inp_y, out


class Dataset_Assembly(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='M', data_path='human_assembly.pkl',
	             target='OT', scale=True, timeenc=0, freq='h', pred_diff=False,frame_skip=1 ):
		# size [seq_len, label_len, pred_len]
		# info

		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		if pred_diff:
			self.seq_len += 1
		# init
		assert flag in ['train', 'test', 'val']
		#train_flags =['body0_run0','body0_run1', 'body0_run2',]  #['body0_run0', 'body0_run1', 'body0_run2', 'body1_run0', 'body1_run1','body2_run0', 'body2_run1']
		train_flags = ['task1_human1_trial1', 'task2_human1_trial1','task3_human1_trial1', ]
		val_flags = ['task1_human1_trial2',]
		#val_flags = [ 'body1_run1']
		#test_flags = ['body3_run2','body4_run2']#['body7_run1','body8_run1']  # ['body3_run0','body4_run0','body7_run0','body8_run0'],
		#test_flags = ['body1_run2','body3_run2', 'body4_run2', 'body5_run2', 'body9_run2']
		test_flags = [ 'task3_human1_trial2','task4_human1_trial1', 'task4_human1_trial2','task5_human1_trial1', 'task5_human1_trial2',]
		test_flags = ['task3_human1_trial2', 'task3_human1_trial1', 'task4_human1_trial1',
		              'task5_human1_trial1',   'task5_human1_trial2', ]
		# task3_human2_trial2 is Nan
		type_map = {'train': train_flags, 'val': val_flags, 'test': test_flags}
		self.set_type = type_map[flag]
		self.flag=flag
		self.train_flags = train_flags

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq
		self.frame_skip=frame_skip # original fps=100

		self.root_path = root_path
		self.data_path = data_path

		self.pred_diff = pred_diff
		self.n_dim = 12

		self.__read_data__()

	def get_mean_std(self, data_list_raw):
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"task{raw['task']}_human{raw['human']}_trial{raw['trial']}"
			#print(mark)
			if mark not in  self.train_flags:
				continue
			traj = np.array(raw['traj'])[:,:self.n_dim]
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
			total_vels.append(traj[1:] - traj[:-1])
		total_traj = np.concatenate(total_traj,axis=0)
		total_vels = np.concatenate(total_vels, axis=0)
		traj_mean = np.mean(total_traj, axis=0, keepdims=True)
		traj_std = np.std(total_traj, axis=0, keepdims=True)
		vel_mean = np.mean(total_vels, axis=0, keepdims=True)
		vel_std = np.std(total_vels, axis=0, keepdims=True)
		stats = {'traj_mean':traj_mean,'traj_std':traj_std,'vel_mean':vel_mean,'vel_std':vel_std}
		self.stats_pt ={'traj_mean':torch.tensor(traj_mean).unsqueeze(0).float(),'traj_std':torch.tensor(traj_std).unsqueeze(0).float(),
		                'vel_mean':torch.tensor(vel_mean).unsqueeze(0).float(),'vel_std':torch.tensor(vel_std).unsqueeze(0).float()}
		self.train_traj = total_traj
		return stats

	def get_full_traj(self):
		pkl = joblib.load(os.path.join(self.root_path,
									   self.data_path))
		data_list_raw = pkl
		total_traj = []
		for raw in data_list_raw:
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
		total_traj = np.concatenate(total_traj, axis=0)
		return total_traj

	def __read_data__(self):

		pkl = joblib.load(os.path.join(self.root_path,
		                               self.data_path))
		data_list_raw =pkl
		# stats = pkl['stats']
		stats = self.get_mean_std(data_list_raw)
		x_trajs,y_trajs = [],[]
		x_vels, y_vels= [], []
		x_dates, y_dates=[],[]
		t0 = 946684800
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"task{raw['task']}_human{raw['human']}_trial{raw['trial']}"
			#print(mark)
			if mark not in  self.set_type:
				continue
			traj = raw['traj'][:,:self.n_dim]

			traj=traj[::self.frame_skip]
			norm_traj =  np.array(traj)
			#norm_traj = (traj - stats['traj_mean']) / stats['traj_std']
			#norm_traj = traj - stats['traj_mean']
			#norm_traj = np.array(norm_traj)
			total_traj.append(norm_traj)
			total_vels.append(norm_traj[1:] - norm_traj[:-1])
			ds = np.arange(len(traj)) * 60
			dates = [pd.to_datetime(t0 + i, unit='s') for i in ds]
			data_stamp = time_features(pd.to_datetime(np.array(dates)), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)
			x_traj, y_traj, x_date, y_date, x_vel, y_vel = data_time_split(norm_traj, data_stamp, self.seq_len, self.pred_len, self.label_len)

			x_trajs.append(x_traj)
			y_trajs.append(y_traj)
			x_vels.append(x_vel)
			y_vels.append(y_vel)
			x_dates.append(x_date)
			y_dates.append(y_date)
			t0 = t0 + ds[-1] + 60 * 60 * 24

		#joblib.dump(total_traj,'dataset/thor_'+self.flag+'.pkl')
		self.data_x=np.concatenate(x_trajs,axis=0)
		self.data_y = np.concatenate(y_trajs, axis=0)
		self.stamp_x = np.concatenate(x_dates, axis=0)
		self.stamp_y = np.concatenate(y_dates, axis=0)

		self.x_vels = np.concatenate(x_vels, axis=0)
		self.y_vels = np.concatenate(y_vels, axis=0)


	def __getitem__(self, index):

		seq_x = self.data_x[index]
		seq_y = self.data_y[index]
		seq_x_mark = self.stamp_x[index]
		seq_y_mark = self.stamp_y[index]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x)

	def normalize(self,data_x,data_y):
		full_traj = torch.cat([data_x,data_y[:,-self.pred_len:,:]],dim=1)
		self.tmp_full_traj = full_traj
		full_speed = full_traj[:,1:,] - full_traj[:,:-1,]
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			data = (full_traj - mn)/std
			inp_x = data[:,:self.seq_len,:]
			inp_y = data[:,self.seq_len - self.label_len:,:]
		else:
			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			data = (full_speed - mn) / std
			inp_x = data[:, :self.seq_len-1, :]
			inp_y = data[:, self.seq_len - self.label_len-1:, :]
		return inp_x,inp_y

	def denormalize(self,data_x,data_y, output, u = None):
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			inp_x = data_x * std + mn
			inp_y = data_y * std + mn
			out = output * std + mn
		else:

			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			vel_x = data_x * std + mn
			vel_y = data_y * std + mn
			vel_out = output * std + mn
			x0 = self.tmp_full_traj[:,:1,:]

			rel_x = torch.cumsum(vel_x,dim=1)
			rel_y = torch.cumsum(vel_y, dim=1)
			rel_out = torch.cumsum(vel_out, dim=1)

			inp_x = rel_x + x0
			inp_y = rel_y + inp_x[:,-1:,:]
			out = rel_out + inp_x[:,-1:,:]
		if u is not None:
			new_u = u.to(device) * std
			return inp_x,inp_y, out, new_u

		return inp_x,inp_y, out


class Dataset_NGSIM(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='M', data_path='vehicle_ngsim.pkl',
	             target='OT', scale=True, timeenc=0, freq='h', pred_diff=False,frame_skip=1):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		if pred_diff:
			self.seq_len += 1
		# init
		assert flag in ['train', 'test', 'val']
		type_map = {'train':'train', 'val': 'valid', 'test': 'test'}
		self.set_type = type_map[flag]
		self.flag=flag

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq
		self.frame_skip=frame_skip

		self.root_path = root_path
		self.data_path = data_path
		self.pred_diff = pred_diff
		self.__read_data__()

	def get_mean_std(self, data_list_raw):
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			traj = np.array(raw)
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
			total_vels.append(traj[1:] - traj[:-1])
		total_traj = np.concatenate(total_traj,axis=0)
		total_vels = np.concatenate(total_vels, axis=0)
		traj_mean = np.mean(total_traj, axis=0, keepdims=True)
		traj_std = np.std(total_traj, axis=0, keepdims=True)
		vel_mean = np.mean(total_vels, axis=0, keepdims=True)
		vel_std = np.std(total_vels, axis=0, keepdims=True)
		stats = {'traj_mean':traj_mean,'traj_std':traj_std,'vel_mean':vel_mean,'vel_std':vel_std}
		self.stats_pt ={'traj_mean':torch.tensor(traj_mean).unsqueeze(0).float(),'traj_std':torch.tensor(traj_std).unsqueeze(0).float(),
		                'vel_mean':torch.tensor(vel_mean).unsqueeze(0).float(),'vel_std':torch.tensor(vel_std).unsqueeze(0).float()}

		self.train_traj = total_traj
		return stats

	def get_full_traj(self):
		pkl = joblib.load(os.path.join(self.root_path,
									   self.data_path))
		data_list_raw = pkl['train']['traj']
		total_traj = []
		for raw in data_list_raw:
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			total_traj.append(traj)
		total_traj = np.concatenate(total_traj, axis=0)
		return total_traj

	def __read_data__(self):

		pkl = joblib.load(os.path.join(self.root_path,
		                               self.data_path))
		data_list_raw = pkl[self.set_type]['traj']
		stats = self.get_mean_std(pkl['train']['traj'][:20])
		#if self.set_type=='test':
		#	data_list_raw += pkl['valid']['traj']
		#data_list_raw=data_list_raw[:int(len(data_list_raw)*0.7)]

		x_trajs,y_trajs = [],[]
		x_dates, y_dates=[],[]
		t0 = 946684800
		if self.flag == 'train':
			data_list_raw = data_list_raw[:10]
		if self.flag in ['val','test']:
			t0 = t0 + 88*(400*60+60 * 60 * 24)
			if self.flag=='test':
				t0 = t0 + 6*(400*60+60 * 60 * 24)
		for raw in data_list_raw:
			traj =raw
			traj=traj[::self.frame_skip]
			#norm_traj = (traj - stats['traj_mean']) / stats['traj_std']
			#norm_traj = (traj - stats['traj_mean'])/15
			norm_traj = traj
			ds = np.arange(len(traj)) * 60
			dates = [pd.to_datetime(t0 + i, unit='s') for i in ds]
			data_stamp = time_features(pd.to_datetime(np.array(dates)), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)
			x_traj, y_traj, x_date, y_date,x_vel, y_vel = data_time_split(norm_traj, data_stamp, self.seq_len, self.pred_len, self.label_len)

			x_trajs.append(x_traj)
			y_trajs.append(y_traj)
			x_dates.append(x_date)
			y_dates.append(y_date)
			t0 = t0 + ds[-1] + 60 * 60 * 24

		self.data_x=np.concatenate(x_trajs,axis=0)
		self.data_y = np.concatenate(y_trajs, axis=0)
		self.stamp_x = np.concatenate(x_dates, axis=0)
		self.stamp_y = np.concatenate(y_dates, axis=0)


	def __getitem__(self, index):

		seq_x = self.data_x[index]
		seq_y = self.data_y[index]
		seq_x_mark = self.stamp_x[index]
		seq_y_mark = self.stamp_y[index]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x)

	def normalize(self,data_x,data_y):
		full_traj = torch.cat([data_x,data_y[:,-self.pred_len:,:]],dim=1)
		self.tmp_full_traj = full_traj
		full_speed = full_traj[:,1:,] - full_traj[:,:-1,]
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			data = (full_traj - mn)/std
			inp_x = data[:,:self.seq_len,:]
			inp_y = data[:,self.seq_len - self.label_len:,:]
		else:
			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			data = (full_speed - mn) / std
			inp_x = data[:, :self.seq_len-1, :]
			inp_y = data[:, self.seq_len - self.label_len-1:, :]
		return inp_x,inp_y

	def denormalize(self,data_x,data_y, output, u = None):
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			inp_x = data_x * std + mn
			inp_y = data_y * std + mn
			out = output * std + mn
		else:

			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			vel_x = data_x * std + mn
			vel_y = data_y * std + mn
			vel_out = output * std + mn
			x0 = self.tmp_full_traj[:,:1,:]

			rel_x = torch.cumsum(vel_x,dim=1)
			rel_y = torch.cumsum(vel_y, dim=1)
			rel_out = torch.cumsum(vel_out, dim=1)

			inp_x = rel_x + x0
			inp_y = rel_y + inp_x[:,-1:,:]
			out = rel_out + inp_x[:,-1:,:]
		if u is not None:
			new_u = u.to(device) * std
			return inp_x,inp_y, out, new_u

		return inp_x,inp_y, out

class Dataset_Kinova(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='M', data_path='kinova_monitor.pkl',
	             target='OT', scale=True, timeenc=0, freq='h', pred_diff=False,frame_skip=40 ):
		# size [seq_len, label_len, pred_len]
		# info

		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		if pred_diff:
			self.seq_len += 1
		# init
		assert flag in ['train', 'test', 'val']

		train_flags = ['test1_1', 'test1_2',]
		val_flags = ['test2_1']
		test_flags = ['test2_2', 'test3_1', 'test3_2','test4_1', 'test4_2']
		type_map = {'train': train_flags, 'val': val_flags, 'test': test_flags}
		self.set_type = type_map[flag]
		self.flag=flag
		self.train_flags = train_flags

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq
		self.frame_skip=frame_skip # original fps=1000

		self.root_path = root_path
		self.data_path = data_path

		self.pred_diff = pred_diff
		self.val_inds = [1,3,5]

		self.__read_data__()

	def get_mean_std(self, data_list_raw):
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"test{raw['task']}_{raw['trial']}"
			#print(mark)
			if mark not in  self.train_flags:
				continue
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			traj = traj[:,self.val_inds]
			total_traj.append(traj)
			total_vels.append(traj[1:] - traj[:-1])
		total_traj = np.concatenate(total_traj,axis=0)
		total_vels = np.concatenate(total_vels, axis=0)
		traj_mean = np.mean(total_traj, axis=0, keepdims=True)
		traj_std = np.std(total_traj, axis=0, keepdims=True)
		vel_mean = np.mean(total_vels, axis=0, keepdims=True)
		vel_std = np.std(total_vels, axis=0, keepdims=True)
		stats = {'traj_mean':traj_mean,'traj_std':traj_std,'vel_mean':vel_mean,'vel_std':vel_std}
		self.stats_pt ={'traj_mean':torch.tensor(traj_mean).unsqueeze(0).float(),'traj_std':torch.tensor(traj_std).unsqueeze(0).float(),
		                'vel_mean':torch.tensor(vel_mean).unsqueeze(0).float(),'vel_std':torch.tensor(vel_std).unsqueeze(0).float()}
		self.train_traj = total_traj
		return stats

	def get_full_traj(self):
		pkl = joblib.load(os.path.join(self.root_path,
									   self.data_path))
		data_list_raw = pkl
		total_traj = []
		for raw in data_list_raw:
			traj = np.array(raw['traj'])
			traj = traj[::self.frame_skip]
			traj = traj[:, self.val_inds]
			total_traj.append(traj)
		total_traj = np.concatenate(total_traj, axis=0)
		return total_traj

	def __read_data__(self):

		pkl = joblib.load(os.path.join(self.root_path,
		                               self.data_path))
		data_list_raw = pkl
		# stats = pkl['stats']
		stats = self.get_mean_std(data_list_raw)
		x_trajs,y_trajs = [],[]
		x_vels, y_vels= [], []
		x_dates, y_dates=[],[]
		t0 = 946684800
		total_traj = []
		total_vels = []
		for raw in data_list_raw:
			mark = f"test{raw['task']}_{raw['trial']}"
			#print(mark)
			if mark not in  self.set_type:
				continue
			traj = raw['traj']
			traj=traj[::self.frame_skip]
			traj = traj[:, self.val_inds]
			norm_traj =  np.array(traj)
			#norm_traj = (traj - stats['traj_mean']) / stats['traj_std']
			#norm_traj = traj - stats['traj_mean']
			#norm_traj = np.array(norm_traj)
			total_traj.append(norm_traj)
			total_vels.append(norm_traj[1:] - norm_traj[:-1])
			ds = np.arange(len(traj)) * 60
			dates = [pd.to_datetime(t0 + i, unit='s') for i in ds]
			data_stamp = time_features(pd.to_datetime(np.array(dates)), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)
			x_traj, y_traj, x_date, y_date, x_vel, y_vel = data_time_split(norm_traj, data_stamp, self.seq_len, self.pred_len, self.label_len)

			x_trajs.append(x_traj)
			y_trajs.append(y_traj)
			x_vels.append(x_vel)
			y_vels.append(y_vel)
			x_dates.append(x_date)
			y_dates.append(y_date)
			t0 = t0 + ds[-1] + 60 * 60 * 24

		#joblib.dump(total_traj,'dataset/thor_'+self.flag+'.pkl')
		self.data_x=np.concatenate(x_trajs,axis=0)
		self.data_y = np.concatenate(y_trajs, axis=0)
		self.stamp_x = np.concatenate(x_dates, axis=0)
		self.stamp_y = np.concatenate(y_dates, axis=0)

		self.x_vels = np.concatenate(x_vels, axis=0)
		self.y_vels = np.concatenate(y_vels, axis=0)


	def __getitem__(self, index):

		seq_x = self.data_x[index]
		seq_y = self.data_y[index]
		seq_x_mark = self.stamp_x[index]
		seq_y_mark = self.stamp_y[index]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x)

	def normalize(self,data_x,data_y):
		full_traj = torch.cat([data_x,data_y[:,-self.pred_len:,:]],dim=1)
		self.tmp_full_traj = full_traj
		full_speed = full_traj[:,1:,] - full_traj[:,:-1,]
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			data = (full_traj - mn)/std
			inp_x = data[:,:self.seq_len,:]
			inp_y = data[:,self.seq_len - self.label_len:,:]
		else:
			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			data = (full_speed - mn) / std
			inp_x = data[:, :self.seq_len-1, :]
			inp_y = data[:, self.seq_len - self.label_len-1:, :]
		return inp_x,inp_y

	def denormalize(self,data_x,data_y, output, u = None):
		device = data_x.device
		if not self.pred_diff:
			mn = self.stats_pt['traj_mean'].to(device)
			std =  self.stats_pt['traj_std'].to(device)
			inp_x = data_x * std + mn
			inp_y = data_y * std + mn
			out = output * std + mn
		else:

			mn = self.stats_pt['vel_mean'].to(device)
			std = self.stats_pt['vel_std'].to(device)
			vel_x = data_x * std + mn
			vel_y = data_y * std + mn
			vel_out = output * std + mn
			x0 = self.tmp_full_traj[:,:1,:]

			rel_x = torch.cumsum(vel_x,dim=1)
			rel_y = torch.cumsum(vel_y, dim=1)
			rel_out = torch.cumsum(vel_out, dim=1)

			inp_x = rel_x + x0
			inp_y = rel_y + inp_x[:,-1:,:]
			out = rel_out + inp_x[:,-1:,:]
		if u is not None:
			new_u = u.to(device) * std
			return inp_x,inp_y, out, new_u

		return inp_x,inp_y, out


class Dataset_Ford(Dataset):
	def __init__(self, root_path, flag='train', size=None,
	             features='S', data_path='ford_data_prep.pkl',
	             target='OT', scale=True, timeenc=0, freq='h'):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		# init
		assert flag in ['train', 'test', 'val']
		train_flags =['2_1']  #['body0_run0', 'body0_run1', 'body0_run2', 'body1_run0', 'body1_run1','body2_run0', 'body2_run1']
		val_flags = [ '2_1']
		test_flags = ['1_1','1_2', '1_3','2_1','2_2','2_3','3_1','3_2','3_3','4_1','4_2','4_3']#['body7_run1','body8_run1']  # ['body3_run0','body4_run0','body7_run0','body8_run0'],
		type_map = {'train': train_flags, 'val': val_flags, 'test': test_flags}
		self.set_type = type_map[flag]
		self.flag=flag

		self.features = features
		self.target = target
		self.scale = scale
		self.timeenc = timeenc
		self.freq = freq
		self.frame_skip=1

		self.root_path = root_path
		self.data_path = data_path
		self.__read_data__()

	def __read_data__(self):

		pkl = joblib.load(os.path.join(self.root_path,
		                               self.data_path))
		data_list_raw = pkl['data']
		stats = pkl['stats']
		stats['traj_std'] = np.array([stats['traj_std'].mean()]*3)
		x_trajs,y_trajs = [],[]
		x_dates, y_dates=[],[]
		t0 = 946684800
		for raw in data_list_raw:
			mark = f"{raw['task_id']}_{raw['trial_id']}"
			#print(mark)
			if mark not in  self.set_type:
				continue
			traj = raw['traj']
			traj=traj[::self.frame_skip]
			#norm_traj = (traj - stats['traj_mean']) / stats['traj_std']
			norm_traj = traj - stats['traj_mean']
			norm_traj = norm_traj*4
			ds = np.arange(len(traj)) * 60
			dates = [pd.to_datetime(t0 + i, unit='s') for i in ds]
			data_stamp = time_features(pd.to_datetime(np.array(dates)), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)
			x_traj, y_traj, x_date, y_date,x_vel, y_vel = data_time_split(norm_traj, data_stamp, self.seq_len, self.pred_len, self.label_len)

			x_trajs.append(x_traj)
			y_trajs.append(y_traj)
			x_dates.append(x_date)
			y_dates.append(y_date)
			t0 = t0 + ds[-1] + 60 * 60 * 24

		self.data_x=np.concatenate(x_trajs,axis=0)
		self.data_y = np.concatenate(y_trajs, axis=0)
		self.stamp_x = np.concatenate(x_dates, axis=0)
		self.stamp_y = np.concatenate(y_dates, axis=0)


	def __getitem__(self, index):

		seq_x = self.data_x[index]
		seq_y = self.data_y[index]
		seq_x_mark = self.stamp_x[index]
		seq_y_mark = self.stamp_y[index]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x)


class Dataset_Pred(Dataset):
	def __init__(self, root_path, flag='pred', size=None,
	             features='S', data_path='ETTh1.csv',
	             target='OT', scale=True, inverse=False, timeenc=0, freq='15min', cols=None):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 24 * 4 * 4
			self.label_len = 24 * 4
			self.pred_len = 24 * 4
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		# init
		assert flag in ['pred']

		self.features = features
		self.target = target
		self.scale = scale
		self.inverse = inverse
		self.timeenc = timeenc
		self.freq = freq
		self.cols = cols
		self.root_path = root_path
		self.data_path = data_path
		self.__read_data__()

	def __read_data__(self):
		self.scaler = StandardScaler()
		df_raw = pd.read_csv(os.path.join(self.root_path,
		                                  self.data_path))
		'''
		df_raw.columns: ['date', ...(other features), target feature]
		'''
		if self.cols:
			cols = self.cols.copy()
			cols.remove(self.target)
		else:
			cols = list(df_raw.columns)
			cols.remove(self.target)
			cols.remove('date')
		df_raw = df_raw[['date'] + cols + [self.target]]
		border1 = len(df_raw) - self.seq_len
		border2 = len(df_raw)

		if self.features == 'M' or self.features == 'MS':
			cols_data = df_raw.columns[1:]
			df_data = df_raw[cols_data]
		elif self.features == 'S':
			df_data = df_raw[[self.target]]

		if self.scale:
			self.scaler.fit(df_data.values)
			data = self.scaler.transform(df_data.values)
		else:
			data = df_data.values

		tmp_stamp = df_raw[['date']][border1:border2]
		tmp_stamp['date'] = pd.to_datetime(tmp_stamp.date)
		pred_dates = pd.date_range(tmp_stamp.date.values[-1], periods=self.pred_len + 1, freq=self.freq)

		df_stamp = pd.DataFrame(columns=['date'])
		df_stamp.date = list(tmp_stamp.date.values) + list(pred_dates[1:])
		if self.timeenc == 0:
			df_stamp['month'] = df_stamp.date.apply(lambda row: row.month, 1)
			df_stamp['day'] = df_stamp.date.apply(lambda row: row.day, 1)
			df_stamp['weekday'] = df_stamp.date.apply(lambda row: row.weekday(), 1)
			df_stamp['hour'] = df_stamp.date.apply(lambda row: row.hour, 1)
			df_stamp['minute'] = df_stamp.date.apply(lambda row: row.minute, 1)
			df_stamp['minute'] = df_stamp.minute.map(lambda x: x // 15)
			data_stamp = df_stamp.drop(['date'], 1).values
		elif self.timeenc == 1:
			data_stamp = time_features(pd.to_datetime(df_stamp['date'].values), freq=self.freq)
			data_stamp = data_stamp.transpose(1, 0)

		self.data_x = data[border1:border2]
		if self.inverse:
			self.data_y = df_data.values[border1:border2]
		else:
			self.data_y = data[border1:border2]
		self.data_stamp = data_stamp

	def __getitem__(self, index):
		s_begin = index
		s_end = s_begin + self.seq_len
		r_begin = s_end - self.label_len
		r_end = r_begin + self.label_len + self.pred_len

		seq_x = self.data_x[s_begin:s_end]
		if self.inverse:
			seq_y = self.data_x[r_begin:r_begin + self.label_len]
		else:
			seq_y = self.data_y[r_begin:r_begin + self.label_len]
		seq_x_mark = self.data_stamp[s_begin:s_end]
		seq_y_mark = self.data_stamp[r_begin:r_end]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x) - self.seq_len + 1

	def inverse_transform(self, data):
		return self.scaler.inverse_transform(data)


class Dataset_sin(Dataset):
	def __init__(self, root_path=None, flag='train', size=None,
	             features='S', data_path=None,
	             target='OT', scale=True, timeenc=0, freq='h',
	             noise=0.01):
		# size [seq_len, label_len, pred_len]
		# info
		if size == None:
			self.seq_len = 5
			self.label_len = 1
			self.pred_len = 5
		else:
			self.seq_len = size[0]
			self.label_len = size[1]
			self.pred_len = size[2]
		self.noise = noise
		# init
		self.__read_data__()

	def __read_data__(self):


		'''
		df_raw.columns: ['date', ...(other features), target feature]
		'''

		n=200
		num_periods=5
		x = np.arange(n)
		T = (2 * np.pi) / (n/num_periods)
		y = np.sin(x * T)
		noise = np.random.normal(loc=0,scale=1,size=n)*self.noise
		y   = y +noise
		y = y.reshape(n,1)

		self.data_x = y
		self.data_y = y
		self.data_stamp = x

	def __getitem__(self, index):
		s_begin = index
		s_end = s_begin + self.seq_len
		r_begin = s_end - self.label_len
		r_end = r_begin + self.label_len + self.pred_len

		seq_x = self.data_x[s_begin:s_end]
		seq_y = self.data_y[r_begin:r_end]
		seq_x_mark = self.data_stamp[s_begin:s_end]
		seq_y_mark = self.data_stamp[r_begin:r_end]

		return seq_x, seq_y, seq_x_mark, seq_y_mark

	def __len__(self):
		return len(self.data_x) - self.seq_len - self.pred_len + 1


