import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import argparse
from sklearn.preprocessing import LabelEncoder
from keras.datasets import mnist
import torchvision.transforms as transforms
import os

from dataloading import MNISTDataset, torch_train_val_split
from models import CNN2D
from pruning import pruning_experiment
from training import train, eval

parser = argparse.ArgumentParser(description='Run first pruning experiment.')
parser.add_argument('--dataset', default='MNIST-3_5',
                    help='Dataset used. "MNIST-3_5", "MNIST-4_9"')
parser.add_argument('--bound', default='N',
                    help='Whether to print theoretical upper bound of algorithms. Y/N')
args = parser.parse_args()

################# Configuration  ######################
DATASET = args.dataset
IMAGE_SIZE = 28 # default for mnist dataset

# Training parameters 
BATCH_SIZE = 128
EPOCHS = 20
EXECUTIONS = 5 # How many models to train in total
PATH = 'models/' # path for models save and loading
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model Specifications
MODEL = 'CNN2D' 
HIDDEN_SIZE = 1000

############# Datasets and Dataloaders ################
if DATASET == "MNIST-3_5":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = [X_train[i] for i in range(len(X_train)) if y_train[i] in [3, 5]]
    y_train = [y_train[i] for i in range(len(y_train)) if y_train[i] in [3, 5]]
    X_test = [X_test[i] for i in range(len(X_test)) if y_test[i] in [3, 5]]
    y_test = [y_test[i] for i in range(len(y_test)) if y_test[i] in [3, 5]]
    output_size = 1

elif DATASET == "MNIST-4_9":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = [X_train[i] for i in range(len(X_train)) if y_train[i] in [4, 9]]
    y_train = [y_train[i] for i in range(len(y_train)) if y_train[i] in [4, 9]]
    X_test = [X_test[i] for i in range(len(X_test)) if y_test[i] in [4, 9]]
    y_test = [y_test[i] for i in range(len(y_test)) if y_test[i] in [4, 9]]
    output_size = 1

else:
    raise ValueError("Invalid dataset")

le = LabelEncoder() # Creating a label encoder
y_train = le.fit_transform(y_train).astype('float32')  # Encoding train labels
y_test = le.fit_transform(y_test).astype('float32')  # Encoding test labels

# Define our PyTorch-based Dataset
train_set = MNISTDataset(X_train, y_train, transform=None)
test_set = MNISTDataset(X_test, y_test, transform=None)

# resizing input image
transform = transforms.Compose([
transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE))
])

train_data = transform(torch.tensor(X_train)).unsqueeze(1).float() / 255 
train_data = train_data.to(DEVICE)

# Defining DataLoaders
train_loader, val_loader = torch_train_val_split(train_set, 
                                                    BATCH_SIZE, 
                                                    val_size=.3, 
                                                    shuffle=True)
test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 


############ Model, Criterion, Optimizer ##############
for i in range(EXECUTIONS):

    model = CNN2D(n_classes=output_size, hidden_size=HIDDEN_SIZE)

    model = model.to(DEVICE)
    print(model)
    
    # Criterion and optimizer selection
    criterion = nn.BCEWithLogitsLoss() # nn.BCELoss() # binary classification
    parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

###################### Training  ######################
    if os.path.exists(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i)):
        model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))
    else:
        model = train(model,
                    EPOCHS,
                    optimizer,
                    criterion,
                    scheduler,
                    train_loader,
                    val_loader)
                
        torch.save(model, PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    accuracy = eval(model, test_loader, criterion)
    print('Test Accuracy of model is {:.2f}'.format(100 * accuracy))


############## Compression Experiment  ###############
info = {
    'name' : MODEL,
    'dataset' : DATASET,
    'imsize' : IMAGE_SIZE,
    'ratios' : [1, 0.1, 0.05, 0.01, 0.005, 0.003],
    'repetitions' : 5 # How many times to repeat compression algorithm
}

BOUND = True if args.bound in ['Y','y'] else False

for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability
    
    model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    # move the mode weight to cpu or gpu
    model = model.to(DEVICE)
    print(model)

    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method = 'zonotope_kmeans', print_upper_bound=BOUND)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method = 'improved_zonotope_kmeans', print_upper_bound=BOUND)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='neural_path_kmeans', print_upper_bound=BOUND)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='tropnnc', print_upper_bound=BOUND)

