from data_provider.data_factory import data_provider
from exp.exp_basic import Exp_Basic
from exp.replay import ReplayBuffer
from exp.nrls import NRLS
from models import Informer, Autoformer, Transformer, DLinear, Linear, NLinear, MLP
from utils.tools import EarlyStopping, adjust_learning_rate, visual, test_params_flop
from utils.metrics import metric, MSE, MAE, CORR

import numpy as np
import torch
import torch.nn as nn
from torch import optim
import joblib

import os
import time

import warnings
import matplotlib.pyplot as plt
import numpy as np

warnings.filterwarnings('ignore')


class Exp_Main(Exp_Basic):
	def __init__(self, args):
		super(Exp_Main, self).__init__(args)
		self.diff_mean = 0
		self.diff_std = 1
		self.pred_diff = args.pred_diff

	def _build_model(self):
		model_dict = {
			'Autoformer': Autoformer,
			'Transformer': Transformer,
			'Informer': Informer,
			'DLinear': DLinear,
			'NLinear': NLinear,
			'Linear': Linear,
			'MLP':MLP,
		}
		model = model_dict[self.args.model].Model(self.args).float()

		if self.args.use_multi_gpu and self.args.use_gpu:
			model = nn.DataParallel(model, device_ids=self.args.device_ids)
		return model

	def _get_data(self, flag):
		data_set, data_loader = data_provider(self.args, flag)
		return data_set, data_loader

	def _select_optimizer(self):
		model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
		return model_optim

	def _select_criterion(self):
		criterion = nn.MSELoss()
		return criterion

	def vali(self, vali_data, vali_loader, criterion):
		total_loss = []
		self.model.eval()
		with torch.no_grad():
			for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(vali_loader):
				batch_x = batch_x.float().to(self.device)
				batch_y = batch_y.float().to(self.device)

				batch_x_mark = batch_x_mark.float().to(self.device)
				batch_y_mark = batch_y_mark.float().to(self.device)

				batch_x, batch_y = vali_data.normalize(batch_x, batch_y)


				# encoder - decoder
				if 'Linear' in self.args.model or 'MLP' in self.args.model:
					outputs = self.model(batch_x)
				else:
					# decoder input
					dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
					dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float()
					if self.args.output_attention:
						outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
					else:
						outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
				f_dim = -1 if self.args.features == 'MS' else 0
				outputs = outputs[:, -self.args.pred_len:, f_dim:]
				batch_y = batch_y[:, -self.args.pred_len:, f_dim:]

				batch_x, batch_y,outputs = vali_data.denormalize(batch_x, batch_y,outputs)

				pred = outputs.detach().cpu()
				true = batch_y.detach().cpu()

				loss = criterion(pred, true)

				total_loss.append(loss)
		total_loss = np.average(total_loss)
		self.model.train()
		return total_loss

	def train(self, setting):
		train_data, train_loader = self._get_data(flag='train')
		vali_data, vali_loader = self._get_data(flag='val')
		test_data, test_loader = self._get_data(flag='test')

		path = os.path.join(self.args.checkpoints, setting)
		if not os.path.exists(path):
			os.makedirs(path)

		time_now = time.time()

		train_steps = len(train_loader)
		early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)

		model_optim = self._select_optimizer()
		criterion = self._select_criterion()

		for epoch in range(self.args.train_epochs):
			iter_count = 0
			train_loss = []

			self.model.train()
			epoch_time = time.time()
			for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
				iter_count += 1
				model_optim.zero_grad()
				batch_x = batch_x.float().to(self.device)

				batch_y = batch_y.float().to(self.device)
				batch_x_mark = batch_x_mark.float().to(self.device)
				batch_y_mark = batch_y_mark.float().to(self.device)

				batch_x,batch_y = train_data.normalize(batch_x,batch_y)

				# encoder - decoder
				if 'Linear' in self.args.model or 'MLP' in self.args.model:
					outputs = self.model(batch_x)
				else:
					# decoder input
					dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
					dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
					if self.args.output_attention:
						outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

					else:
						outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, batch_y)
				# print(outputs.shape,batch_y.shape)
				f_dim = -1 if self.args.features == 'MS' else 0
				outputs = outputs[:, -self.args.pred_len:, f_dim:]
				batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
				loss = criterion(outputs, batch_y)
				train_loss.append(loss.item())

				if (i + 1) % 100 == 0:
					print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
					speed = (time.time() - time_now) / iter_count
					left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
					print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
					iter_count = 0
					time_now = time.time()

				loss.backward()
				model_optim.step()

			print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
			train_loss = np.average(train_loss)
			vali_loss = self.vali(vali_data, vali_loader, criterion)
			test_loss = self.vali(test_data, test_loader, criterion)

			print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
				epoch + 1, train_steps, train_loss, vali_loss, test_loss))
			early_stopping(vali_loss, self.model, path)
			# if early_stopping.early_stop:
			# 	print("Early stopping")
			# 	break

			adjust_learning_rate(model_optim, epoch + 1, self.args)

		best_model_path = path + '/' + 'checkpoint.pth'
		self.model.load_state_dict(torch.load(best_model_path))

		return self.model

	def test(self, setting, test=0):
		# TODO
		test_data, test_loader = self._get_data(flag='test')
		pred_diff = self.pred_diff
		if test:
			print('loading model')
			self.model.load_state_dict(torch.load(os.path.join(self.args.checkpoints + setting, 'checkpoint.pth')))

		preds = []
		preds2=[]
		trues = []
		inputx = []
		folder_path = './results/test/' + setting + '/'
		if not os.path.exists(folder_path):
			os.makedirs(folder_path)

		self.model.eval()
		with torch.no_grad():
			for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
				batch_x = batch_x.float().to(self.device)
				batch_y = batch_y.float().to(self.device)

				batch_x_mark = batch_x_mark.float().to(self.device)
				batch_y_mark = batch_y_mark.float().to(self.device)

				# decoder input
				dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
				dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
				# encoder - decoder
				if self.args.use_amp:
					with torch.cuda.amp.autocast():
						if 'Linear' in self.args.model or 'MLP' in self.args.model:
							if not pred_diff:
								outputs = self.model(batch_x)
							else:
								new_shape = list(batch_x.size())
								batch_diff = torch.zeros_like(batch_x)
								batch_diff[:, 1:, :] = batch_x[:, 1:, :] - batch_x[:, :-1, :]
								batch_diff[:, 0, :] = batch_diff[:, 1, :]
								batch_diff = (batch_diff - self.diff_mean) / self.diff_std
								output_diff = self.model(batch_diff)
								output_diff = output_diff * self.diff_std + self.diff_mean
								outputs = torch.cumsum(output_diff, dim=1)
								init_pos = batch_x[:, -1:, :]
								init_pos = init_pos.repeat(1, outputs.size(1), 1)
								outputs = outputs + init_pos
						else:
							if self.args.output_attention:
								outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
							else:
								outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
				else:
					if 'Linear' in self.args.model or 'MLP' in self.args.model:
						if not pred_diff:
							outputs = self.model(batch_x)
						else:
							new_shape = list(batch_x.size())
							batch_diff = torch.zeros_like(batch_x)
							batch_diff[:, 1:, :] = batch_x[:, 1:, :] - batch_x[:, :-1, :]
							batch_diff[:, 0, :] = batch_diff[:, 1, :]
							batch_diff = (batch_diff - self.diff_mean) / self.diff_std
							output_diff = self.model(batch_diff)
							output_diff = output_diff * self.diff_std + self.diff_mean
							outputs = torch.cumsum(output_diff, dim=1)
							init_pos = batch_x[:, -1:, :]
							init_pos = init_pos.repeat(1, outputs.size(1), 1)
							outputs = outputs + init_pos
						out2 = outputs
					else:
						if self.args.output_attention:
							outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
							out2 = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]

						else:
							outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
							out2 = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)

				f_dim = -1 if self.args.features == 'MS' else 0
				# print(outputs.shape,batch_y.shape)
				outputs = outputs[:, -self.args.pred_len:, f_dim:]
				batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
				outputs = outputs.detach().cpu().numpy()
				batch_y = batch_y.detach().cpu().numpy()
				out2 =  out2[:, -self.args.pred_len:, f_dim:].detach().cpu().numpy()

				pred = outputs  # outputs.detach().cpu().numpy()  # .squeeze()
				true = batch_y  # batch_y.detach().cpu().numpy()  # .squeeze()

				preds.append(pred)
				trues.append(true)
				inputx.append(batch_x.detach().cpu().numpy())
				preds2.append(out2)
				if i % 20 == 0:
					input = batch_x.detach().cpu().numpy()
					gt = np.concatenate((input[0, :, -1], true[0, :, -1]), axis=0)
					pd = np.concatenate((input[0, :, -1], pred[0, :, -1]), axis=0)
					visual(gt, pd, os.path.join(folder_path, str(i) + '.pdf'))

		if self.args.test_flop:
			test_params_flop((batch_x.shape[1], batch_x.shape[2]))
			exit()
		preds =np.concatenate(preds,axis=0) #np.array(preds)
		trues = np.concatenate(trues,axis=0)
		inputx = np.concatenate(inputx,axis=0)
		preds2 = np.concatenate(preds2, axis=0)

		#preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
		#trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
		#inputx = inputx.reshape(-1, inputx.shape[-2], inputx.shape[-1])

		# result save
		# folder_path = './results/' + setting + '/'
		# if not os.path.exists(folder_path):
		#    os.makedirs(folder_path)

		mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
		mae2, mse2, _, _, _, _, _ = metric(preds2, trues)
		print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
		print('mse2:{}, mae2:{}'.format(mse2, mae2))
		f = open(folder_path + "result.txt", 'a')
		f.write(setting + "  \n")
		f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
		f.write('\n')
		f.write('\n')
		f.close()

		joblib.dump({'pred': preds, 'label': trues, 'input': inputx,
		             'metrics': np.array([mae, mse, rmse, mape, mspe, rse, corr])}, folder_path + 'results.pkl')
		return

	def predict(self, setting, load=False):
		# TODO
		pred_data, pred_loader = self._get_data(flag='pred')

		if load:
			path = os.path.join(self.args.checkpoints, setting)
			best_model_path = path + '/' + 'checkpoint.pth'
			self.model.load_state_dict(torch.load(best_model_path))

		preds = []

		self.model.eval()
		with torch.no_grad():
			for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(pred_loader):
				batch_x = batch_x.float().to(self.device)
				batch_y = batch_y.float()
				batch_x_mark = batch_x_mark.float().to(self.device)
				batch_y_mark = batch_y_mark.float().to(self.device)

				# decoder input
				dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[2]]).float().to(
					batch_y.device)
				dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
				# encoder - decoder
				if self.args.use_amp:
					with torch.cuda.amp.autocast():
						if 'Linear' in self.args.model or 'MLP' in self.args.model:
							outputs = self.model(batch_x)
						else:
							if self.args.output_attention:
								outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
							else:
								outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
				else:
					if 'Linear' in self.args.model or 'MLP' in self.args.model:
						outputs = self.model(batch_x)
					else:
						if self.args.output_attention:
							outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
						else:
							outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
				pred = outputs.detach().cpu().numpy()  # .squeeze()
				preds.append(pred)

		preds = np.array(preds)
		preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])

		# result save
		folder_path = './results/' + setting + '/'
		if not os.path.exists(folder_path):
			os.makedirs(folder_path)

		np.save(folder_path + 'real_prediction.npy', preds)

		return

	def infer_sample(self, batch_x, batch_y, batch_x_mark, batch_y_mark, return_feature=False, pred_diff=None):


		# encoder - decoder
		if 'Linear'  in self.args.model or 'MLP'  in self.args.model:
			outputs = self.model(batch_x)
			feat = batch_x
		else:
			# decoder input
			dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()
			dec_inp = torch.cat([batch_y[:, :self.args.label_len, :], dec_inp], dim=1).float().to(self.device)
			if return_feature:
				outputs, feat = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark, return_feature=True)
			else:
				outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
				feat = None

		f_dim = -1 if self.args.features == 'MS' else 0
		# print(outputs.shape,batch_y.shape)
		outputs = outputs[:, -self.args.pred_len:, f_dim:]
		batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
		if return_feature:
			return outputs, batch_y, batch_x, feat
		else:
			return outputs, batch_y, batch_x

	def adaptable_predict(self, setting, adapt_layers, load_path=None, stop_iter=-1):
		self.args.batch_size = 1
		args = self.args
		test_data, test_loader = self._get_data(flag='test')

		if load_path is not None:
			print('loading model', load_path)
			self.model.load_state_dict(torch.load(load_path,map_location=self.device))

		replay_buffer = ReplayBuffer(max_size=self.args.buffer_size, pull_mode=self.args.replay_buffer,
		                             delay=self.args.update_step_replay,device=self.device)

		adapt_weights = []
		for name, p in self.model.named_parameters():
			if name in adapt_layers:
				adapt_weights.append(p)
				print(name, p.size())

		if self.args.adapt == 'nrls':
			optimizer = NRLS(adapt_weights, dim_out=self.args.update_step * self.args.c_out,
			                 p0=self.args.p0, eps=self.args.eps, lbd=self.args.lbd,)
			optimizer_replay = NRLS(adapt_weights, dim_out=self.args.update_step_replay * self.args.c_out,
			                        p0=self.args.p0_replay, eps=self.args.eps_replay, lbd=self.args.lbd_replay)
		elif self.args.adapt == 'sgd':
			optimizer = torch.optim.SGD(adapt_weights, lr=self.args.lr)
			optimizer_replay = torch.optim.SGD(adapt_weights, lr=self.args.lr_replay)
		elif self.args.adapt == 'adam':
			optimizer = torch.optim.Adam(adapt_weights, lr=self.args.lr)
			optimizer_replay = torch.optim.Adam(adapt_weights, lr=self.args.lr_replay)
		else:
			optimizer = None
			optimizer_replay = None

		preds = []
		trues = []
		inputx = []
		prior_errs_list = []
		traj_err_list,traj_preds, traj_inputs, traj_trues, traj_uncertains =[],[],[],[],[]
		uncertainty_list, real_errs = [], []
		sample_sims=[]
		indices_sims=[]
		last_x=None
		K_list = []
		folder_path = './results/adaptation/' + setting + '/'
		if not os.path.exists(folder_path):
			os.makedirs(folder_path)

		uncertain_indices=[]
		abnormal_points=[]
		pes=[]
		self.model.eval()
		with torch.no_grad():
			for i, (data_x, data_y, data_x_mark, data_y_mark) in enumerate(test_loader):
				data_x = data_x.float().to(self.device)
				data_y = data_y.float().to(self.device)
				data_x_mark = data_x_mark.float().to(self.device)
				data_y_mark = data_y_mark.float().to(self.device)

				data_x, data_y = test_data.normalize(data_x, data_y)

				outputs, batch_y, _ = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark)
				# filter out abnormal points
				prior_e = ((outputs - batch_y) ** 2).mean().data.item()
				pes.append((i,prior_e))
				if args.prior_thresh is not None:
					if prior_e > self.args.prior_thresh:
						abnormal_points.append(i)
		if args.prior_thresh is None:
			pes = sorted(pes, key= lambda x : x[1])
			ab_num = int(0.01*len(pes))
			pes = pes[-ab_num:]
			abnormal_points = [pn[0] for pn in pes]

		print('Abnormal points:', len(abnormal_points))
		Ks=[]
		for iii in range(args.pred_len):
			Ks.append([args.K0*args.K0_step*(iii+1)]*args.c_out)
		K = np.array(Ks)[np.newaxis,:]
		delta = self.args.delta
		i_iter = 0
		rewards=0
		non_sims=0


		test_data, test_loader = self._get_data(flag='test')
		self.model.train()
		for i, (data_x, data_y, data_x_mark, data_y_mark) in enumerate(test_loader):
			if i>stop_iter and stop_iter>0:
				break
			data_x = data_x.float().to(self.device)
			data_y = data_y.float().to(self.device)
			data_x_mark = data_x_mark.float().to(self.device)
			data_y_mark = data_y_mark.float().to(self.device)
			if i not in abnormal_points:
				i_iter += 1

			data_x, data_y = test_data.normalize(data_x, data_y)

			# feedforward compensation
			data_rep = self.args.replay_buffer != 'none' and len(replay_buffer.data) >= 5 and i not in abnormal_points
			if data_rep:
				with torch.no_grad():
					outputs, batch_y, batch_x, batch_feat = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
					                                                          return_feature=True)
				if self.args.replay_feature == 'traj':
					current_feature = batch_x
				else:
					current_feature = batch_feat
				features_rep, labels_rep, inp_marks_rep, out_marks_rep, inp_trajs_rep, rewards = replay_buffer.get_sample(
					size=1, current_feature=current_feature)
				cur_sample_sims = replay_buffer.cur_sample_sims

				fit_iters = 1 if args.fitting_thresh is None else args.max_fit_iter
				fit_err = 0
				for i_fit in range(fit_iters):
					outputs_rep, y_rep, x_rep, feat_rep = self.infer_sample(inp_trajs_rep, labels_rep, inp_marks_rep,
					                                                        out_marks_rep,
					                                                        return_feature=True)

					#print(cur_sample_sims[0].item())
					if cur_sample_sims[0] > self.args.sim_update_thresh or self.args.replay_buffer != 'feature_sim':
						y_adapt_rep = y_rep[0, :self.args.update_step_replay].contiguous().view((-1, 1))
						y_hat_adapt_rep = outputs_rep[0, :self.args.update_step_replay].contiguous().view((-1, 1))
						err_adapt_rep = (y_adapt_rep - y_hat_adapt_rep).detach()
						fit_err = (y_adapt_rep - y_hat_adapt_rep).pow(2).mean().item()

						def nrls_closure_replay(index=0):
							optimizer_replay.zero_grad()
							dim_out = optimizer_replay.state['dim_out']
							retain = index < dim_out - 1
							y_hat_adapt_rep[index].backward(retain_graph=retain)
							return err_adapt_rep

						if self.args.adapt == 'nrls':
							optimizer_replay.step(nrls_closure_replay)
						elif self.args.adapt == 'none':
							fit_err = 0
						else:
							optimizer_replay.zero_grad()
							loss_rep = (y_adapt_rep - y_hat_adapt_rep).pow(2).mean()
							loss_rep.backward()
							optimizer_replay.step()

					else:
						non_sims+=1
					if  args.fitting_thresh is None or fit_err < args.fitting_thresh:
						break

			# uncertainty estimation and prediction
			self.model.eval()
			with torch.no_grad():
				outputs, batch_y, batch_x, batch_feat = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
				                                                          return_feature=True)
				prior_e = ((outputs - batch_y) ** 2).mean().data
				if torch.isnan(prior_e) or torch.isinf(prior_e):
					return
				if data_rep:
					with torch.no_grad():
						outputs_rep, y_rep, x_rep, feat_rep = self.infer_sample(inp_trajs_rep, labels_rep,
						                                                        inp_marks_rep, out_marks_rep,
						                                                        return_feature=True)

						if self.args.replay_feature == 'traj':
							rep_feature = x_rep
						else:
							rep_feature = feat_rep
						z_err = torch.norm(current_feature[0] - rep_feature[0]).detach().cpu()
						post_err_rep = torch.abs(outputs_rep - y_rep).detach().cpu()
						pred_diff = torch.abs(outputs_rep - outputs).detach().cpu()
						feat_norm =  torch.norm(rep_feature[0]).detach().cpu()
						time_shift = abs(i_iter-rewards[0])
						sigma_vec = torch.tensor(K) * z_err + delta*time_shift + post_err_rep + pred_diff
						uncertain_indices.append(i)
						z_t = current_feature.view((len(current_feature), -1))
						z_s = rep_feature.view((len(rep_feature), -1))
						#torch.sqrt(((t1 - t2) ** 2).mean(axis=1))
						#sample_err = torch.norm(z_t - z_s, dim=1).detach().cpu().numpy()
						sample_err = torch.sqrt(((z_t - z_s) ** 2).mean(axis=1)).detach().cpu().numpy()
						sample_sims.append(sample_err)
						#indices_sims.append(rewards[0]-i_iter)
						indices_sims.append(rewards[0])
				else:
					sigma_vec = torch.zeros_like(outputs)
					indices_sims.append(-1)
					if last_x is not None:
						z_t = batch_x.view((len(batch_x), -1))
						z_s = last_x.view((len(last_x), -1)).to(z_t.device)
						#sample_err = torch.norm(z_t - z_s, dim=1).detach().cpu().numpy()
						sample_err = torch.sqrt(((z_t - z_s) ** 2).mean(axis=1)).detach().cpu().numpy()
						sample_sims.append(sample_err)
					else:
						sample_sims.append(np.zeros(len(batch_x)))

			preds.append(outputs.detach().cpu().numpy())
			trues.append(batch_y.detach().cpu().numpy())
			inputx.append(batch_x.detach().cpu().numpy())


			prior_errs_list.append(prior_e.item())
			uncertainty_list.append(self.args.uncertainty_ratio * sigma_vec.detach().cpu().numpy())
			real_errs.append(torch.abs(outputs - batch_y).detach().cpu().numpy())
			K_list.append(K)
			if data_rep:
				K = update_k(K, real_errs[-1], uncertainty_list[-1])

			if len(inputx) >=self.args.update_step_replay:
				last_x = torch.tensor(inputx[-self.args.update_step_replay])

			# save buffer
			if self.args.replay_buffer != 'none' and self.args.adapt !='none':
				if self.args.replay_feature == 'traj':
					feat = batch_x
				else:
					feat = batch_feat
				replay_buffer.push(feat, data_y,data_x_mark, data_y_mark, data_x, reward=[i_iter] * len(feat),
				                   task=0)

			traj_x, traj_y, traj_outputs, traj_u = test_data.denormalize(batch_x, batch_y, outputs,u=self.args.uncertainty_ratio * sigma_vec)
			traj_preds.append(traj_outputs.detach().cpu().numpy())
			traj_trues.append(traj_y.detach().cpu().numpy())
			traj_inputs.append(traj_x.detach().cpu().numpy())
			traj_err_list.append(((traj_outputs - traj_y) ** 2).mean().item())
			traj_uncertains.append(traj_u.detach().cpu().numpy())

			"""
			self.model.train()
			# feedback adaptation
			if i>=args.update_step-1 and args.feedback and self.args.adapt !='none':
				#data_x_pre,data_y_pre,data_x_mark_pre,data_y_mark_pre = data_buffer[0]
				#data_x_pre = data_x_pre.float().to(self.device)
				#data_y_pre = data_y_pre.float().to(self.device)
				#data_x_mark_pre = data_x_mark_pre.float().to(self.device)
				#data_y_mark_pre = data_y_mark_pre.float().to(self.device)
				data_x_pre, data_y_pre, data_x_mark_pre, data_y_mark_pre  =data_x,data_y,data_x_mark,data_y_mark
				outputs_pre, batch_y_pre, batch_x_pre = self.infer_sample(data_x_pre,data_y_pre,data_x_mark_pre,data_y_mark_pre ,
				                                              return_feature=False)
				y_adapt = batch_y_pre[0, :self.args.update_step].contiguous().view((-1, 1))
				y_hat_adapt = outputs_pre[0, :self.args.update_step].contiguous().view((-1, 1))
				err_adapt = (y_adapt - y_hat_adapt).detach()

				def nrls_closure(index=0):
					optimizer.zero_grad()
					dim_out = optimizer.state['dim_out']
					retain = index < dim_out - 1
					y_hat_adapt[index].backward(retain_graph=retain)
					return err_adapt

				if self.args.adapt == 'nrls' and i not in abnormal_points:
					optimizer.step(nrls_closure)
				elif self.args.adapt == 'none' or i in abnormal_points:
					pass
				else:
					optimizer.zero_grad()
					loss = (y_adapt - y_hat_adapt).pow(2).mean()
					loss.backward()
					optimizer.step()

			# eval posterior
			self.model.eval()
			with torch.no_grad():
				outputs, batch_y, batch_x = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
				                                              return_feature=False)
				post_e = ((outputs - batch_y) ** 2).mean().data
				post_errs_list.append(post_e.item())
			"""
			self.model.train()
			if i_iter % 10 == 0:
				print(f'step:{i}, prior_error:{prior_e}')

		print('non_sims:', non_sims)
		preds =np.concatenate(preds,axis=0) #np.array(preds)
		trues = np.concatenate(trues,axis=0)
		inputx = np.concatenate(inputx,axis=0)
		uncertainty_list =  np.concatenate(uncertainty_list,axis=0)
		real_errs = np.concatenate(real_errs,axis=0)

		traj_preds = np.concatenate(traj_preds, axis=0)
		traj_trues = np.concatenate(traj_trues, axis=0)
		traj_inputs = np.concatenate(traj_inputs, axis=0)
		traj_uncertains = np.concatenate(traj_uncertains, axis=0)


		mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
		print('mse:{}, mae:{}'.format(mse, mae))
		traj_mae, traj_mse, traj_rmse, traj_mape, traj_mspe, traj_rse, traj_corr = metric(traj_preds, traj_trues)
		print('traj_mse:{}, traj_mae:{}'.format(traj_mse, traj_mae))

		f = open(folder_path + "result.txt", 'a')
		f.write(setting + "  \n")
		f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
		f.write('\n')
		f.write('traj_mse:{}, traj_mae:{}, traj_rse:{}, traj_corr:{}'.format(traj_mse, traj_mae, traj_rse, traj_corr))
		f.write('\n')
		f.write('\n')
		f.close()

		print('time shift selcted sample',np.mean(indices_sims))
		save_dict = {'pred': preds, 'label': trues, 'input': inputx,'sample_sims':sample_sims,'indices_sims':indices_sims,
		             'metrics': np.array([mae, mse, rmse, mape, mspe, rse]),
		             'prior_errs': prior_errs_list, 'uncertainty': uncertainty_list,
		             'real_errs': real_errs, 'K': K_list,'uncertain_indices':uncertain_indices,
		             'traj_pred':traj_preds, 'traj_true':traj_trues, 'traj_input':traj_inputs,'traj_err_list':traj_err_list,
		             'traj_uncertains':traj_uncertains}
		joblib.dump(save_dict, folder_path + 'adapt_results.pkl')

		plot_figures(save_dict, folder_path)
		return

	def adaptable_predict_simple_batch(self, setting, adapt_layers, load_path=None, stop_iter=-1):

		args = self.args
		args.batch_size = 1
		test_data, test_loader = data_provider(args, flag='test')

		if load_path is not None:
			print('loading model', load_path)
			self.model.load_state_dict(torch.load(load_path, map_location=self.device))

		replay_buffer = ReplayBuffer(max_size=self.args.buffer_size, pull_mode=self.args.replay_buffer,
		                             delay=self.args.update_step_replay, device=self.device)

		adapt_weights = []
		for name, p in self.model.named_parameters():
			if name in adapt_layers:
				adapt_weights.append(p)
				print(name, p.size())

		if self.args.adapt == 'sgd':
			optimizer = torch.optim.SGD(adapt_weights, lr=self.args.lr)
			optimizer_replay = torch.optim.SGD(adapt_weights, lr=self.args.lr_replay)
		else:
			optimizer = None
			optimizer_replay = None
			raise NotImplementedError

		preds = []
		trues = []
		inputx = []
		prior_errs_list, post_errs_list = [], []
		uncertainty_list, real_errs = [], []
		sample_sims = []
		data_buffer = []
		K_list = []
		folder_path = './results/adaptation/' + setting + '/'
		if not os.path.exists(folder_path):
			os.makedirs(folder_path)

		uncertain_indices = []
		abnormal_points = []
		pes = []
		self.model.eval()
		with torch.no_grad():
			for i, (data_x, data_y, data_x_mark, data_y_mark) in enumerate(test_loader):
				data_x = data_x.float().to(self.device)
				data_y = data_y.float().to(self.device)
				data_x_mark = data_x_mark.float().to(self.device)
				data_y_mark = data_y_mark.float().to(self.device)
				outputs, batch_y, _ = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark)
				# filter out abnormal points
				prior_e = ((outputs - batch_y) ** 2).mean().data.item()
				pes.append((i, prior_e))
				if args.prior_thresh is not None:
					if prior_e > self.args.prior_thresh:
						abnormal_points.append(i)
		if args.prior_thresh is None:
			pes = sorted(pes, key=lambda x: x[1])
			ab_num = int(0.01 * len(pes))
			pes = pes[-ab_num:]
			abnormal_points = [pn[0] for pn in pes]

		print('Abnormal points:', len(abnormal_points))
		i_iter = 0
		non_sims = 0
		test_data, test_loader = data_provider(args, flag='test')
		self.model.train()
		for i, (data_x, data_y, data_x_mark, data_y_mark) in enumerate(test_loader):
			if i > stop_iter and stop_iter > 0:
				break

			data_x = data_x.float().to(self.device)
			data_y = data_y.float().to(self.device)
			data_x_mark = data_x_mark.float().to(self.device)
			data_y_mark = data_y_mark.float().to(self.device)
			if i not in abnormal_points:
				i_iter += 1

			# feedforward compensation
			data_rep = self.args.replay_buffer != 'none' and len(replay_buffer.data) > 5 and i not in abnormal_points
			if data_rep:
				with torch.no_grad():
					outputs, batch_y, batch_x, batch_feat = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
					                                                          return_feature=True)
				if self.args.replay_feature == 'traj':
					current_feature = batch_x
				else:
					current_feature = batch_feat
				features_rep, labels_rep, inp_marks_rep, out_marks_rep, inp_trajs_rep, rewards = replay_buffer.get_sample(
					size=args.adapt_batch, current_feature=current_feature)
				cur_sample_sims = replay_buffer.cur_sample_sims

				outputs_rep, y_rep, x_rep, feat_rep = self.infer_sample(inp_trajs_rep, labels_rep, inp_marks_rep,
				                                                        out_marks_rep,
				                                                        return_feature=True)

				# print(cur_sample_sims[0].item())
				if cur_sample_sims[0] > self.args.sim_update_thresh or self.args.replay_buffer != 'feature_sim':
					y_adapt_rep = y_rep[:, :self.args.update_step_replay]
					y_hat_adapt_rep = outputs_rep[:, :self.args.update_step_replay]
					optimizer_replay.zero_grad()
					loss_rep = (y_adapt_rep - y_hat_adapt_rep).pow(2).mean()
					loss_rep.backward()
					optimizer_replay.step()
				else:
					non_sims += 1

			# uncertainty estimation and prediction
			self.model.eval()
			with torch.no_grad():
				outputs, batch_y, batch_x, batch_feat = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
				                                                          return_feature=True)
				prior_e = ((outputs - batch_y) ** 2).mean().data
				if torch.isnan(prior_e) or torch.isinf(prior_e):
					return
				if data_rep:
					with torch.no_grad():
						outputs_rep, y_rep, x_rep, feat_rep = self.infer_sample(inp_trajs_rep, labels_rep,
						                                                        inp_marks_rep, out_marks_rep,
						                                                        return_feature=True)

						if self.args.replay_feature == 'traj':
							rep_feature = x_rep
						else:
							rep_feature = feat_rep
						z_t =  current_feature.view((len(current_feature),-1))
						z_s = rep_feature.view((len(rep_feature), -1))
						z_err = torch.norm(z_t - z_s, dim=1).detach().cpu().numpy()
						sample_sims.append(z_err)
				else:
					sigma_vec = torch.zeros_like(outputs)
					sample_sims.append(np.zeros((args.adapt_batch)))
			preds.append(outputs.detach().cpu().numpy())
			trues.append(batch_y.detach().cpu().numpy())
			inputx.append(batch_x.detach().cpu().numpy())
			prior_errs_list.append(prior_e.item())
			real_errs.append(torch.abs(outputs - batch_y).detach().cpu().numpy())
			uncertainty_list.append(self.args.uncertainty_ratio * sigma_vec.detach().cpu().numpy())

			# save buffer
			if self.args.replay_buffer != 'none' and self.args.adapt != 'none':
				if self.args.replay_feature == 'traj':
					feat = batch_x
				else:
					feat = batch_feat
				replay_buffer.push(feat, data_y, data_x_mark, data_y_mark, data_x, reward=[i_iter] * len(feat),
				                   task=0)
			self.model.train()

			# eval posterior
			self.model.eval()
			with torch.no_grad():
				outputs, batch_y, batch_x = self.infer_sample(data_x, data_y, data_x_mark, data_y_mark,
				                                              return_feature=False)
				post_e = ((outputs - batch_y) ** 2).mean().data
				post_errs_list.append(post_e.item())
			self.model.train()
			if i_iter % 10 == 0:
				print(f'step:{i}, prior_error:{prior_e}, posterior_error:{post_e}')

		preds = np.concatenate(preds, axis=0)  # np.array(preds)
		trues = np.concatenate(trues, axis=0)
		inputx = np.concatenate(inputx, axis=0)
		print('non_sims:', non_sims)
		mae, mse, rmse, mape, mspe, rse, corr = metric(preds, trues)
		print('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
		f = open(folder_path + "result.txt", 'a')
		f.write(setting + "  \n")
		f.write('mse:{}, mae:{}, rse:{}, corr:{}'.format(mse, mae, rse, corr))
		f.write('\n')
		f.write('\n')
		f.close()

		save_dict = {'pred': preds, 'label': trues, 'input': inputx, 'sample_sims': sample_sims,
		             'metrics': np.array([mae, mse, rmse, mape, mspe, rse, corr]),
		             'prior_errs': prior_errs_list, 'post_errs': post_errs_list,
		             'uncertainty': np.concatenate(uncertainty_list, axis=0),
		             'real_errs': np.concatenate(real_errs, axis=0),
		             'K': np.array(K_list), 'uncertain_indices': uncertain_indices}
		joblib.dump(save_dict, folder_path + 'adapt_results.pkl')

		plot_figures(save_dict, folder_path)
		return


def update_k(ks, real_errs, pred_errs):
	s = np.shape(real_errs)
	ks = ks.flatten()
	real_errs = real_errs.flatten()
	pred_errs = pred_errs.flatten()
	new_k = []
	for ii in range(len(ks)):
		if real_errs[ii] > pred_errs[ii] * 2:
			nk = ks[ii] * 1.5
		elif real_errs[ii] < pred_errs[ii] / 2:
			nk = ks[ii] / 1.5
		else:
			nk = ks[ii]

		new_k.append(nk)
	new_k = np.array(new_k)
	new_k = new_k.reshape(s)
	return new_k


def plot_figures(result, folder_path):
	print('Train done! Prepare results.')

	save_dir = os.path.join(folder_path, 'figs/')
	os.makedirs(save_dir, exist_ok=True)

	prior_errs = result['prior_errs']
	plt.figure()
	plt.xlabel('time step')
	plt.ylabel('MSE error')
	ts = np.arange(0, len(prior_errs))
	plt.plot(ts, np.array(prior_errs), color='b', label='prior')
	plt.legend()
	plt.savefig(save_dir + 'error.png')
	plt.show()
	print(f' Prior Error:{np.mean(prior_errs)}')
	"""
	def mmd_linear(X, Y):
		XX = np.dot(X, X.T)
		YY = np.dot(Y, Y.T)
		XY = np.dot(X, Y.T)
		return XX.mean() + YY.mean() - 2 * XY.mean()
	if len(uncertain_indices)>0:
		uncertainty_cor = CORR(result['uncertainty'][uncertain_indices],result['real_errs'][uncertain_indices])
		uncertainty_cor = np.mean(uncertainty_cor)
		uncertainty_match = mmd_linear(result['uncertainty'][uncertain_indices].flatten(), result['real_errs'][uncertain_indices].flatten())
		covers=0
		cnt=0
		y1 = pred_traj - uncertainty_traj
		y2 = pred_traj + uncertainty_traj
		flag1 = np.logical_and(y1[uncertain_indices] <= out_traj[uncertain_indices], y2[uncertain_indices] >= out_traj[uncertain_indices])
		covers += flag1.astype(dtype='float').sum()
		cnt += len(flag1.flatten())
		uncertainty_cover = covers/cnt
	else:
		uncertainty_cor, uncertainty_match, uncertainty_cover = 0,0,0
	"""
	out_traj = result['label']
	inp_traj = result['input']
	pred_traj = result['pred']
	plot_step = 10
	ind = 0

	plt.figure()
	plt.xlabel('time step')
	plt.ylabel('target')
	t = np.arange(len(inp_traj))
	plt.plot(t, out_traj[:, plot_step, ind], label='ground truth')
	plt.plot(t, pred_traj[:, plot_step, ind], label='pred')
	plt.legend()
	plt.savefig(save_dir + 'pred'+ '.png')
	plt.show()


	out_traj = result['traj_true']
	inp_traj = result['traj_input']
	pred_traj = result['traj_pred']
	plot_step = 1
	ind = 0

	plt.figure()
	plt.xlabel('time step')
	plt.ylabel('trajectory')
	t = np.arange(len(inp_traj))
	plt.plot(t, out_traj[:, plot_step, ind], label='ground truth')
	plt.plot(t, pred_traj[:, plot_step, ind], label='pred')
	plt.legend()
	plt.savefig(save_dir + 'pred_traj' + '.png')
	plt.show()

	plt.close()
	print('results save on',save_dir)
