import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from src.dataloaders.utils import load_sparse_data, load_data, to_sparse_bin_data, BaseDataset

import matplotlib.pyplot as plt
import pickle as pk

base_dir = os.path.join(os.path.split(__file__)[0], '../../data/gisette/')


if not all([os.path.exists(os.path.join(base_dir, 'valid1_data.npy')),
            os.path.exists(os.path.join(base_dir, 'test_label.npy'))]):
    base_dir = os.path.join(os.path.split(__file__)[0], '../../data/gisette/')

    data = np.concatenate([load_data(os.path.join(base_dir, f'gisette_train.data')),
                           load_data(os.path.join(base_dir, f'gisette_valid.data')),])
    label = np.concatenate([load_data(os.path.join(base_dir, f'gisette_train.labels')),
                            load_data(os.path.join(base_dir, f'gisette_valid.labels')),])
    data = data/1000  # original: 0~999
    label[np.where(label < 0)] = 0

    # shuffle
    idx = np.arange(len(data))
    np.random.shuffle(idx)
    data = data[idx]
    label = label[idx]

    # cut data
    train_data, eval_data1, eval_data2, test_data = data[:int(0.6 * len(data))], data[int(0.6 * len(data)):int(0.7 * len(data))], \
                                                    data[int(0.7 * len(data)):int(0.8 * len(data))], data[int(0.8 * len(data)):]
    train_label, eval_label1, eval_label2, test_label = label[:int(0.6 * len(data))], label[int(0.6 * len(data)):int(0.7 * len(data))], \
                                                        label[int(0.7 * len(data)):int(0.8 * len(data))], label[int(0.8 * len(data)):]

    print(f'resplit data, the pos in data is: ({np.sum(train_label == 1)}, {np.sum(eval_label1 == 1)}, {np.sum(eval_label2 == 1)}, {np.sum(test_label == 1)})')
    # train_data, train_label = expand_data(train_data, train_label)

    np.save(os.path.join(base_dir, 'train_data.npy'), train_data)
    np.save(os.path.join(base_dir, 'train_label.npy'), train_label)

    np.save(os.path.join(base_dir, 'valid1_data.npy'), eval_data1)
    np.save(os.path.join(base_dir, 'valid1_label.npy'), eval_label1)

    np.save(os.path.join(base_dir, 'valid2_data.npy'), eval_data2)
    np.save(os.path.join(base_dir, 'valid2_label.npy'), eval_label2)

    np.save(os.path.join(base_dir, 'test_data.npy'), test_data)
    np.save(os.path.join(base_dir, 'test_label.npy'), test_label)



class Gisette(BaseDataset):
    def __init__(self, dtype='train', max_sample_size=None):
        super().__init__(base_dir, dtype, max_sample_size)


# class Gisette(Dataset):
#     mean = info['mean']
#     std = info['std']
#
#     def __init__(self, type='train', norm=True):
#         if type in ['valid', 'eval', 'test']:
#             type = 'valid'
#         self.type = type
#         self.data = load_data(os.path.join(base_dir, f'gisette_{self.type}.data'))
#         self.label = load_data(os.path.join(base_dir, f'gisette_{self.type}.labels'))
#
#         if norm:
#             self.data = (self.data - self.mean) / (self.std + 1e-8)
#             # self.data[np.isnan(self.data)] = 0
#
#         self.label[np.where(self.label < 0)] = 0
#
#         self.data = self.data.astype('float32')
#         self.label = self.label.astype('int64').reshape(-1)
#
#     def __getitem__(self, item):
#         return self.data[item], self.label[item]
#
#     def __len__(self):
#         return len(self.label)
