import logging
import os
import json
import time
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.backends.cudnn as cudnn
import pickle

from torch.utils.data import DataLoader
from torch import optim

from utils.parameters import get_parameter
from utils.dataset import construct_datasets
from utils.model import Linear, ResNet18, VGG16, MobileNetV2


def train(args, logging, model, loss_func, optimizer, train_loader):
    for epoch in range(args.epoch):
        model.train()
        running_loss, n_batchs, total, correct = 0.0, 0, 0, 0

        for idx, (images, labels, _) in enumerate(train_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            outputs = model(images)
            loss = loss_func(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()
            n_batchs += 1

        loss = running_loss / n_batchs
        accuracy = 100 * correct / total
        print('Epoch %d training loss: %.3f training accuracy: %.3f%%' % (epoch, loss, accuracy))
        logging.info('Epoch %d training loss: %.3f training accuracy: %.3f%%' % (epoch, loss, accuracy))


def test(args, logging, model, test_loader):
    model.eval()
    total, correct = 0, 0

    with torch.no_grad():
        for idx, (images, labels, _) in enumerate(test_loader):
            images, labels = images.to(args.device), labels.to(args.device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Test accuracy: %.2f%%' % (100 * correct / total))
    logging.info('Test accuracy: %.2f%%' % (100 * correct / total))

def main():
    args = get_parameter()

    """# Create a log file"""
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)

    log_path = args.dataset + '-pretraining-log-%s' % (datetime.datetime.now().strftime("%Y-%m-%d-%H:%M-%S"))
    log_path = log_path + '.txt'
    logging.basicConfig(
        filename=os.path.join(args.logdir, log_path),
        format="%(asctime)s - %(name)s - %(message)s",
        datefmt='%d-%b-%y %H:%M:%S', level=logging.INFO, filemode='w')

    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.info(str(args))

    if torch.cuda.is_available():
        cudnn.benchmark = True
    else:
        args.device = "cpu"
    print(f'device: {args.device}')
    logger.info(f'device: {args.device}')

    if not os.path.exists(args.moddir):
        os.makedirs(args.moddir)

    train_data, test_data = construct_datasets(args, args.dataset, args.datadir, load=args.load)
    train_loader = DataLoader(train_data, batch_size=args.batchsize, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_data, batch_size=args.batchsize, shuffle=True, num_workers=2)

    input_shape = len(train_data[0][0])
    num_classes = args.num_classes
    
    if args.net == 'ResNet18':
        model = ResNet18(num_classes).to(args.device)
    elif args.net == 'VGG16':
        model = VGG16().to(args.device)
    elif args.net == 'MobileNetV2':
        model = MobileNetV2().to(args.device)
    elif args.net == 'Linear':
        model = Linear(input_shape, num_classes).to(args.device)

    loss_func = nn.CrossEntropyLoss()
    if args.opt == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.opt == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=args.lr,  momentum=0.9, weight_decay=5e-4)
    train(args, logging, model, loss_func, optimizer, train_loader)
    test(args, logging, model, test_loader)

    state = {
        'net': model.state_dict(),
        'epoch': args.epoch,
        'batch_size': args.batchsize,
        'optimizer': optimizer,
    }

    torch.save(state, os.path.join(args.moddir, args.dataset + '_' + args.net + '_' + str(args.modname) + '.pth'))

if __name__ == '__main__':
    main()
