r'''
    modified training script of GLU-Net
    https://github.com/PruneTruong/GLU-Net
'''

import argparse
import os
import pickle
import random
import time

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import utils_training.optimize as optimize

from os import path as osp
from tensorboardX import SummaryWriter
from termcolor import colored
from torch.utils.data import DataLoader
from collections import OrderedDict
from models.glbt import GLBT

from utils_training.logger import Logger, AverageMeter
from utils_training.evaluation import Evaluator
from utils_training.utils import parse_list, load_checkpoint, save_checkpoint, boolean_string
from data import download


if __name__ == "__main__":
    # Argument parsing
    parser = argparse.ArgumentParser(description='GLBT Training Script')
    # Paths
    parser.add_argument('--name_exp', type=str,
                        default=time.strftime('%Y_%m_%d_%H_%M'),
                        help='name of the experiment to save')
    parser.add_argument('--logpath', type=str, default='./snapshots')
    parser.add_argument('--pretrained', dest='pretrained', default=None,
                       help='path to pre-trained model')
    parser.add_argument('--model_path', default=None,
                       help='path to pre-trained model')
    parser.add_argument('--start_epoch', type=int, default=-1,
                        help='start epoch')
    parser.add_argument('--epochs', type=int, default=300,
                        help='number of training epochs')
    parser.add_argument('--batch-size', type=int, default=8,
                        help='training batch size')
    parser.add_argument('--n_threads', type=int, default=1,
                        help='number of parallel threads for dataloaders')
    parser.add_argument('--seed', type=int, default=2021,
                        help='Pseudo-RNG seed')
                        
    parser.add_argument('--datapath', type=str, default='../Datasets')
    parser.add_argument('--benchmark', type=str, default='spair', choices=['pfpascal', 'spair', 'pfwillow'])
    parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox'])
    parser.add_argument('--alpha', type=float, default=0.1)

    parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                        help='SGD momentum (default: 0.9)')
    parser.add_argument('--weight-decay', type=float, default=0.0,
                        help='weight decay (default: 0.05)')
    parser.add_argument('--lr', type=float, default=3e-4, metavar='LR',
                        help='learning rate (default: 3e-5)')
    parser.add_argument('--lr-backbone', type=float, default=3e-6, metavar='LR',
                        help='learning rate (default: 3e-6)')
    parser.add_argument('--scheduler', type=str, default='step', choices=['step', 'cosine'])
    parser.add_argument('--step', type=str, default='[150, 200, 250]')
    parser.add_argument('--step_gamma', type=float, default=0.5)

    parser.add_argument('--feature-size', type=int, default=32)
    parser.add_argument('--feature-proj-dim', type=int, default=128)
    parser.add_argument('--augmentation', type=boolean_string, nargs='?', const=True, default=True)

    # Seed
    args = parser.parse_args()
    Logger.initialize(args, training=True)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Logger.info('# available GPUs: %d' % torch.cuda.device_count())

    # Initialize Evaluator
    Evaluator.initialize(args.benchmark, args.alpha)
    
    # Dataloader
    train_dataset = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'trn', args.augmentation, args.feature_size)
    val_dataset = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'test', args.augmentation, args.feature_size)
    train_dataloader = DataLoader(train_dataset,
        batch_size=args.batch_size,
        num_workers=args.n_threads,
        shuffle=True)
    val_dataloader = DataLoader(val_dataset,
        batch_size=args.batch_size,
        num_workers=args.n_threads,
        shuffle=False)

    # Model
    model = GLBT()
    # print(args.model_path)
    if args.model_path:
        checkpoints = torch.load(args.model_path)
        new_state_dict = OrderedDict()
        for k, v in checkpoints.items():
            new_state_dict[k[7:]] = v
        model.load_state_dict(new_state_dict, strict=False)
    Logger.log_params(model)
    param_model = [param for name, param in model.named_parameters() if 'backbone' not in name]
    param_backbone = [param for name, param in model.named_parameters() if 'backbone' in name]

    # Optimizer
    optimizer = optim.AdamW([{'params': param_model, 'lr': args.lr}, {'params': param_backbone, 'lr': args.lr_backbone}], 
                weight_decay=args.weight_decay)
    
    # Scheduler
    scheduler = \
        lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6, verbose=True)\
        if args.scheduler == 'cosine' else\
        lr_scheduler.MultiStepLR(optimizer, milestones=parse_list(args.step), gamma=args.step_gamma, verbose=True)

    if args.pretrained:
        # reload from pre_trained_model
        model, optimizer, scheduler, start_epoch, best_val = load_checkpoint(model, optimizer, scheduler,
                                                                 filename=args.pretrained)
        # now individually transfer the optimizer parts...
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
        cur_snapshot = os.path.basename(os.path.dirname(args.pretrained))
    else:
        if not os.path.isdir(args.logpath):
            os.makedirs(args.logpath)

        cur_snapshot = args.name_exp
        if not osp.isdir(osp.join(args.logpath, cur_snapshot)):
            os.makedirs(osp.join(args.logpath, cur_snapshot))

        with open(osp.join(args.logpath, cur_snapshot, 'args.pkl'), 'wb') as f:
            pickle.dump(args, f)

        best_val = 0
        start_epoch = 0

    # create summary writer
    save_path=osp.join(args.logpath, cur_snapshot)
    train_writer = SummaryWriter(os.path.join(save_path, 'train'))

    model = nn.DataParallel(model)
    model = model.to(device)

    train_started = time.time()

    for epoch in range(start_epoch, args.epochs):
        train_loss = optimize.train_epoch(model,
                                 optimizer,
                                 train_dataloader,
                                 device,
                                 epoch,
                                 train_writer)
        scheduler.step()
        val_loss_grid, val_mean_pck = optimize.validate_epoch(model,
                                                       val_dataloader,
                                                       device,
                                                       epoch=epoch)
        # Save the best model
        if val_mean_pck > best_val:
            best_val = val_mean_pck
            Logger.save_model_mPCK(model, epoch, val_mean_pck)

        Logger.tbd_writer.add_scalars('data/loss', {'trn_loss': train_loss, 'val_loss': val_loss_grid}, epoch)
        Logger.tbd_writer.add_scalars('data/mean_PCK', {'best_mPCK': best_val, 'cur_mPCK': val_mean_pck}, epoch)
        Logger.tbd_writer.flush()
    # Logger.info('==================== Finished Training ====================')
    print(args.seed, 'Training took:', time.time()-train_started, 'seconds')
    Logger.info('# Training took: seconds' % time.time()-train_started)
    Logger.tbd_writer.close()
