import os
import pandas as pd
import numpy as np
from src.dataloaders.utils import BaseDataset, expand_data
from torch.utils.data import Dataset
# from dataloaders.utils import load_sparse_data, load_data, to_sparse_bin_data

import matplotlib.pyplot as plt
import pickle as pk


base_dir = os.path.join(os.path.split(__file__)[0], '../../data/p53/')


if not os.path.exists(os.path.join(base_dir, 'K9_process.data')):
    f = open(os.path.join(base_dir, 'K9.data'), 'r')
    wf = open(os.path.join(base_dir, 'K9_process.data'), 'w')
    for i, line in enumerate(f):
        tmp = line[:-2].replace('?', '')
        wf.write(tmp + '\n')
        print(f'\r {i}/31420', end='')
    wf.close()
    f.close()


if not all([os.path.exists(os.path.join(base_dir, 'support_target_data.npy')),
            os.path.exists(os.path.join(base_dir, 'test_label.npy'))]):

    data = pd.read_csv(os.path.join(base_dir, 'K9_process.data'), header=None, engine='c')

    label = data[[5408]].squeeze().apply(lambda x: 0 if x == 'inactive' else 1)
    label = np.array(label).reshape(-1)
    data = np.array(data[np.arange(5408)])

    not_nan_idx = np.where(~np.isnan(data.mean(1)))[0]
    data = data[not_nan_idx]
    label = label[not_nan_idx]

    # shuffle data
    idx = np.arange(len(data))
    np.random.shuffle(idx)
    data = data[idx]
    label = label[idx]

    # cut data
    train_data, eval_data1, eval_data2, test_data = data[:int(0.6 * len(data))], data[int(0.6 * len(data)):int(0.7 * len(data))], \
                                                    data[int(0.7 * len(data)):int(0.8 * len(data))], data[int(0.8 * len(data)):]
    train_label, eval_label1, eval_label2, test_label = label[:int(0.6 * len(data))], label[int(0.6 * len(data)):int(0.7 * len(data))], \
                                                        label[int(0.7 * len(data)):int(0.8 * len(data))], label[int(0.8 * len(data)):]

    support_data, support_label = expand_data(np.concatenate([train_data, eval_data1, eval_data2]),
                                              np.concatenate([train_label, eval_label1, eval_label2]))
    np.save(os.path.join(base_dir, 'support_target_data.npy'), support_data)
    np.save(os.path.join(base_dir, 'support_target_label.npy'), support_label)

    print(f'resplit data, the pos in data is: ({np.sum(train_label == 1)}, {np.sum(eval_label1 == 1)}, {np.sum(eval_label2 == 1)}, {np.sum(test_label == 1)})')
    train_data, train_label = expand_data(train_data, train_label)

    np.save(os.path.join(base_dir, 'train_data.npy'), train_data)
    np.save(os.path.join(base_dir, 'train_label.npy'), train_label)

    np.save(os.path.join(base_dir, 'valid1_data.npy'), eval_data1)
    np.save(os.path.join(base_dir, 'valid1_label.npy'), eval_label1)

    np.save(os.path.join(base_dir, 'valid2_data.npy'), eval_data2)
    np.save(os.path.join(base_dir, 'valid2_label.npy'), eval_label2)

    np.save(os.path.join(base_dir, 'test_data.npy'), test_data)
    np.save(os.path.join(base_dir, 'test_label.npy'), test_label)


class P53(BaseDataset):
    def __init__(self, dtype='train', max_sample_size=None):
        super().__init__(base_dir, dtype, max_sample_size)
