import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import argparse
from sklearn.preprocessing import LabelEncoder
from keras.datasets import mnist, fashion_mnist

from dataloading import MNISTDataset, torch_train_val_split
from models import LeNet5, CNN2D, deepNN
from training import train, eval
from pruning import pruning_experiment

parser = argparse.ArgumentParser(description='Run second pruning experiment.')
parser.add_argument('--model', type=str, default='LeNet5',
                    help='CNN model. Possible Options: "LeNet5", "CNN2D"')
parser.add_argument('--dataset', default='MNIST',
                    help='Dataset used. Possible options: "MNIST", "FASHION_MNIST"')
parser.add_argument('--image_size', type=int, default=32,
                    help='Input image size. Default 32')
args = parser.parse_args()


################# Configuration  ######################
DATASET = args.dataset
IMAGE_SIZE = args.image_size

# Training parameters 
BATCH_SIZE = 128
EPOCHS = 20
EXECUTIONS = 5 # How many models to train in total
PATH = 'models/' # path for model save and loading
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model Parameters
MODEL = args.model
CONV_FILTERS = 16 if DATASET == 'MNIST' else 32
HIDDEN_SIZE = 500 if DATASET == 'MNIST' else 1000


############# Datasets and Dataloaders ################
if DATASET == "MNIST":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    output_size = 10

elif DATASET == "FASHION_MNIST":
    (X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
    output_size = 10

else:
    raise ValueError("Invalid dataset")

# resizing input image
transform = transforms.Compose([
transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE))
])

train_data = transform(torch.tensor(X_train)).unsqueeze(1).float() / 255 
train_data = train_data.to(DEVICE)

# convert data labels from strings to integers
le = LabelEncoder() # Creating a label encoder
y_train = le.fit_transform(y_train)  # Encoding train labels
y_test = le.fit_transform(y_test)  # Encoding test labels

# Define our PyTorch-based Dataset
train_set = MNISTDataset(X_train, y_train, transform=transform)
test_set = MNISTDataset(X_test, y_test, transform=transform)
    
# Defining DataLoaders
train_loader, val_loader = torch_train_val_split(train_set, 
                                                    BATCH_SIZE, 
                                                    val_size=.3, 
                                                    shuffle=True)
test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 


############ Model, Criterion, Optimizer ##############
for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability
    if MODEL == 'deepNN':
        # Default input image size = 28
        model = deepNN(n_classes=output_size) 

    elif MODEL == 'LeNet5':
        # Default input image size = 32
        model = LeNet5(n_classes=output_size) 
    
    elif MODEL == 'CNN2D':
        # Default input image size = 28
        model = CNN2D(n_classes=output_size, hidden_size=HIDDEN_SIZE, conv_filters=CONV_FILTERS) 

    else:
        raise ValueError("Invalid model's name")

    # move the mode weight to cpu or gpu
    model = model.to(DEVICE)
    print(model)
    
    # Criterion and optimizer selection
    criterion = nn.CrossEntropyLoss() # for multiclass classification
    parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


###################### Training  ######################
    model = train(model,
                EPOCHS,
                optimizer,
                criterion,
                scheduler,
                train_loader,
                val_loader)
            
    torch.save(model, PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    accuracy = eval(model, test_loader, criterion)
    print('Test Accuracy of model is {:.2f}'.format(100 * accuracy))


############## Compression Experiment  ###############
info = {
    'name' : MODEL,
    'dataset' : DATASET,
    'imsize' : IMAGE_SIZE,
    'ratios' : [1, 0.5, 0.25, 0.1, 0.05] if MODEL == 'CNN2D' else [1, 0.9, 0.75, 0.5, 0.3, 0.25, 0.2, 0.15, 0.1, 0.05],
    'repetitions' : 5 # How many times to repeat compression algorithm
}

for i in range(EXECUTIONS): # We execute the experiment in many models to ensure its stability
    
    model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    # move the mode weight to cpu or gpu
    model = model.to(DEVICE)
    print(model)

    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='neural_path_kmeans')

    if( not MODEL=='CNN2D'):
        pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='thinet', dataset=train_data, w2_rescale=True)
        pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='random_structured')
        pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='l1_structured')
