import argparse
import numpy as np
import pandas as pd
import torch
from torch import nn
import time
from time import gmtime, strftime
import csv
import copy
import random
import math
import sys

import torchvision
from torchvision import datasets, transforms

def data_prepare(args):

    if args["model"] == "ResNet":
        print("ResNet transform!")
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        
    else:
        transform=transforms.ToTensor()
        transform_train, transform_test = [transform, transform]
    
    if args["dataset"] in ["MNIST", "FMNIST"] and (args["model"] in ["lenet_5", "L4_Net"]):  
        print("haha")
        transform = transforms.Compose([
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])

       
    if args["dataset"] == 'MNIST':
    
        train_dataset = torchvision.datasets.MNIST(root='Datasets', 
                                           train=True, 
                                           transform=transform,  
                                           download=True)
        test_dataset = torchvision.datasets.MNIST(root='Datasets', 
                                          train=False, 
                                          transform=transform)
        
    elif args["dataset"] == 'FMNIST':
    
        train_dataset = torchvision.datasets.FashionMNIST(root='Datasets', 
                                           train=True, 
                                           transform=transform,  
                                           download=True)
        test_dataset = torchvision.datasets.FashionMNIST(root='Datasets', 
                                          train=False, 
                                          transform=transform)
        
    elif args["dataset"] == 'CIFAR10':
        

    
        train_dataset = torchvision.datasets.CIFAR10(root='Datasets', 
                                           train=True, 
                                           transform=transform_train,  
                                           download=True)
        test_dataset = torchvision.datasets.CIFAR10(root='Datasets', 
                                          train=False, 
                                          transform=transform_test)
    
    elif args["dataset"] == 'SVHN':
    
        train_dataset = torchvision.datasets.SVHN(root='Datasets', 
                                                  split='train', 
                                                  transform=transform, 
                                                  download=True)
        test_dataset = torchvision.datasets.SVHN(root='Datasets', 
                                          split='test',
                                          transform=transform,
                                                download=True)
        
    elif args["dataset"] == 'CIFAR100':
    
        train_dataset = torchvision.datasets.CIFAR100(root = 'Datasets', 
                                                      train=True, 
                                                      transform=None,     
                                                      target_transform=None, 
                                                      download=True)
        
        test_dataset = torchvision.datasets.CIFAR100(root = 'Datasets', 
                                                     train=False, 
                                                     transform=None, 
                                                     target_transform=None, 
                                                     download=True)
        
    return train_dataset, test_dataset

def random_mini_batches(X, Y, mini_batch_size, seed = 0):
    """
    Creates a list of random minibatches from (X, Y)

    Arguments:
    X -- input data, of shape (input size, number of examples)
    Y -- true "label" vector (1 for blue dot / 0 for red dot), of shape (1, number of examples)
    mini_batch_size -- size of the mini-batches, integer

    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """
    
    np.random.seed(seed)            # To make your "random" minibatches the same as ours
    m = X.shape[0]                  # number of training examples
    mini_batches = []

    # Step 1: Shuffle (X, Y)

    permutation = list(np.random.permutation(m))
    shuffled_X = X[permutation, :]
    shuffled_Y = Y[permutation, :].reshape((m, Y.shape[1]))

    # Step 2: Partition (shuffled_X, shuffled_Y). Minus the end case.
    num_complete_minibatches = math.floor(m/mini_batch_size) # number of mini batches of size mini_batch_size in your partitionning
    for k in range(0, num_complete_minibatches):
        mini_batch_X = shuffled_X[k*mini_batch_size : (k+1)*mini_batch_size, :]
        mini_batch_Y = shuffled_Y[k*mini_batch_size : (k+1)*mini_batch_size, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)

    # Handling the end case (last mini-batch < mini_batch_size)
    if m % mini_batch_size != 0:
        mini_batch_X = shuffled_X[num_complete_minibatches*mini_batch_size:, :]
        mini_batch_Y = shuffled_Y[num_complete_minibatches*mini_batch_size:, :]
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)

    return mini_batches


