import os
import sys
sys.path.insert(0, '..')
import argparse
import numpy as np
import logging

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms

import data_utils
from functions import * 
from network import StylizedFacePoint

"""
import warnings
warnings.filterwarnings("ignore")
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train StylizedFacePoint")
    parser.add_argument("--experiment_name", type=str, default="Exp1", help="the name of experiment")
    parser.add_argument("--data_name", type=str, default="FLSC", help="the name of dataset")
    parser.add_argument("--dataset_url", type=str, default="data/", help="the path of dataset")
    parser.add_argument("--num_lms", type=int, default=98, help="the number of landmarks in dataset")
    parser.add_argument("--num_nb", type=int, default=3, help="the number of neighbor landmarks")
    parser.add_argument("--input_size", type=int, default=256, help="the size of input images")
    parser.add_argument("--net_stride", type=int, default=16, help="the stride of network")
    parser.add_argument("--nstack", type=int, default=4, help="the number of stages in stacked hourglass network")
    parser.add_argument("--num_epochs", type=int, default=600, help="the number of training epoches")
    parser.add_argument("--batch_size", type=int, default=16, help="the number of images in a batch")
    parser.add_argument("--init_lr", type=float, default=0.0001, help="the initial learning rate")
    parser.add_argument("--decay_steps", type=str, default="200,300,450", help="the epoches of learning rate decay")
    parser.add_argument("--gamma", type=float, default=0.1, help="the decay weight of learning rate")
    parser.add_argument("--criterion_cls", type=str, default="l2", help="the type of main heatmap regression loss(l1/l2)")
    parser.add_argument("--criterion_reg", type=str, default="l1", help="the type of offset regression and neighbor regression loss(l1/l2)")
    parser.add_argument("--cls_loss_weight", type=float, default=10.0, help="the weight of main heatmap regression loss")
    parser.add_argument("--reg_loss_weight", type=float, default=1.0, help="the weight of offset regression and neighbor regression loss")
    parser.add_argument("--use_gpu", action="store_true", help="use gpu for model training")
    parser.add_argument("--gpu_id", type=int, default=0, help="the index of gpu(if use)")
    args = parser.parse_args()

    if not os.path.exists(os.path.join('./snapshots', args.data_name)):
        os.mkdir(os.path.join('./snapshots', args.data_name))
    save_dir = os.path.join('./snapshots', args.data_name, args.experiment_name)
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    if not os.path.exists(os.path.join('./logs', args.data_name)):
        os.mkdir(os.path.join('./logs', args.data_name))
    log_dir = os.path.join('./logs', args.data_name, args.experiment_name)
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)

    args.decay_steps = [int(i) for i in args.decay_steps.split(',')]

    logging.basicConfig(filename=os.path.join(log_dir, 'train.log'), level=logging.INFO)

    print('###########################################')
    print('experiment_name:', args.experiment_name)
    print('data_name:', args.data_name)
    print('dataset_url:', args.dataset_url)
    print('num_lms:', args.num_lms)
    print('num_nb:', args.num_nb)
    print('input_size:', args.input_size)
    print('net_stride:', args.net_stride)
    print('nstack:', args.nstack)
    print('num_epochs:', args.num_epochs)
    print('batch_size:', args.batch_size)
    print('init_lr:', args.init_lr)
    print('decay_steps:', args.decay_steps)
    print('gamma:', args.gamma)
    print('criterion_cls:', args.criterion_cls)
    print('criterion_reg:', args.criterion_reg)
    print('cls_loss_weight:', args.cls_loss_weight)
    print('reg_loss_weight:', args.reg_loss_weight)
    print('use_gpu:', args.use_gpu)
    print('gpu_id:', args.gpu_id)
    print('###########################################')
    logging.info('###########################################')
    logging.info('experiment_name: {}'.format(args.experiment_name))
    logging.info('data_name: {}'.format(args.data_name))
    logging.info('dataset_url: {}'.format(args.dataset_url))
    logging.info('num_lms: {}'.format(args.num_lms))
    logging.info('num_nb: {}'.format(args.num_nb))
    logging.info('input_size: {}'.format(args.input_size))
    logging.info('net_stride: {}'.format(args.net_stride))
    logging.info('nstack: {}'.format(args.nstack))
    logging.info('num_epochs: {}'.format(args.num_epochs))
    logging.info('batch_size: {}'.format(args.batch_size))
    logging.info('init_lr: {}'.format(args.init_lr))
    logging.info('decay_steps: {}'.format(args.decay_steps))
    logging.info('gamma: {}'.format(args.gamma))
    logging.info('criterion_cls: {}'.format(args.criterion_cls))
    logging.info('criterion_reg: {}'.format(args.criterion_reg))
    logging.info('cls_loss_weight: {}'.format(args.cls_loss_weight))
    logging.info('reg_loss_weight: {}'.format(args.reg_loss_weight))
    logging.info('use_gpu: {}'.format(args.use_gpu))
    logging.info('gpu_id: {}'.format(args.gpu_id))
    logging.info('###########################################')

    meanface_indices, reverse_index1, reverse_index2, max_len = get_meanface(os.path.join(args.dataset_url, args.data_name, 'meanface.txt'), args.num_nb)

    if args.use_gpu:
        device = torch.device("cuda:{}".format(args.gpu_id) if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    net = StylizedFacePoint(args, device)
    net = net.to(device)

    criterion_cls = None
    if args.criterion_cls == 'l2':
        criterion_cls = nn.MSELoss()
    elif args.criterion_cls == 'l1':
        criterion_cls = nn.L1Loss()
    else:
        print('No such cls criterion:', args.criterion_cls)

    criterion_reg = None
    if args.criterion_reg == 'l1':
        criterion_reg = nn.L1Loss()
    elif args.criterion_reg == 'l2':
        criterion_reg = nn.MSELoss()
    else:
        print('No such reg criterion:', args.criterion_reg)


    points_flip = [32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 46, 45, 44, 43, 42, 50, 49, 48, 47, 37, 36, 35, 34, 33, 41, 40, 39, 38, 51, 52, 53, 54, 59, 58, 57, 56, 55, 72, 71, 70, 69, 68, 75, 74, 73, 64, 63, 62, 61, 60, 67, 66, 65, 82, 81, 80, 79, 78, 77, 76, 87, 86, 85, 84, 83, 92, 91, 90, 89, 88, 95, 94, 93, 97, 96]
    points_flip = (np.array(points_flip)-1).tolist()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

    optimizer = optim.Adam(net.parameters(), lr=args.init_lr)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.decay_steps, gamma=args.gamma)

    labels_train = get_label(args.dataset_url, args.data_name, 'train.txt')
    labels_test = get_label(args.dataset_url, args.data_name, 'test.txt')

    train_data = data_utils.ImageFolder(os.path.join(args.dataset_url, args.data_name, 'images_train'), 
                                                labels_train, args.input_size, args.num_lms, 
                                                args.net_stride, points_flip, meanface_indices,
                                                transforms.Compose([
                                                transforms.RandomGrayscale(0.2),
                                                transforms.ToTensor(),
                                                normalize]), True)

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, drop_last=True)

    test_data = data_utils.ImageFolder(os.path.join(args.dataset_url, args.data_name, 'images_test'), 
                                                labels_test, args.input_size, args.num_lms, 
                                                args.net_stride, points_flip, meanface_indices,
                                                transforms.Compose([
                                                transforms.ToTensor(),
                                                normalize]), False)

    test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True, drop_last=False)

    train_model(net, train_loader, test_loader, criterion_cls, criterion_reg, args.cls_loss_weight, args.reg_loss_weight, args.num_nb, optimizer, args.num_epochs, scheduler, save_dir, device)