import torch
import torch.nn as nn
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import argparse
from vgg import vgg16, vgg16_bn, vgg19
from torchvision.models.vgg import vgg16 as cleanVGG16
from torchvision.models.vgg import vgg19 as cleanVGG19
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from dataset import *
from transform import pad_zero, get_model, getdata_set_name, num_class
from torch.optim import SGD
import global_time
import time
import sys
import math

def main(is_blind, batch_size, load_epoch, train_loader, 
         val_loader, saved_path, log_file, 
         model_name, datasetname):
    print(len(train_loader), len(val_loader))
    # system settings
    device = torch.device("cuda:0")
    dtype  = torch.float32

    # get models
    num_classes = num_class(datasetname)
    model = get_model(model_name, is_blind, device, dtype, num_classes)

    model.to(device)
    print(model)
    
    # get loss function & optimizer
    loss_fn = CrossEntropyLoss().to(device)
    optimizer = SGD(model.parameters(),
                    lr=0.01,
                    weight_decay=5e-4,
                    momentum=0.9)

    # prepare log file
    log = open(log_file, 'w')

    # load model data if applicable
    if load_epoch >= 0:
        load_model(model, load_epoch)

    mini_batch_size = batch_size

    for epoch in range(200):
        average_loss = 0.0
        print("training epcho ", epoch)
        mini_batch_count = 0
        # traning loop
        i = 1
        iter_count = 0
        model.train()
        for image, target in tqdm(train_loader):
            image = image.to(device)
            target= target.to(device)

            if is_blind:
                image = pad_zero(image)

            y_pred = model(image)
            y_copy = y_pred.detach().clone()
            # get loss
            if is_blind:
                y_copy = y_copy[0:+target.size(0)]
            y_copy.requires_grad = True
            loss = loss_fn(y_copy, target)
            loss.backward()
            grad_back = y_copy.grad
            if is_blind:
                grad_back = pad_zero(y_copy.grad)

            y_pred.backward(gradient=grad_back)

            # average loss calculation
            average_loss =  average_loss + (loss.item() - average_loss) / i
            if math.isnan(average_loss) or average_loss > 10.0:
                log.write("average loss ")
                log.write(str(average_loss))
                log.flush()
                log.close()
                sys.exit(0)

            i += 1
            mini_batch_count += batch_size
            # update mini batch
            if (mini_batch_count == mini_batch_size):
                optimizer.step()
                optimizer.zero_grad()
                mini_batch_count = 0
            
        # end of epoch & validation
        save_model(model, saved_path, epoch)
        model.eval()
        top1_err = 0.0
        top5_err = 0.0
        print("validating epoch ", epoch)
        total_num = 0
        for image, target in tqdm(val_loader):
            image = image.to(device)
            target = target.to(device)

            if target.size(0) % 2 != 0:
                continue
            # padding zeros
            if is_blind:
                image = pad_zero(image)

            with torch.no_grad():
                y_pred = model(image)
            if is_blind:
                pred = torch.argmax(y_pred[0:+target.size(0)], dim=1)
            else:
                pred = torch.argmax(y_pred, dim=1)
            diff = sum(((pred - target) == 0).float())
            total_num += target.size(0)
            top1_err += diff
            
        acc = top1_err / total_num
        #print result and save model
        log.write(str(epoch))
        log.write(", average_train loss ")
        log.write(str(average_loss))
        log.write(", top1 acc ")
        log.write(str(acc.item()))
        log.write("\n")
        log.flush()
        print("Epoch ", epoch, "average train loss ", average_loss, "top 1 acc ", acc.item())
        
        if num_classes == 100:
            if epoch == 59:
                optimizer = SGD(model.parameters(),
                    lr=0.002,
                    weight_decay=5e-4,
                    momentum=0.9)
                print("new lr is ", 0.002)
            elif epoch == 119:
                optimizer = SGD(model.parameters(),
                    lr=0.0004,
                    weight_decay=5e-4,
                    momentum=0.9)
                print("new lr is ", 0.0004)
            elif epoch == 159:
                optimizer = SGD(model.parameters(),
                    lr=0.00008,
                    weight_decay=5e-4,
                    momentum=0.9)
                print("new lr is ", 0.00008)
    log.close()

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--blind', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--load', type=int, default=-1)
    parser.add_argument("-s", "--saved", required=True, help="the path to dataset")
    parser.add_argument('--mean', type=float, default=0.0)
    parser.add_argument('--std', type=float, default=1.0)
    parser.add_argument("-l", "--log", required=True, help="the path to log file")
    parser.add_argument("-m", "--model", required=True, help="model name")
    parser.add_argument("-d", "--dataset", required=True, help="dataset name")

    args = parser.parse_args()
    saved_path = args.saved
    is_blind = args.blind == 1
    batch_size = args.batch_size
    
    global_time.init()
    global_time.mean = args.mean
    global_time.std  = args.std
    log_file = args.log
    model_name = args.model
    datasetname = getdata_set_name(args.dataset)
    train_loader = data_loader(datasetname, batch_size)
    val_loader = data_loader(datasetname, batch_size, is_train=False)
    
    main(is_blind, batch_size, args.load, train_loader, val_loader, saved_path, log_file, model_name, datasetname)
