import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from src.dataloaders.utils import load_sparse_data, load_data, to_sparse_bin_data, BaseDataset, expand_data
from scipy import sparse
import pickle as pk


base_dir = os.path.join(os.path.split(__file__)[0], '../../data/shopping/changed/')
config_path = os.path.join(base_dir, 'data_config.yaml')


if not os.path.exists(os.path.join(base_dir, 'train.npz')):

    data_path = os.path.join(base_dir, 'data_changed_csr_merged.npz')
    data = sparse.load_npz(data_path)
    total_nums = data.shape[0]

    idx = np.arange(total_nums)
    np.random.shuffle(idx)
    train_idx, valid_idx, test_idx = idx[:int(0.6 * total_nums)], idx[int(0.6 * total_nums):int(0.8 * total_nums)], \
                                     idx[int(0.8 * total_nums):]

    sparse.save_npz(os.path.join(base_dir, 'train.npz'), data[train_idx])
    sparse.save_npz(os.path.join(base_dir, 'valid.npz'), data[valid_idx])
    sparse.save_npz(os.path.join(base_dir, 'test.npz'), data[test_idx])

    data_mean = np.array(data.mean(0))
    # variance = ((data - data_mean) ** 2).sum(0) / total_nums
    data.data **= 2  # change to sqr
    variance = np.array(data.sum(0) / total_nums) - np.array(data_mean) ** 2
    with open(os.path.join(base_dir, 'config.pkl'), 'wb') as f:
        pk.dump({'mean': data_mean, 'variance': variance}, f)


class ShoppingGender(Dataset):

    def __init__(self, dtype='train', max_sample_size=None):
        if not isinstance(dtype, str):
            raise ValueError
        self.dtype = dtype
        self.max_sample_size = max_sample_size

        with open(os.path.join(base_dir, 'config.pkl'), 'rb') as f:
            cfg = pk.load(f)

        self.mean = cfg['mean']
        self.variance = cfg['variance']
        self.std = np.sqrt(self.variance)
        self.zeros_std_idx = np.where(self.std == 0)[1]

        self._data = sparse.load_npz(os.path.join(base_dir, f'{dtype}.npz'))

    def __getitem__(self, item):
        res = self._data[item].toarray()
        data = ((res - self.mean) / self.std)[0, :-2].astype('float32')
        data[self.zeros_std_idx] = 0
        label = (res[0, -1] - 1).astype('int64')
        return data, label

    def __len__(self):
        if self.max_sample_size is not None:
            return min(self._data.shape[0], self.max_sample_size)
        return self._data.shape[0]

    @property
    def data(self):
        length = len(self)
        res = self._data[:length].toarray()
        data = ((res - self.mean) / self.std)[:, :-2].astype('float32')
        data[:, self.zeros_std_idx] = 0
        return data

    @property
    def label(self):
        length = len(self)
        res = self._data[:length, -1].toarray().reshape(-1)
        label = (res - 1).astype('int64')
        return label
