from __future__ import print_function, absolute_import, division

import datetime
import os
import os.path as path
import random

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn

from function_baseline.config import get_parse_args
from function_baseline.data_preparation import data_preparation
from function_baseline.model_pos_preparation_scale_ensemble import model_pos_preparation
from function_baseline.model_pos_train_datascale_adaptive_ensemble import train
from function_poseaug.model_pos_eval_datascale_adaptive_ensemble_output_avg import evaluate
from utils.log import Logger, savefig
from utils.utils import save_ckpt, loss


def main(args):
    print('==> Using settings {}'.format(args))
    device = torch.device("cuda")

    print('==> Loading dataset...')
    data_dict = data_preparation(args)

    print("==> Creating PoseNet model...")
    model_pos = model_pos_preparation(args, data_dict['dataset'], device)
    if args.num_branches > 1:
        if args.posenet_name == 'gcn':
            ind = 1
            for layer in model_pos.regression_scale_stages:
                torch.nn.init.orthogonal_(layer.weight)
                layer.weight.data += 0.1 * ind * torch.randn_like(layer.weight)
                ind += 1            
        elif args.posenet_name == 'poseformer':
            ind = 1
            for layer in model_pos.scale_stages:
                torch.nn.init.orthogonal_(layer.weight)
                layer.weight.data += 0.1 * ind * torch.randn_like(layer.weight)
                layer.bias.data += 0.1 * ind * torch.randn_like(layer.bias)
                ind += 1
            ind = 1
            for layer in model_pos.regression_stages:
                torch.nn.init.orthogonal_(layer[1].weight)
                layer[1].weight.data += 0.1 * ind * torch.randn_like(layer[1].weight)
                layer[1].bias.data += 0.1 * ind * torch.randn_like(layer[1].bias)
                ind += 1
        else:
            ind = 1
            for layer in model_pos.regression_stages:
                torch.nn.init.orthogonal_(layer.weight)
                layer.weight.data += 0.1 * ind * torch.randn_like(layer.weight)
                ind += 1
            

    print("==> Prepare optimizer...")
    criterion = nn.MSELoss(reduction='mean').to(device)                                            # Simple Baseline, VideoPose
    optimizer = torch.optim.Adam(model_pos.parameters(), lr=args.lr)                          # Simple Baseline, VideoPose

    ckpt_dir_path = path.join(args.checkpoint, args.posenet_name, args.keypoints,
                                   datetime.datetime.now().strftime('%m%d%H%M%S') + '_' + args.note)    

    os.makedirs(ckpt_dir_path, exist_ok=True)
    print('==> Making checkpoint dir: {}'.format(ckpt_dir_path))

    logger = Logger(os.path.join(ckpt_dir_path, 'log.txt'), args)
    logger.set_names(['epoch', 'lr', 'loss_train', 'error_h36m_p1', 'error_h36m_p2', 'error_3dhp_p1', 'error_3dhp_p2', 'pck', 'auc'])

    #################################################
    # ########## start training here
    #################################################
    start_epoch = 0
    error_best = None
    pck_best = None

    glob_step = 0
    lr_now = args.lr

    for epoch in range(start_epoch, args.epochs):
        print('\nEpoch: %d' % (epoch + 1))

        # Train for one epoch
        epoch_loss, lr_now, glob_step = train(data_dict['train_loader'], model_pos, criterion, optimizer, device, args.lr, lr_now,
                                              glob_step, args.lr_decay, args.lr_gamma, max_norm=args.max_norm, num_branches=args.num_branches)

        # Ensemble
        model_pos.eval()

        # Evaluate
        error_h36m_p1, error_h36m_p2, _, _ = evaluate(data_dict['H36M_test'], model_pos, device, args.num_branches)
        error_3dhp_p1, error_3dhp_p2, pck_3dhp, auc_3dhp = evaluate(data_dict['3DHP_test'], model_pos, device, args.num_branches,flipaug='_flip')

        # Update log file
        logger.append([epoch + 1, lr_now, epoch_loss, error_h36m_p1, error_h36m_p2, error_3dhp_p1, error_3dhp_p2, pck_3dhp, auc_3dhp])

        # Update checkpoint
        if error_best is None or error_best > error_h36m_p1:
            error_best = error_h36m_p1
            save_ckpt({'state_dict': model_pos.state_dict(), 'epoch': epoch + 1}, ckpt_dir_path, suffix='best')
        
        if pck_best is None or pck_best < pck_3dhp:
            pck_best = pck_3dhp
            save_ckpt({'state_dict': model_pos.state_dict(), 'epoch': epoch + 1}, ckpt_dir_path, suffix='best_pck')

        if (epoch + 1) % args.snapshot == 0:
            save_ckpt({'state_dict': model_pos.state_dict(), 'epoch': epoch + 1}, ckpt_dir_path)

        if epoch % 2 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.95

    logger.close()
    logger.plot(['loss_train', 'error_h36m_p1'])
    savefig(path.join(ckpt_dir_path, 'log.eps'))
    return



if __name__ == '__main__':
    args = get_parse_args()

    random_seed = args.random_seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    os.environ['PYTHONHASHSEED'] = str(random_seed)
    # copy from #https://pytorch.org/docs/stable/notes/randomness.html
    torch.backends.cudnn.deterministic = True
    cudnn.benchmark = True

    main(args)
