import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import argparse
from sklearn.preprocessing import LabelEncoder
from keras.datasets import mnist
import torchvision.transforms as transforms

from dataloading import MNISTDataset, torch_train_val_split
from models import CNN2D
from pruning import pruning_experiment
from training import train, eval

parser = argparse.ArgumentParser(description='Run first pruning experiment.')
parser.add_argument('--dataset', default='MNIST-3_5',
                    help='Dataset used. "MNIST-3_5", "MNIST-4_9"')
parser.add_argument('--bound', default='N',
                    help='Whether to print theoretical upper bound of algorithms. Y/N')
args = parser.parse_args()

################# Configuration  ######################
DATASET = args.dataset
IMAGE_SIZE = 28 # default for mnist dataset

# Training parameters 
BATCH_SIZE = 128
EPOCHS = 20
EXECUTIONS = 5 # How many models to train in total
PATH = 'models/' # path for models save and loading
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model Specifications
MODEL = 'CNN2D' 
HIDDEN_SIZE = 1000

############# Datasets and Dataloaders ################
if DATASET == "MNIST-3_5":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = [X_train[i] for i in range(len(X_train)) if y_train[i] in [3, 5]]
    y_train = [y_train[i] for i in range(len(y_train)) if y_train[i] in [3, 5]]
    X_test = [X_test[i] for i in range(len(X_test)) if y_test[i] in [3, 5]]
    y_test = [y_test[i] for i in range(len(y_test)) if y_test[i] in [3, 5]]
    output_size = 1

elif DATASET == "MNIST-4_9":
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = [X_train[i] for i in range(len(X_train)) if y_train[i] in [4, 9]]
    y_train = [y_train[i] for i in range(len(y_train)) if y_train[i] in [4, 9]]
    X_test = [X_test[i] for i in range(len(X_test)) if y_test[i] in [4, 9]]
    y_test = [y_test[i] for i in range(len(y_test)) if y_test[i] in [4, 9]]
    output_size = 1

else:
    raise ValueError("Invalid dataset")

le = LabelEncoder() # Creating a label encoder
y_train = le.fit_transform(y_train).astype('float32')  # Encoding train labels
y_test = le.fit_transform(y_test).astype('float32')  # Encoding test labels

# Define our PyTorch-based Dataset
train_set = MNISTDataset(X_train, y_train, transform=None)
test_set = MNISTDataset(X_test, y_test, transform=None)

# resizing input image
transform = transforms.Compose([
transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE))
])

train_data = transform(torch.tensor(X_train)).unsqueeze(1).float() / 255 
train_data = train_data.to(DEVICE)

# Defining DataLoaders
train_loader, val_loader = torch_train_val_split(train_set, 
                                                    BATCH_SIZE, 
                                                    val_size=.3, 
                                                    shuffle=True)
test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 


############ Model, Criterion, Optimizer ##############
for i in range(EXECUTIONS):

    model = CNN2D(n_classes=output_size, hidden_size=HIDDEN_SIZE)

    model = model.to(DEVICE)
    print(model)
    
    # Criterion and optimizer selection
    criterion = nn.BCEWithLogitsLoss() # nn.BCELoss() # binary classification
    parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters, lr=1e-3)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


###################### Training  ######################
    model = train(model,
                EPOCHS,
                optimizer,
                criterion,
                scheduler,
                train_loader,
                val_loader)
            
    torch.save(model, PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    accuracy = eval(model, test_loader, criterion)
    print('Test Accuracy of model is {:.2f}'.format(100 * accuracy))


############## Compression Experiment  ###############
info = {
    'name' : MODEL,
    'dataset' : DATASET,
    'imsize' : IMAGE_SIZE,
    'ratios' : [1, 0.1, 0.05, 0.01, 0.005, 0.003],
    'repetitions' : 5 # How many times to repeat compression algorithm
}

BOUND = True if args.bound in ['Y','y'] else False

for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability
    
    model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    # move the mode weight to cpu or gpu
    model = model.to(DEVICE)
    print(model)

    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method = 'zonotope_kmeans', print_upper_bound=BOUND)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='neural_path_kmeans', print_upper_bound=BOUND)

