from __future__ import print_function
import argparse
import torch
import torch.optim as optim
from gcommand_loader import GCommandLoader
import numpy as np
from model import LeNet, VGG
from train import train, test
from torch.optim.lr_scheduler import MultiStepLR,StepLR

import os


# Training settings
parser = argparse.ArgumentParser(
    description='ConvNets for Speech Commands Recognition')
parser.add_argument('--train_path', default='data/speechdata/train',
                    help='path to the train data folder')
parser.add_argument('--test_path', default='data/speechdata/test',
                    help='path to the test data folder')
parser.add_argument('--valid_path', default='data/speechdata/valid',
                    help='path to the valid data folder')
parser.add_argument('--batch_size', type=int, default=256,
                    metavar='N', help='training and valid batch size')
parser.add_argument('--test_batch_size', type=int, default=256,
                    metavar='N', help='batch size for testing')
parser.add_argument('--arc', default='LeNet',
                    help='network architecture: LeNet, VGG11, VGG13, VGG16, VGG19')
parser.add_argument('--epochs', type=int, default=150,
                    metavar='N', help='number of epochs to train')
parser.add_argument('--lr', type=float, default=0.001,
                    metavar='LR', help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9,
                    metavar='M', help='SGD momentum, for SGD only')
parser.add_argument('--optimizer', default='adam',
                    help='optimization method: sgd | adam')
parser.add_argument('--cuda', default=True, help='enable CUDA')
parser.add_argument('--seed', type=int, default=1234,
                    metavar='S', help='random seed')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='num of batches to wait until logging train status')
parser.add_argument('--patience', type=int, default=5, metavar='N',
                    help='how many epochs of no loss improvement should we wait before stop training')

# feature extraction options
parser.add_argument('--window_size', default=.02,
                    help='window size for the stft')
parser.add_argument('--window_stride', default=.01,
                    help='window stride for the stft')
parser.add_argument('--window_type', default='hamming',
                    help='window type for the stft')
parser.add_argument('--normalize', default=True,
                    help='boolean, wheather or not to normalize the spect')

args = parser.parse_args()

# args.cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

# loading data
train_dataset = GCommandLoader(args.train_path, window_size=args.window_size, window_stride=args.window_stride,
                               window_type=args.window_type, normalize=args.normalize)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=6, pin_memory=args.cuda, sampler=None)

valid_dataset = GCommandLoader(args.valid_path, window_size=args.window_size, window_stride=args.window_stride,
                               window_type=args.window_type, normalize=args.normalize)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=args.batch_size, shuffle=None,
    num_workers=2, pin_memory=args.cuda, sampler=None)

test_dataset = GCommandLoader(args.test_path, window_size=args.window_size, window_stride=args.window_stride,
                              window_type=args.window_type, normalize=args.normalize)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=args.test_batch_size, shuffle=None,
    num_workers=2, pin_memory=args.cuda, sampler=None)

# build model
if args.arc == 'LeNet':
    model = LeNet()
elif args.arc.startswith('VGG'):
    model = VGG(args.arc)
else:
    model = LeNet()

if args.cuda:
    print('Using CUDA with {0} GPUs'.format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model).cuda()

# define optimizer
if args.optimizer == 'adam':
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'sgdm':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=0.0005)
elif args.optimizer == 'sgd':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005)
elif args.optimizer == 'Adadelta':
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
elif args.optimizer =='sgdn':
    optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=0.0005, momentum=0.9, nesterov=True)

scheduler = MultiStepLR(optimizer, milestones=[50,100],gamma=0.1)
best_valid_loss = np.inf
iteration = 0

filename = open('gcommands/file/'+str(args.epochs)+'epochs'+str(args.lr)+str(args.optimizer)+'.txt', 'w')

for epoch in range(1, args.epochs + 1):
    train_loss, train_acc = train(train_loader, model, optimizer, epoch, args.cuda, args.log_interval)
    if args.optimizer != 'Adadelta':
        scheduler.step(epoch)
    valid_loss, valid_acc = test(valid_loader, model, args.cuda)
    test_loss, test_acc = test(test_loader, model, args.cuda)
    filename.write(str(train_loss)+'  '+str('%.4f'%train_acc)+'  '+str('%.4f'%valid_acc)+'  '+str('%.4f'%test_acc)+'\n')

