"""MNIST dataset, we take the MNIST dataset
https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html
and we convert this to be tabular by using STG to select the 20 most
predictive features.
"""


import os
import os.path as osp

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as T
from torchvision.datasets import MNIST

from datasets.preprocessing_utils import preprocess_and_save_data


SEED = 3295
np.random.seed(SEED)
path = osp.join("datasets", "data", "mnist")
train_size = 50_000
val_size = 10_000


if __name__ == "__main__":
  mnist_train_val = MNIST(root=path, train=True, download=True, transform=T.ToTensor())
  mnist_test = MNIST(root=path, train=False, download=True, transform=T.ToTensor())

  X_train = torch.stack([batch[0].view(-1) for batch in mnist_train_val], dim=0).float()
  y_train = torch.tensor([batch[1] for batch in mnist_train_val]).long()

  X_test = torch.stack([batch[0].view(-1) for batch in mnist_test], dim=0).float()
  y_test = torch.tensor([batch[1] for batch in mnist_test]).long()

  # Best features chosen by STG.
  best_features = np.array([348, 243, 327, 269, 153, 430, 543, 271, 461, 154,
                            210, 427, 295, 409, 655, 375, 350, 211, 405, 514])

  X_train = X_train[:, best_features]
  X_test = X_test[:, best_features]

  # Shuffle train set to get train and val set
  shuffle_ids = np.random.permutation(X_train.shape[0])
  X_train = X_train[shuffle_ids]
  y_train = y_train[shuffle_ids]

  X = torch.cat([X_train, X_test], dim=0).float()
  y = torch.cat([y_train, y_test], dim=0).long()

  dataset_dict = {
    "num_con_features": X.shape[1],
    "num_cat_features": 0,
    "most_categories": 0,
    "out_dim": 10,
    "metric": "accuracy",
    "max_dim": None,
  }

  preprocess_and_save_data(
    path=path,
    dataset_dict=dataset_dict,
    train_size=train_size,
    val_size=val_size,
    X=X,
    y=y,
    M=None,
    shuffle=False,
    num_bins=200,
    size_normal=1e-5,
    ratio_uniform=0.2,
  )