import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from dataloaders.utils import load_sparse_data, load_data
import pickle as pk


base_dir = os.path.join(os.path.split(__file__)[0], '../../data/forest/')

# data = dataset.data
# label =dataset.labels
# idx = np.arange(len(data))
# np.random.shuffle(idx)
# train_idx, eval_idx = idx[:int(0.8 * len(idx))], idx[int(0.8 * len(idx)):]
# train_data, eval_data = data[train_idx], data[eval_idx]
# train_label, eval_label = label[train_idx], label[eval_idx]
# np.save('train_data.npy', train_data)
# np.save('eval_data.npy', eval_data)
# np.save('train_label.npy', train_label)
# np.save('eval_label.npy', eval_label)
#
# class Forest(Dataset):
#     def __init__(self):
#         data = pd.read_csv(os.path.join(base_dir, 'train.csv'))
#         data = np.array(data)
#         label = data[:, -1]
#         data = data[:, 1: -1]
#
#         real_nums, bin_nums = data[:, :10], data[:, 10:]
#
#         mean = np.mean(real_nums, 0)
#         std = np.std(real_nums, 0)
#         real_nums = (real_nums - mean) / std
#
#         out_data = np.concatenate([real_nums, bin_nums], -1)
#
#         self.data = out_data
#         self.labels = label - 1
#
#     def __getitem__(self, item):
#         return self.data[item].astype('float32'), self.labels[item]
#
#     def __len__(self):
#         return len(self.labels)


class Forest(Dataset):
    def __init__(self, dtype='train'):
        if dtype in ['valid', 'eval', 'test']:
            dtype = 'eval'

        self.data = np.load(os.path.join(base_dir, f'{dtype}_data.npy'))
        self.label = np.load(os.path.join(base_dir, f'{dtype}_label.npy'))

        self.data = self.data.astype('float32')
        self.label = self.label.astype('int64')

    def __getitem__(self, item):
        return self.data[item], self.label[item]

    def __len__(self):
        return len(self.data)
