# basic functions
import os
import sys
import math
import numpy as np
import shutil
import setproctitle
import argparse
import matplotlib.pyplot as plt
# torch functions
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader, TensorDataset
# local functions
from models import *
from tr import *
from densenet import DenseNet
#------------------------------------------------------------------------------

# arguments setting
parser = argparse.ArgumentParser()
parser.add_argument('--batchSz', type=int, default=64, help='mini batch size')
parser.add_argument('--latent_dim', type=int, default=2, help='the dimension of latent space')
parser.add_argument('--nEpochs', type=int, default=1, help='the number of outter loop')
parser.add_argument('--cuda_device', type=int, default=0, help='choose cuda device')
parser.add_argument('--no-cuda', action='store_true', help='if TRUE, cuda will not be used')
parser.add_argument('--save', help='path to save results')
parser.add_argument('--lr', type=float, default=1000, help='tuning parameter-lambda')
parser.add_argument('--seed', type=int, default=1, help='random seed')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
args.save = args.save or 'Results/Sim'
setproctitle.setproctitle(args.save)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.cuda_device)
if os.path.exists(args.save):
    shutil.rmtree(args.save)
os.makedirs(args.save, exist_ok=True)

X_train, X_test, y_train, y_test, idx1, idx2 = s_curve(n_points = 20000, args = args)
train_dat = TensorDataset(torch.from_numpy(X_train).float(), torch.from_numpy(y_train))
trainLoader = DataLoader(train_dat, batch_size=args.batchSz, shuffle=True)
test_dat = TensorDataset(torch.from_numpy(X_test).float(), torch.from_numpy(y_test))
testLoader = DataLoader(test_dat, batch_size=args.batchSz, shuffle=False)

# nets and optimizers setting
D_net = Discriminator(ndim = args.latent_dim)
R_net = DenseNet(growthRate=12, depth=20, reduction=0.5,
                        bottleneck=True, ndim = args.latent_dim, nClasses=10)

print('  + Number of params (R_net) : {}'.format(
    sum([p.data.nelement() for p in R_net.parameters()])))
print('  + Number of params (D_net) : {}'.format(
    sum([p.data.nelement() for p in D_net.parameters()])))

if args.cuda:
    R_net = R_net.cuda()
    D_net = D_net.cuda()

optimizer_R = optim.Adam(R_net.parameters(),  weight_decay=1e-4)
optimizer_D = optim.Adam(D_net.parameters(),  weight_decay=1e-4)

trainF = open(os.path.join(args.save, 'train.csv'), 'w')
testF = open(os.path.join(args.save, 'test.csv'), 'w')

#------------------------------------------------------------------------------
# train models
for epoch in range(1, args.nEpochs + 1):
    train(args, epoch, R_net, D_net, trainLoader, optimizer_R, optimizer_D, trainF, device)
    test(args, epoch, R_net, testLoader, optimizer_R, testF, device)
    torch.save(R_net.state_dict(), os.path.join(args.save, 'R.pt'))
    torch.save(D_net.state_dict(), os.path.join(args.save, 'D.pt'))
    if epoch % 1 ==0:
        fig = plt.figure(figsize=(8, 8))
        X_test, y_test = npLoader(testLoader, R_net, device)
        sns.set_style("whitegrid")
        plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=plt.cm.Spectral)
        plt.xlabel('Feature-one')
        plt.ylabel('Feature-two')
        plt.rcParams.update({'font.size': 20})
        plt.savefig(os.path.join(args.save, 'latent_{}.png'.format(epoch)),dpi=30)
        plt.show()
trainF.close()
testF.close()

print("Done")