import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

## Code modified from:
## https://github.com/xehartnort/Purchase100-dataset
## For downloading purchase100.npz:
## https://github.com/xehartnort/Purchase100-dataset/releases/download/v2.0/purchase100.npz
class Purchase100(Dataset):
    """Purchase100 dataset."""

    def __init__(self, root_dir='./purchase100.npz', perm=None):
    
        data = np.load(root_dir)

        # np.random.seed(1234)
        perm = perm if perm is not None else np.random.permutation(60000)

        features = data['features']
        self.labels = torch.LongTensor(data['labels'])[perm]

        self.features = torch.FloatTensor(np.stack(list(map(lambda x: x*np.sqrt(np.sum(x>0)), features))))[perm]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return self.features[idx], self.labels[idx]

class Purchase_Net(nn.Module):
    """
    Simple FCN for Purchase-100
    """
    def __init__(self):
        super(Purchase_Net, self).__init__()
        self.fc1 = nn.Linear(600, 128)
        self.fc2 = nn.Linear(128, 100)

    def forward(self, x):
        x = F.tanh(self.fc1(x))
        x = self.fc2(x)
        return x

def random_flip(fets, ratio):
    """
    Random flip a given ratio of features
    """
    inds = torch.randperm(600)[:int(600*ratio)]
    mask = torch.zeros(600)
    mask[inds] += 1
    return 1.0*torch.logical_xor(fets, mask)