import os
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader



def load_dataset_portfolio_multi(data_dir, args, use_test_set=True, data_ver=''): 
	train_prefix = "train"
	val_prefix = ("test" if use_test_set else "val") 
	# suffix = '_real10k2'
	suffix = data_ver
	lst_loader_train, lst_loader_val = [], []

	# if args.test_mode:
	# 	print('Sanity check: use same data')

	# 	train_inputs, train_labels = torch.load(os.path.join(data_dir, "portfolio_{}_c{}.pt".format(train_prefix, 1)))
	# 	val_inputs, val_labels = torch.load(os.path.join(data_dir, "portfolio_{}_c{}.pt".format(val_prefix, 1)))
	# else:

	train_inputs, train_labels, train_x_sols, train_y_sols, train_z_sols, train_x_sols_sum, train_y_sols_sum, train_z_sols_sum = torch.load(os.path.join(data_dir, "portfolio_{}{}.pt".format(train_prefix, suffix)))
	val_inputs, val_labels, val_x_sols, val_y_sols, val_z_sols, val_x_sols_sum, val_y_sols_sum, val_z_sols_sum  = torch.load(os.path.join(data_dir, "portfolio_{}{}.pt".format(val_prefix, suffix)))
	print('train_inputs.shape', train_inputs.shape, train_labels.shape, train_x_sols.shape, train_y_sols.shape, train_z_sols.shape, train_x_sols_sum.shape, train_y_sols_sum.shape, train_z_sols_sum.shape)
	dataset_train =TensorDataset(train_inputs, train_labels, train_x_sols, train_y_sols, train_z_sols, train_x_sols_sum, train_y_sols_sum, train_z_sols_sum)
	loader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True)

	dataset_val =TensorDataset(val_inputs, val_labels, val_x_sols, val_y_sols, val_z_sols, val_x_sols_sum, val_y_sols_sum, val_z_sols_sum)
	loader_val = DataLoader(dataset_val, batch_size=args.batch_size)
	return loader_train, loader_val




