import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import argparse
import os

from dataloading import MNISTDataset, torch_train_val_split
from models import vgg_bn_drop, vgg_bn_drop_100
from training import train, eval
from pruning import pruning_experiment

device = 'cuda' if torch.cuda.is_available() else 'cpu'

parser = argparse.ArgumentParser(description='Run third pruning experiment.')
parser.add_argument('--model', type=str, default='CIFAR-VGG',
                    help='CNN model. Possible Options: "CIFAR-VGG", "AlexNet"')
parser.add_argument('--dataset', default='CIFAR10',
                    help='Dataset used. Possible options: "CIFAR10", "CIFAR100"')
parser.add_argument('--image_size', type=int, default=32,
                    help='Input image size. Default 32')
parser.add_argument('--epochs', type=int, default=28,
                    help='Number of training epochs. Default 28')
args = parser.parse_args()

################# Configuration  ######################
# Dataset selection
DATASET = args.dataset 
IMAGE_SIZE = args.image_size

# Training parameters 
BATCH_SIZE = 64
EPOCHS = args.epochs
EXECUTIONS = 5 # How many models to train in total
PATH = 'models/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Model Parameters
MODEL = args.model


############# Datasets and Dataloaders ################
transform_train = transforms.Compose([
    transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size=(IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if DATASET == "CIFAR10":
    train_set = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train)
    test_set = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test)

    output_size = 10

elif DATASET == "CIFAR100":
    train_set = torchvision.datasets.CIFAR100(
        root='./data', train=True, download=True, transform=transform_train)
    test_set = torchvision.datasets.CIFAR100(
        root='./data', train=False, download=True, transform=transform_test)

    output_size = 100

train_data = train_set
     
# Defining DataLoaders
train_loader, val_loader = torch_train_val_split(train_set, 
                                                    BATCH_SIZE, 
                                                    val_size=.25, 
                                                    shuffle=True)
test_loader = torch.utils.data.DataLoader(
            test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) 


############ Model, Criterion, Optimizer ##############
for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability

    if MODEL == 'AlexNet':
        # Redifining AlexNet's FC layers sizes to be compatible with CIFAR Dataset
        model = models.alexnet(pretrained=True)
        model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        model.classifier[1]=nn.Linear(256, 512)
        model.classifier[4]=nn.Linear(512, 512)
        model.classifier[6]=nn.Linear(512, output_size)
        # freeze convolution weights
        for param in model.features.parameters():
            param.requires_grad = False  
        
    else:
        raise ValueError("Invalid model's name")

    # move the mode weight to cpu or gpu
    model = model.to(device)
    print(model)

    # Criterion and optimizer selection
    criterion = nn.CrossEntropyLoss()
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters, lr=0.01, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

###################### Training  ######################
    if os.path.exists(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i)):
        model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))
    else:
        model = train(model,
                    EPOCHS,
                    optimizer,
                    criterion,
                    scheduler,
                    train_loader,
                    val_loader)
                
        torch.save(model, PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i))

    accuracy = eval(model, test_loader, criterion)
    print('Test Accuracy of model is {:.2f}'.format(100 * accuracy))


############## Compression Experiment  ###############
info = {
    'name' : MODEL,
    'dataset' : DATASET,
    'imsize' : IMAGE_SIZE,
    'ratios' : [1, 0.9, 0.7, 0.5, 0.3, 0.2, 0.15, 0.10, 0.05, 0.02, 0.01, 0.005],
    'repetitions' : 5 # How many times to repeat compression algorithm
}
for i in range(EXECUTIONS): # We execute the experiment in many models to examine its stability

    model = torch.load(PATH + '{}_trained_in_{}_imsize_{}_version_{}.pkl'.format(MODEL, DATASET, IMAGE_SIZE, i), map_location=DEVICE)

    # move the mode weight to cpu or gpu
    model = model.to(device)
    print(model)

    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='neural_path_kmeans')
    # pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
    #                    method='tropnnc')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='iterative_tropnnc')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='thinet', dataset=train_data, w2_rescale=True if MODEL == "AlexNet" else False)
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='random_structured')
    pruning_experiment(model, train_loader, val_loader, test_loader, criterion, info,
                       method='l1_structured')
