import numpy as np
from torch.utils.data import Dataset
from src.verify.trainer.utils import SubShell
import torch
import os


def expand_data(train_data, train_label):
    # expand training data
    expand_num = len(np.where(train_label == 0)) // len(np.where(train_label == 1)) - 1

    train_data = np.concatenate([train_data, train_data[np.where(train_label == 1)].repeat(expand_num, 0)], 0)
    train_label = np.concatenate([train_label, np.ones(expand_num * len(np.where(train_label == 1)[0]))])

    # reshuffle
    idx = np.arange(len(train_data))
    np.random.shuffle(idx)
    train_label = train_label[idx]
    train_data = train_data[idx]
    return train_data, train_label


def load_sparse_data(path, dims=20000):
    with open(path) as f:
        f_data = f.readlines()

    length = len(f_data)

    data = np.zeros([length, dims])
    for i in range(length):
        for item in f_data[i].strip().split():
            k, v = item.split(':')

            data[i, int(k) - 1] = float(v)

    return data


def to_sparse_bin_data(data, dims):
    if isinstance(data[0], (np.ndarray, list)):
        return np.array([to_sparse_bin_data(item, dims) for item in data])

    result = np.zeros([dims])
    result[np.array(data).astype('int')] = 1
    return result


def load_data(path):
    r = []
    with open(path) as f:
        for line in f:
            r.append([float(item) for item in line.strip().split()])

    return np.array(r)


class SelectedDataset(Dataset):
    def __init__(self, *datasets, select_ids=None):
        self.dataset = SubShell(*datasets)
        self.select_ids = select_ids

    def __getitem__(self, item):
        x, y = self.dataset[item]
        if isinstance(x, torch.Tensor):
            return x[self.select_ids].float(), y
        return x[self.select_ids].astype('float32'), int(y)

    def __len__(self):
        return self.dataset.__len__()


class MaskDataset(Dataset):
    def __init__(self, *datasets, select_ids=None):
        self.dataset = SubShell(*datasets)
        self.select_ids = select_ids
        self.not_selected_ids = list(set(np.arange(self.dataset[0][0].shape[-1])) - set(select_ids))
        self.decay = len(select_ids) / self.dataset[0][0].shape[-1]
        
    def __getitem__(self, item):
        x, y = self.dataset[item]
        x[self.not_selected_ids] = 0
        x = x / self.decay
        
        if isinstance(x, torch.Tensor):
            return x.float(), y
        return x.astype('float32'), int(y)

    def __len__(self):
        return self.dataset.__len__()


class BaseDataset(Dataset):
    def __init__(self, base_dir, dtype='train', max_sample_size=None):
        self.max_sample_size = max_sample_size

        if isinstance(dtype, list):
            data = []
            label = []
            for name in dtype:
                data.append(np.load(os.path.join(base_dir, f'{name}_data.npy')))
                label.append(np.load(os.path.join(base_dir, f'{name}_label.npy')))
            self.data = np.concatenate(data)
            self.label = np.concatenate(label)
        else:
            self.data = np.load(os.path.join(base_dir, f'{dtype}_data.npy'))
            self.label = np.load(os.path.join(base_dir, f'{dtype}_label.npy'))

        self.data = self.data.astype('float32')
        self.label = self.label.reshape(-1).astype('int64')

        if max_sample_size is not None:
            self.data = self.data[:max_sample_size]
            self.label = self.label[:max_sample_size]

    def __getitem__(self, item):
        return self.data[item], self.label[item]

    def __len__(self):
        return len(self.data)
