# basic functions
import os
import sys
import math
import numpy as np
import shutil
import setproctitle
import argparse
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix

# torch functions
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision
from torchvision import models, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
# import torch_optimizer as optim

# local functions
from SpinalVGGnet import SpinalVGG
from kernelmodel import *
from tr import *
#------------------------------------------------------------------------------

# arguments setting
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=128, help='mini batch size')
parser.add_argument('--latent_dim', type=int, default=8, help='the dimension of latent space')
parser.add_argument('--nEpochs', type=int, default=200, help='the number of outter loop')
parser.add_argument('--cuda_device', type=int, default=0, help='choose cuda device')
parser.add_argument('--no-cuda', action='store_true', help='if TRUE, cuda will not be used')
parser.add_argument('--save', help='path to save results')
parser.add_argument('--lr', type=float, default=1e-4, help='tuning parameter-lambda')
parser.add_argument('--dataset', type=str, default="mnist", help='mnist or kmnist')
parser.add_argument('--seed', type=int, default=1, help='random seed')
args = parser.parse_args()

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
args.save = args.save or 'Classification/Results/MNIST_8_test'
setproctitle.setproctitle(args.save)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.cuda_device)

if os.path.exists(args.save):
    shutil.rmtree(args.save)
os.makedirs(args.save, exist_ok=True)

# get dataloaders
trainTransform = transforms.Compose([
    torchvision.transforms.RandomPerspective(), 
    torchvision.transforms.RandomRotation(10, fill=(0,)), 
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
    (0.1307,), (0.3081,)
)])
testTransform = transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
    (0.1307,), (0.3081,)
)])

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
if args.dataset == 'mnist':
    train_set = dset.MNIST(root='mnist', train=True, download=True, 
                           transform=trainTransform)
    test_set = dset.MNIST(root='mnist', train=False, download=True,
                     transform=testTransform)
if args.dataset == 'kmnist':
    train_set = dset.KMNIST(root='kmnist', train=True, download=True, 
                           transform=trainTransform)
    test_set = dset.KMNIST(root='kmnist', train=False, download=True,
                           transform=testTransform)
trainLoader = DataLoader(train_set, batch_size=args.batchSz, 
                         shuffle=True, **kwargs)
testLoader = DataLoader(test_set,batch_size=args.batchSz, 
                        shuffle=False, **kwargs)


R_net = SpinalVGG(Ldim = args.latent_dim)
D_net = Discriminator(ndim = args.latent_dim)

print('  + Number of params (net) : {}'.format(
    sum([p.data.nelement() for p in R_net.parameters()])))
print('  + Number of params (Dnet) : {}'.format(
    sum([p.data.nelement() for p in D_net.parameters()])))
if args.cuda:
    R_net = R_net.cuda()
    D_net = D_net.cuda()


optimizer_R = optim.Adam(R_net.parameters(), lr=0.005)
optimizer_D = optim.Adam(D_net.parameters(), weight_decay=1e-4)

trainF = open(os.path.join(args.save, 'train.csv'), 'w')
testF = open(os.path.join(args.save, 'test.csv'), 'w')
f = open(os.path.join(args.save, 'res.txt'), 'w')
#------------------------------------------------------------------------------

# train models
for epoch in range(1, args.nEpochs + 1):
    train(args, epoch, R_net, D_net, trainLoader, optimizer_R, optimizer_D, trainF, f, device)
    test(args, epoch, R_net, testLoader, optimizer_R, testF, f, device)
    torch.save(R_net.state_dict(), os.path.join(args.save, 'R.pt'))
    torch.save(D_net.state_dict(), os.path.join(args.save, 'D.pt'))
trainF.close()
testF.close()

#------------------------------------------------------------------------------

# evaluate models
R_net.eval()
torch.cuda.empty_cache()
X_train, y_train = npLoader(trainLoader, R_net, device)
X_test, y_test = npLoader(testLoader, R_net, device)

scaler = StandardScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)

# KNN for classification
classifier = KNeighborsClassifier(n_neighbors=5)
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)
print(confusion_matrix(y_test, y_pred))
print(classification_report(y_test, y_pred))

acc = 100 * np.sum(y_pred == y_test) / y_pred.shape
print('Accuracy: %f' % acc)
print('Done!')
f.write('Accuracy: %f \n' % acc)
f.write('Done!\n')
f.close()