#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torchvision, torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


def load_database_CIFAR10() :
    global X_train_rescale
    global Y_train_rescale
    
    global X_test_rescale
    global Y_test_rescale
    
    training_data = torchvision.datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
    )
    test_data = torchvision.datasets.CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
    )
    
    train_dataloader = DataLoader(training_data, batch_size=50000, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=10000, shuffle=True)
    X_train_rescale, train_labels_indx = next(iter(train_dataloader))
    X_test_rescale, test_labels_indx = next(iter(test_dataloader))
    
    #X_train_rescale = torch.flatten(X_train_rescale, start_dim=1, end_dim=-1)
    #X_test_rescale = torch.flatten(X_test_rescale, start_dim=1, end_dim=-1)
    Y_train_rescale = torch.zeros((50000, 10))
    Y_train_rescale[torch.arange(50000), train_labels_indx] = torch.ones(50000)
    Y_test_rescale = torch.zeros((10000, 10))
    Y_test_rescale[torch.arange(10000), test_labels_indx] = torch.ones(10000)
    
    
    
    
    
    
def load_database_MNIST() :
    global X_train_rescale
    global Y_train_rescale
    
    global X_test_rescale
    global Y_test_rescale
    
    training_data = torchvision.datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
    )

    test_data = torchvision.datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
    )
    
      
    train_dataloader = DataLoader(training_data, batch_size=60000, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=10000, shuffle=True)
    
    X_train_rescale, train_labels_indx = next(iter(train_dataloader))
    X_test_rescale, test_labels_indx = next(iter(test_dataloader))
    
    X_train_rescale = X_train_rescale[:, 0, :, :]
    X_test_rescale = X_test_rescale[:, 0, :, :]
    
    X_train_rescale = torch.flatten(X_train_rescale, start_dim=1, end_dim=-1)
    X_test_rescale = torch.flatten(X_test_rescale, start_dim=1, end_dim=-1)
    
    Y_train_rescale = torch.zeros((60000, 10))
    Y_train_rescale[torch.arange(60000), train_labels_indx] = torch.ones(60000)
    
    Y_test_rescale = torch.zeros((10000, 10))
    Y_test_rescale[torch.arange(10000), test_labels_indx] = torch.ones(10000)
