import os
import time
import argparse
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CyclicLR, CosineAnnealingLR
from torch.utils.tensorboard import SummaryWriter
# import tensorflow as tf
from lib.network import PoseNet, PoseNetV2, PoseNetV3
from lib.loss import Loss
from data.pose_dataset import PoseDataset
from lib.utils import setup_logger, compute_sRT_errors, load_obj
from lib.align import estimateSimilarityTransform
from lib.smr import SoftRenderer
from lib import geom_utils
from lib.transformations import quaternion_matrix
import pdb
import wandb

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CAMERA+Real', help='CAMERA or CAMERA+Real')
parser.add_argument('--data_dir', type=str, default='data/NOCS/', help='data directory')
parser.add_argument('--n_pts', type=int, default=1024, help='number of foreground points')
parser.add_argument('--n_cat', type=int, default=1, help='number of object categories')
parser.add_argument('--nv_prior', type=int, default=1024, help='number of vertices in shape priors')
parser.add_argument('--img_size', type=int, default=192, help='cropped image size')
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--num_workers', type=int, default=4, help='number of data loading workers')
parser.add_argument('--gpu', type=str, default='0', help='GPU to use')
parser.add_argument('--n_gpus', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate')
parser.add_argument('--start_epoch', type=int, default=1, help='which epoch to start')
parser.add_argument('--max_epoch', type=int, default=50, help='max number of epochs to train')
parser.add_argument('--resume_model', type=str, default='', help='resume from saved model')
parser.add_argument('--select_class', type=str, default='bottle', help='resume from saved model')
parser.add_argument('--renderer_type', type=str, default='softmax', help='choices are [hard, softmax]')
parser.add_argument('--result_dir', type=str, default='work_dirs/camera', help='directory to save train results')
parser.add_argument('--deform_epoch', type=int, default=0, help='directory to save train results')
parser.add_argument('--use_dz', action='store_true')
parser.add_argument('--semi', action='store_true')
parser.add_argument('--use_point_reg', action='store_true')
parser.add_argument('--use_nocs_map', action='store_true')
parser.add_argument('--use_fc', action='store_true')
parser.add_argument('--implict', action='store_true')
parser.add_argument('--max_point', action='store_true')
parser.add_argument('--version', type=str, default='v1')
parser.add_argument('--use_rgb', action='store_true')
parser.add_argument('--use_co3d', action='store_true')
parser.add_argument('--sep_stage', action='store_true')
parser.add_argument('--with_recon', action='store_true')
parser.add_argument('--feat_align', action='store_true')
parser.add_argument('--use_wild6d', action='store_true')
parser.add_argument('--finetune', action='store_true')
parser.add_argument('--no_pose_loss', action='store_true')
parser.add_argument('--semi_percent',  type=float, default=1.0)






opt = parser.parse_args()

opt.decay_epoch = [0, 10, 20, 30, 40]
opt.decay_rate = [1.0, 0.6, 0.3, 0.1, 0.01]
opt.corr_wt = 2.0
opt.cd_wt = 5.0
opt.entropy_wt = 0.0001
opt.deform_wt = 0.01
opt.mask_wt = 0.2
opt.pose_wt = 0.2
opt.pose_param_wt = 0.01
opt.project_wt = 0.0
opt.recon_wt = 0.5
opt.align_wt = 0.1
if opt.select_class == "all":
    opt.n_cat = 6

wandb.init(project="semi-pose", config=opt, name=opt.result_dir.split('/')[-1])

mean = torch.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1)
std = torch.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1)
cat_names = ['bottle', 'bowl', 'camera', 'can', 'laptop', 'mug']
faces_list = []
for cat in cat_names:
    faces_list.append(torch.from_numpy(load_obj('./assets/{}.obj'.format(cat))[-1]).unsqueeze(0))


