import argparse
import numpy as np
import os
import time
import pickle
import copy
import gc


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import RandomSampler
import torch.backends.cudnn as cudnn


from dataset import *
from myCNN import smallCNN1,smallCNN2
from SGLD import SGLD, NoisySGD
from signSGD import signSGD, BernoulliSGD
from utils import *






parser = argparse.ArgumentParser(description='SGLD and Noisy Sign-SGD Example')

# Data 
parser.add_argument('--data', type=str, choices=['mnist', 'fashion', 'cifar10'], default='mnist',help='Type of dataset: mnist|fashion|cifar10.')
parser.add_argument('--num-classes', type=int, default=10,help='Number of classes in the dataset. Taken into account only by classifiers and data generators that support multi-class')
parser.add_argument('--num-samples-per-class', type=int, default=1000,help='Number of samples per class to get a subset. 0 means use the entire dataset.')
# Model
parser.add_argument('--arch', type=str, choices=['cnn1','cnn2'], default='cnn1',help='Type of network architecture: cnn1 for mnist and fashion mnist|cnn2 for cifar10.')
# Training
parser.add_argument('--optimizer',type=str, default='sgld',choices=['sgld','noisy_signSGD'])
parser.add_argument('--batch-size', type=int, default= 100, help='Number of samples in a training batch.')
parser.add_argument('--learning-rate', type=float, default=4e-3, help='Learning rate.')
parser.add_argument('--num-epochs', type=int, default=1, help='Cancel training after maximum number of epochs')
parser.add_argument('--decay', type=float, default=0.96,help='learning rate decay')


# SGLD
parser.add_argument('--alpha', type=float, default=55000, help='scaling factor: inverse temperature or direct scaling in SGLD, and the scaling in noisy sign-sgd ')
# Results
parser.add_argument('--save_circ', type=int, default=10,help='Save weights at every N iteration.')
parser.add_argument('--results-folder', type=str, default='results', help='Folder in which to put all results folders')

args = parser.parse_args()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cudnn.benchmark = True



######### Load dataset, set optimizer and learning rate scheduling  ################

batch_size = args.batch_size
num_samples_per_class = args.num_samples_per_class

learning_rate = args.learning_rate
alpha = args.alpha
decay_amount = args.decay


if args.data=='mnist':
    train_loader,test_loader = get_mnist_data(batch_size,num_samples_per_class)
    model = smallCNN1().to(device)
    criterion = nn.NLLLoss().to(device)

    if args.optimizer=='sgld':
        optimizer=SGLD(model.parameters(), lr=4e-3, beta=alpha)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.96)

        
    if args.optimizer=='noisy_signSGD':
        optimizer = BernoulliSGD(model.parameters(),lr = 1e-4, alpha=alpha)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)



if args.data=='fashion':
    train_loader,test_loader = get_fashion_data(batch_size,num_samples_per_class)
    model = smallCNN1().to(device)
    criterion = nn.NLLLoss().to(device)

    if args.optimizer=='sgld':
        optimizer=SGLD(model.parameters(), lr=4e-3, beta=alpha)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.96)
    if args.optimizer=='noisy_signSGD':
        optimizer = BernoulliSGD(model.parameters(),lr = 1e-4, alpha=alpha)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)





if args.data=='cifar10':
    train_loader,test_loader = get_cifar10_data(batch_size,num_samples_per_class)
    model = smallCNN2().to(device)
    criterion = nn.NLLLoss().to(device)

    optimizer=SGLD(model.parameters(), lr=5e-3, beta=alpha)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.995)




###################################################################################

    
num_epochs = args.num_epochs

history = {}
history['train'] = []
history['test'] = []

t=0
############################ start training ########################################
try:
    # ready to go
    for epoch in range(num_epochs):
        model.train()


        if epoch == 0 or epoch==1 or ( epoch % args.save_circ==0) or (epoch==num_epochs-1):
            state = {'state_dict': model.state_dict()} # save parameters at each iteration
            save_torch_results(args,'state_iter_{}'.format(epoch),state)

       


        total = 0
        correct = 0
        losses = 0
        cut_perc = 0

        dataloader_iter = iter(train_loader)
        
        for j in range(len(train_loader)):

            images, labels = next(dataloader_iter)
            images = images.to(device)                    
            labels = labels.to(device)

            if args.optimizer=='sgld':
                adjust_beta(optimizer, t, alpha)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # This is for sgld or noisy signSGD
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()



            # Track the accuracy
            losses+=loss.item()
            total += labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            del images,labels
            gc.collect()
            t+= 1

        # learning rate scheduling
        scheduler.step()
        
        train_loss = losses/(j+1)
        train_acc = correct/total
        print ('Epoch [{}/{}], Train Loss: {:.4f}, Acc: {:.2f}%'.format(epoch+1, num_epochs,train_loss ,train_acc * 100))
        history['train'].append([epoch,train_loss,train_acc])



    test_loss, test_acc = test_eval(test_loader, model, criterion, device)
    print('Epoch [{}/{}], Test Loss: {:.4f}, Acc: {:.2f}%'.format(epoch+1, num_epochs,test_loss ,test_acc * 100))

    history['test'].append([test_loss,test_acc])
    save_results(args,'results',history)


except Exception as e:
    import traceback
    traceback.print_exc()
