import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from src.dataloaders.utils import load_sparse_data, load_data, to_sparse_bin_data, BaseDataset, expand_data


base_dir = os.path.join(os.path.split(__file__)[0], '../../data/qsar/')


if not all([os.path.exists(os.path.join(base_dir, 'support_target_data.npy')),
            os.path.exists(os.path.join(base_dir, 'support_target_label.npy'))]):
    data = pd.read_csv(os.path.join(base_dir, 'qsar_oral_toxicity.csv'), sep=';', header=None)
    label = np.array(data[[1024]].squeeze().apply(lambda x: 0 if x == 'negative' else 1))
    data = np.array(data[np.arange(1024)])

    # shuffle
    idx = np.arange(len(data))
    np.random.shuffle(idx)
    data = data[idx]
    label = label[idx]
    # split
    train_data, eval_data1, eval_data2, test_data = data[:int(0.7 * len(data))], data[int(0.7 * len(data)):int(0.8 * len(data))], \
                                                    data[int(0.8 * len(data)):int(0.85 * len(data))], data[int(0.85 * len(data)):]
    train_label, eval_label1, eval_label2, test_label = label[:int(0.7 * len(data))], label[int(0.7 * len(data)):int(0.8 * len(data))], \
                                                        label[int(0.8 * len(data)):int(0.85 * len(data))], label[int(0.85 * len(data)):]
    # train_data, eval_data1, eval_data2, test_data = data[:int(0.5 * len(data))], data[int(0.5 * len(data)):int(0.65 * len(data))], \
    #                                                 data[int(0.65 * len(data)):int(0.8 * len(data))], data[int(0.8 * len(data)):]
    # train_label, eval_label1, eval_label2, test_label = label[:int(0.5 * len(data))], label[int(0.5 * len(data)):int(0.65 * len(data))], \
    #                                                     label[int(0.65 * len(data)):int(0.8 * len(data))], label[int(0.8 * len(data)):]

    print(f'resplit data, the pos in data is: ({np.sum(train_label == 1)}, {np.sum(eval_label1 == 1)}, {np.sum(eval_label2 == 1)}, {np.sum(test_label == 1)})')
    support_data, support_label = expand_data(np.concatenate([train_data, eval_data1, eval_data2]),
                                              np.concatenate([train_label, eval_label1, eval_label2]))

    np.save(os.path.join(base_dir, 'support_target_data.npy'), support_data.astype('int8'))
    np.save(os.path.join(base_dir, 'support_target_label.npy'), support_label.astype('int8'))

    train_data, train_label = expand_data(train_data, train_label)

    np.save(os.path.join(base_dir, 'train_data.npy'), train_data.astype('int8'))
    np.save(os.path.join(base_dir, 'train_label.npy'), train_label.astype('int8'))

    np.save(os.path.join(base_dir, 'valid1_data.npy'), eval_data1.astype('int8'))
    np.save(os.path.join(base_dir, 'valid1_label.npy'), eval_label1.astype('int8'))

    np.save(os.path.join(base_dir, 'valid2_data.npy'), eval_data2.astype('int8'))
    np.save(os.path.join(base_dir, 'valid2_label.npy'), eval_label2.astype('int8'))

    np.save(os.path.join(base_dir, 'test_data.npy'), test_data.astype('int8'))
    np.save(os.path.join(base_dir, 'test_label.npy'), test_label.astype('int8'))


class QSAR(BaseDataset):
    def __init__(self, dtype='train', max_sample_size=None):
        super().__init__(base_dir, dtype, max_sample_size)