def compute_T(trans, s_box, c_box, bbox, cam):
    # compute T from translation head
    ratio_delta_c = trans[:, :2]
    ratio_depth = trans[:, 2]
    pred_depth = ratio_depth * (opt.img_size / s_box)
    pred_c = ratio_delta_c * bbox[:, 2:] + c_box
    pred_x = (pred_c[:, 0] - cam[:, 0, 2]) * pred_depth / cam[:, 0, 0]
    pred_y = (pred_c[:, 1] - cam[:, 1, 2]) * pred_depth / cam[:, 1, 1]
    return torch.stack([pred_x, pred_y, pred_depth], dim=1)

def train_net(estimator, update_deform=False, update_pose=False, joint=False, resume_epoch=1):
    criterion = Loss(opt, use_gt_pose=update_deform, use_gt_model=update_pose)
    # dataset
    train_dataset = PoseDataset(opt.dataset, 'train', opt.data_dir, \
                opt.n_pts, opt.img_size, opt.select_class, opt.use_dz, \
                opt.use_co3d, opt.use_wild6d, opt.semi_percent)
    val_dataset = PoseDataset(opt.dataset, 'test', opt.data_dir, opt.n_pts, \
                opt.img_size, opt.select_class)
    # start training
    st_time = time.time()
    train_steps = 1500
    # train_steps = 500 * opt.n_cat
    global_step = train_steps * (resume_epoch - 1)
    n_decays = len(opt.decay_epoch)
    assert len(opt.decay_rate) == n_decays
    for i in range(n_decays):
        if resume_epoch > opt.decay_epoch[i]:
            decay_count = i
    train_size = train_steps * opt.batch_size
    indices = []
    page_start = -train_size
    end_epoch  = opt.max_epoch

    # optimizer = torch.optim.Adam(estimator.parameters(), lr=opt.lr)
    # lr_scheduler = CosineAnnealingLR(
    #     optimizer, T_max=opt.max_epoch, eta_min=1e-5
    # )
    for epoch in range(resume_epoch, end_epoch + 1):
        # train one epoch
        logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + \
                    ', ' + 'Epoch %02d' % epoch + ', ' + 'Training started'))
        # create optimizer and adjust learning rate if needed
        if decay_count < len(opt.decay_rate):
            if epoch > opt.decay_epoch[decay_count]:
                current_lr = opt.lr * opt.decay_rate[decay_count]
                if opt.n_gpus > 1:
                    # optimizer = torch.optim.Adam(estimator.module.parameters(), lr=current_lr)
                    optimizer = torch.optim.Adam([
                        {'params': estimator.module.encoder.parameters(), 'lr': current_lr},
                        {'params': estimator.module.deform_head.parameters(), 'lr': current_lr},
                        {'params': estimator.module.pose_head.parameters(), 'lr': current_lr},

                    ])
                else:
                    optimizer = torch.optim.Adam([
                        {'params': estimator.encoder.parameters(), 'lr': current_lr},
                        {'params': estimator.deform_head.parameters(), 'lr': current_lr},
                        {'params': estimator.pose_head.parameters(), 'lr': current_lr},

                    ])
                decay_count += 1
        # sample train subset
        page_start += train_size
        len_last = len(indices) - page_start
        real_ratio = 4
        if opt.dataset == 'CAMERA' and opt.use_wild6d:
            real_ratio = 3
        if len_last < train_size:
            indices = indices[page_start:]
            if opt.dataset == 'CAMERA+Real':
                # CAMERA : Real = 3 : 1
                camera_len = train_dataset.subset_len[0]
                real_len = train_dataset.subset_len[1]
                real_indices = list(range(camera_len, camera_len+real_len))
                camera_indices = list(range(camera_len))
                n_repeat = (train_size - len_last) // (real_ratio * real_len) + 1
                data_list = random.sample(camera_indices, 3*n_repeat*real_len) + real_indices*n_repeat
                random.shuffle(data_list)
                indices += data_list
            else:
                data_list = list(range(train_dataset.length))
                for i in range((train_size - len_last) // train_dataset.length + 1):
                    random.shuffle(data_list)
                    indices += data_list
            page_start = 0
        train_idx = indices[page_start:(page_start+train_size)]
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_idx)
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, sampler=train_sampler,
                                                       num_workers=opt.num_workers, pin_memory=True)
        estimator.train()
        for i, data in enumerate(train_dataloader, 1):
            if opt.use_co3d:
                (points, rgb, mask, choose, cat_id, model, prior, sRT, verts, nocs, pose, \
                    trans_local, bbox, box_c, box_s, crop_bbox, cam, is_real, 
                    is_symmetric, co3d_data) = data
                (points_real, rgb_real, mask_real, choose_real, cam_real, bbox_real, \
                crop_bbox_real, box_c_real, box_s_real) = co3d_data
            elif opt.use_wild6d:
                (points, rgb, mask, choose, cat_id, model, prior, sRT, verts, nocs, pose, \
                    trans_local, bbox, box_c, box_s, crop_bbox, cam, is_real, 
                    is_symmetric, wild6d_data) = data
                (points_real, rgb_real, mask_real, choose_real, cam_real, bbox_real, \
                crop_bbox_real, box_c_real, box_s_real, is_real_real) = wild6d_data                
            else:
                (points, rgb, mask, choose, cat_id, model, prior, sRT, verts, nocs, pose, \
                    trans_local, bbox, box_c, box_s, crop_bbox, cam, is_real, is_symmetric) = data
            faces = [faces_list[c].long().cuda() for c in cat_id]
            if opt.select_class != 'all':
                faces = [faces_list[cat_names.index(opt.select_class)].long().cuda() for c in cat_id]
                faces = torch.cat(faces, dim=0)
            points = points.cuda()
            rgb = rgb.cuda()
            mask = mask.float().cuda()
            choose = choose.cuda()
            cat_id = cat_id.cuda()
            model = model.cuda()
            prior = prior.cuda()
            verts = verts.float().cuda()
            sRT = sRT.cuda()
            nocs = nocs.cuda()
            pose = pose.float().cuda()
            trans_local = trans_local.float().cuda()
            bbox = bbox.float().cuda()
            box_c = box_c.float().cuda()
            box_s = box_s.float().cuda()
            K = cam.cuda()
            cam_crop = geom_utils.get_K_crop_resize(cam, crop_bbox, [480, 640], [opt.img_size, opt.img_size])
            cam_crop = cam_crop.cuda()
            cam_crop = torch.cat([cam_crop, torch.zeros_like(K[:, :, None, 0]).cuda()], dim=-1)
            is_real = is_real.cuda()
            is_symmetric = is_symmetric.cuda()
            if opt.use_co3d or opt.use_wild6d:
                bs = points_real.size(0)
                points_real = points_real[:bs//real_ratio].float().cuda()
                rgb_real = rgb_real[:bs//real_ratio].cuda()
                mask_real = mask_real[:bs//real_ratio].cuda().float()
                choose_real = choose_real[:bs//real_ratio].cuda()
                cam_crop_real = geom_utils.get_K_crop_resize(cam_real, crop_bbox_real, 
                            [640, 480], [opt.img_size, opt.img_size])
                K_real = cam_real[:bs//real_ratio].cuda().float()
                cam_crop_real = cam_crop_real[:bs//real_ratio].cuda().float()
                bbox_real = bbox_real[:bs//real_ratio].cuda().float()
                crop_bbox_real = crop_bbox_real[:bs//real_ratio].cuda().float()
                box_c_real = box_c_real[:bs//real_ratio].float().cuda()
                box_s_real = box_s_real[:bs//real_ratio].float().cuda()
                cam_crop_real = torch.cat([cam_crop_real, torch.zeros_like(K_real[:, :, None, 0]).cuda()], dim=-1)
                # view_RT = view_RT.float().cuda() # 3x4
                # cam_crop_co3d = torch.matmul(cam_crop_co3d, view_RT)

                points = torch.cat([points, points_real], dim=0)
                rgb = torch.cat([rgb, rgb_real], dim=0)
                mask = torch.cat([mask, mask_real], dim=0)
                choose = torch.cat([choose, choose_real], dim=0)
                cat_id = torch.cat([cat_id, cat_id[:bs//real_ratio]], dim=0)
                verts = torch.cat([verts, verts[:bs//real_ratio]], dim=0)
                faces = torch.cat([faces, faces[:bs//real_ratio]], dim=0)
                K = torch.cat([K, K_real], dim=0)
                cam_crop = torch.cat([cam_crop, cam_crop_real], dim=0)
                bbox = torch.cat([bbox, bbox_real], dim=0)
                box_c = torch.cat([box_c, box_c_real], dim=0)
                box_s = torch.cat([box_s, box_s_real], dim=0)
            outputs = estimator(points, rgb, choose, cat_id, verts)
            pred_scales, pred_trans, pred_rots = outputs['pose']
            pred_pose_trans = compute_T(pred_trans, box_s, box_c, bbox, K)
            pred_pose= torch.cat([pred_scales, pred_pose_trans, pred_rots], dim=1)
            total_loss, losses = criterion(outputs, trans_local, pred_pose, pose, mask, verts, \
                            nocs, model, faces, cam_crop, is_real, is_symmetric, epoch, points)
            optimizer.zero_grad()
            # deform_optimizer.zero_grad()
            # pose_optimizer.zero_grad()
            total_loss.backward()
            # deform_optimizer.step()
            # pose_optimizer.step()
            optimizer.step()
            # lr_scheduler.step()
            global_step += 1
            # write results to tensorboard
            for k, v in losses.items():
                tf_writer.add_scalar(k, v, global_step=global_step)

            if i % 50 == 0:
                if opt.semi:
                    logger.info('Batch {0} Loss:{1:f}, corr_loss:{2:f}, cd_loss:{3:f}, mask_loss:{4:f}, pose_loss:{5:f}, backproj_loss:{6:f}'.format(i, \
                        total_loss.item(),losses['corr_loss'].item(), losses['cd_loss'].item(), 
                        losses['mask_loss'].item(), losses['pose_loss'].item(), losses['backproj_loss'].item()))
                else:
                    logger.info('Batch {0} Loss:{1:f}, corr_loss:{2:f}, cd_loss:{3:f}, mask_loss:{4:f}, pose_loss:{5:f}'.format(
                        i, total_loss.item(), losses['corr_loss'].item(), losses['cd_loss'].item(), 
                        losses['mask_loss'].item(), losses['pose_loss'].item()))
                wandb.log(losses)

        logger.info('>>>>>>>>----------Epoch {:02d} train finish---------<<<<<<<<'.format(epoch))
        # lr_scheduler.step()

        # evaluate one epoch
        logger.info('Time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) +
                    ', ' + 'Epoch %02d' % epoch + ', ' + 'Testing started'))
        val_loss = 0.0
        total_count = np.zeros((opt.n_cat,), dtype=int)
        strict_success = np.zeros((opt.n_cat,), dtype=int)    # 5 degree and 5 cm
        easy_success = np.zeros((opt.n_cat,), dtype=int)      # 10 degree and 5 cm
        iou_success = np.zeros((opt.n_cat,), dtype=int)       # relative scale error < 0.1

        strict_success_pose = np.zeros((opt.n_cat,), dtype=int)    # 5 degree and 5 cm
        easy_success_pose = np.zeros((opt.n_cat,), dtype=int)      # 10 degree and 5 cm
        iou_success_pose = np.zeros((opt.n_cat,), dtype=int)       # relative scale error < 0.1
        # sample validation subset
        val_size = 1500
        val_idx = random.sample(list(range(val_dataset.length)), val_size)
        val_sampler = torch.utils.data.sampler.SubsetRandomSampler(val_idx)
        val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1, sampler=val_sampler,
                                                     num_workers=opt.num_workers, pin_memory=True)
        estimator.eval()
        for i, data in enumerate(val_dataloader, 1):
            (points, rgb, mask, choose, cat_id, model, prior, sRT, verts, nocs, pose, \
                trans_local, bbox, box_c, box_s, crop_bbox, cam, is_real, is_symmetric) = data
            faces = [faces_list[c].long().cuda() for c in cat_id]
            if opt.select_class != 'all':
                faces = [faces_list[cat_names.index(opt.select_class)].long().cuda() for c in cat_id]
                faces = torch.cat(faces, dim=0)
            points = points.cuda()
            rgb = rgb.cuda()
            mask = mask.float().cuda()
            choose = choose.cuda()
            cat_id = cat_id.cuda()
            model = model.cuda()
            prior = prior.cuda()
            verts = verts.float().cuda()
            sRT = sRT.cuda()
            nocs = nocs.cuda()
            pose = pose.float().cuda()
            trans_local = trans_local.float().cuda()
            bbox = bbox.float().cuda()
            box_c = box_c.float().cuda()
            box_s = box_s.float().cuda()
            K = cam.cuda()
            cam_crop = geom_utils.get_K_crop_resize(cam, crop_bbox, [480, 640], [opt.img_size, opt.img_size])
            cam_crop = cam_crop.cuda()
            cam_crop = torch.cat([cam_crop, torch.zeros_like(K[:, :, None, 0]).cuda()], dim=-1)
            is_real = is_real.cuda()
            is_symmetric = is_symmetric.cuda()

            outputs = estimator(points, rgb, choose, cat_id, verts)
            pred_scales, pred_trans, pred_rots = outputs['pose']
            pred_pose_trans = compute_T(pred_trans, box_s, box_c, bbox, K)
            pred_pose = torch.cat([pred_scales, pred_pose_trans, pred_rots], dim=1)
            total_loss, _ = criterion(outputs, trans_local, pred_pose, pose, mask, verts, \
                nocs, model, faces, cam_crop, is_real, is_symmetric, epoch, points)

            pred_scales = pred_scales.detach().cpu().numpy()[0]
            pred_pose_trans = pred_pose_trans.detach().cpu().numpy()[0]
            pred_rots = quaternion_matrix(pred_rots.detach().cpu().numpy()[0])

            # estimate pose and scale
            if 'deltas' in outputs.keys():
                inst_shape = verts + outputs['deltas']
            elif 'deformed_shape' in outputs.keys():
                inst_shape = outputs['deformed_shape'].float()
            if opt.version == 'v3' or opt.version == 'v4':
                nocs_coords = outputs['assign_mat'].detach().cpu().numpy()[0]
            else:
                assign_mat = F.softmax(outputs['assign_mat'], dim=2)
                nocs_coords = torch.bmm(assign_mat, inst_shape)
                nocs_coords = nocs_coords.detach().cpu().numpy()[0]
            points = points.cpu().numpy()[0]
            # use choose to remove repeated points
            choose = choose.cpu().numpy()[0]
            _, choose = np.unique(choose, return_index=True)
            nocs_coords = nocs_coords[choose, :]
            points = points[choose, :]
            _, _, _, pred_sRT = estimateSimilarityTransform(nocs_coords, points)
            # inst_size = 2 * np.amax(np.abs(inst_shape[0].detach().cpu().numpy()), axis=0)
            pred_sRT_pose = np.identity(4)
            pred_sRT_pose[:3, :3] = pred_rots[:3, :3]
            pred_sRT_pose[:3, 3] = pred_pose_trans
            # evaluate pose
            cat_id = cat_id.item()
            if pred_sRT is not None:
                sRT = sRT.detach().cpu().numpy()[0]
                R_error, T_error, IoU = compute_sRT_errors(pred_sRT, sRT)
                if R_error < 5 and T_error < 0.05:
                    strict_success[cat_id] += 1
                if R_error < 10 and T_error < 0.05:
                    easy_success[cat_id] += 1
                if IoU < 0.1:
                    iou_success[cat_id] += 1
                R_error, T_error, IoU = compute_sRT_errors(pred_sRT_pose, sRT)
                if R_error < 5 and T_error < 0.05:
                    strict_success_pose[cat_id] += 1
                if R_error < 10 and T_error < 0.05:
                    easy_success_pose[cat_id] += 1
                if IoU < 0.1:
                    iou_success_pose[cat_id] += 1
            
            total_count[cat_id] += 1
            val_loss += total_loss.item()
            if i % 100 == 0:
                logger.info('Batch {0} Loss:{1:f}'.format(i, total_loss.item()))
        # compute accuracy
        strict_acc = 100 * (strict_success / total_count)
        easy_acc = 100 * (easy_success / total_count)
        iou_acc = 100 * (iou_success / total_count)
        strict_acc_pose = 100 * (strict_success_pose / total_count)
        easy_acc_pose = 100 * (easy_success_pose / total_count)
        iou_acc_pose = 100 * (iou_success_pose / total_count)
        if opt.n_cat == 1:
            logger.info('{} accuracies:'.format(opt.select_class))
            logger.info('5^o 5cm: {:4f}'.format(strict_acc[0]))
            logger.info('10^o 5cm: {:4f}'.format(easy_acc[0]))
            logger.info('IoU < 0.1: {:4f}'.format(iou_acc[0]))
            logger.info('Pose: 5^o 5cm: {:4f}'.format(strict_acc_pose[0]))
            logger.info('Pose: 10^o 5cm: {:4f}'.format(easy_acc_pose[0]))
            logger.info('Pose: IoU < 0.1: {:4f}'.format(iou_acc_pose[0]))
        else:
            for i in range(opt.n_cat):
                logger.info('{} accuracies:'.format(val_dataset.cat_names[i]))
                logger.info('5^o 5cm: {:4f}'.format(strict_acc[i]))
                logger.info('10^o 5cm: {:4f}'.format(easy_acc[i]))
                logger.info('IoU < 0.1: {:4f}'.format(iou_acc[i]))
                logger.info('Pose: 5^o 5cm: {:4f}'.format(strict_acc_pose[i]))
                logger.info('Pose: 10^o 5cm: {:4f}'.format(easy_acc_pose[i]))
                logger.info('Pose: IoU < 0.1: {:4f}'.format(iou_acc_pose[i]))
        strict_acc = np.mean(strict_acc)
        easy_acc = np.mean(easy_acc)
        iou_acc = np.mean(iou_acc)
        strict_acc_pose = np.mean(strict_acc_pose)
        easy_acc_pose = np.mean(easy_acc_pose)
        iou_acc_pose = np.mean(iou_acc_pose)
        val_loss = val_loss / val_size
        tf_writer.add_scalar('5^o5cm_acc', strict_acc, global_step=global_step)
        tf_writer.add_scalar('10^o5cm_acc', easy_acc, global_step=global_step)
        tf_writer.add_scalar('iou_acc', iou_acc, global_step=global_step)
        tf_writer.add_scalar('Pose: 5^o5cm_acc', strict_acc_pose, global_step=global_step)
        tf_writer.add_scalar('Pose: 10^o5cm_acc', easy_acc_pose, global_step=global_step)
        tf_writer.add_scalar('Pose: iou_acc', iou_acc_pose, global_step=global_step)
        logger.info('Epoch {0:02d} test average loss: {1:06f}'.format(epoch, val_loss))
        logger.info('Overall accuracies:')
        logger.info('5^o 5cm: {:4f} 10^o 5cm: {:4f} IoU: {:4f}'.format(strict_acc, easy_acc, iou_acc))
        logger.info('Pose Regression -- 5^o 5cm: {:4f} 10^o 5cm: {:4f} IoU: {:4f}'.format(strict_acc_pose, \
            easy_acc_pose, iou_acc_pose))
        logger.info('>>>>>>>>----------Epoch {:02d} test finish---------<<<<<<<<'.format(epoch))
        wandb.log({'5^o5cm_acc':strict_acc, '10^o5cm_acc':easy_acc, 'iou_acc':iou_acc})
        
        # save model after each epoch
        if update_deform:
            save_dir = os.path.join(opt.result_dir, 'deform_net')
        elif update_pose:
            save_dir = os.path.join(opt.result_dir, 'pose_net')
        elif joint and opt.sep_stage:
            save_dir = os.path.join(opt.result_dir, 'joint')
        else:
            save_dir = opt.result_dir
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        if opt.n_gpus > 1:
            torch.save(estimator.module.state_dict(), '{0}/model_{1:02d}.pth'.format(save_dir, epoch))
        else:
            torch.save(estimator.state_dict(), '{0}/model_{1:02d}.pth'.format(save_dir, epoch))
    return estimator


if __name__ == '__main__':
    # seed = 0
    # torch.manual_seed(seed)
    # np.random.seed(seed)
    # set result directory
    if not os.path.exists(opt.result_dir):
        os.makedirs(opt.result_dir)
    tf_writer = SummaryWriter(os.path.join(opt.result_dir, 'logs'))
    logger = setup_logger('train_log', os.path.join(opt.result_dir, 'log.txt'))
    for key, value in vars(opt).items():
        logger.info(key + ': ' + str(value))
    # model & loss
    if opt.version == 'v3' or opt.version == 'v4' :
        estimator = PoseNetV3(opt)
    elif opt.version == 'v2':
        estimator = PoseNetV2(opt)
    else:
        estimator = PoseNet(opt)
    estimator.cuda()
    resume_deform = 0
    resume_pose = 0
    resume_joint = 0
    if opt.resume_model != '':
        logger.info('Load model from {}'.format(opt.resume_model))
        estimator.load_state_dict(torch.load(opt.resume_model))
        resume_epoch = int(opt.resume_model.split('/')[-1].split('_')[-1][:2]) + 1
        if 'deform' in opt.resume_model.split('/')[-2]:
            resume_deform = 1
        elif ('pose' in opt.resume_model.split('/')[-2]):
            resume_pose = 1
        elif ('joint' in opt.resume_model.split('/')[-2]):
            resume_joint = 1
    else:
        resume_epoch = 1
    if opt.use_wild6d or opt.finetune:
        resume_epoch = 1
    if opt.n_gpus > 1:
        estimator = torch.nn.DataParallel(estimator)
    if opt.sep_stage:
        if resume_deform:
            logger.info('Start Training deformation branch')
            estimator = train_net(estimator, update_deform=True, resume_epoch=resume_epoch)
            logger.info('Start Training pose estimation branch')
            estimator = train_net(estimator, update_pose=True)
            logger.info('Start Joint Training deformation & pose estimation branch')
            estimator = train_net(estimator, joint=True)
        elif resume_pose:
            logger.info('deformation branch finished')
            logger.info('Start Training pose estimation branch')
            estimator = train_net(estimator, update_pose=True, resume_epoch=resume_epoch)
            logger.info('Start Joint Training deformation & pose estimation branch')
            estimator = train_net(estimator, joint=True)
        elif resume_joint:
            logger.info('deformation branch finished')
            logger.info('pose estimation branch finished')
            logger.info('Start Joint Training deformation & pose estimation branch')
            estimator = train_net(estimator, joint=True, resume_epoch=resume_epoch)
        else:
            estimator = train_net(estimator, update_deform=True)
            estimator = train_net(estimator, update_pose=True)
            estimator = train_net(estimator, joint=True)
    else:
        estimator = train_net(estimator, joint=True, resume_epoch=resume_epoch)
