from genericpath import exists
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            transforms.Resize([3, 3]),
            ])
train_dataset = datasets.MNIST('../data', train=True, download=True,
                        transform=transform)
test_dataset = datasets.MNIST('../data', train=False,
                        transform=transform)
train_samples = torch.cat([train_dataset[x][0] for x in range(len(train_dataset))], dim=0)
test_samples = torch.cat([test_dataset[x][0] for x in range(len(test_dataset))], dim=0)
train_labels = torch.Tensor([train_dataset[x][1] for x in range(len(train_dataset))]).long()
test_labels = torch.Tensor([test_dataset[x][1] for x in range(len(test_dataset))]).long()

train_samples = train_samples - train_samples.mean()
test_samples = test_samples - test_samples.mean()

train_samples = train_samples.gt(0).long()
test_samples = test_samples.gt(0).long()

for i in range(10):
    train_index = []
    test_index = []
    for j in range(len(train_labels)):
        if train_labels[j] == i:
            train_index.append(j)
    for j in range(len(test_labels)):
        if test_labels[j] == i:
            test_index.append(j)
    torch.save(train_samples[train_index], './data/train/samples_{}.pt'.format(i))
    torch.save(test_samples[test_index], './data/test/samples_{}.pt'.format(i))
    print('--------------------{}-----------------'.format(i))
    print(len(train_index))
    print(len(test_index))