import torch
from torch.utils.data import Dataset, DataLoader
from MeshConvertPytorch3D import verts_proj, verts_proj_matrix, if_visiable_pytorch_given_pix_to_face, limit_to_img_size, limit_to_obj_mask
from MeshUtils import load_off, MeshInterpolateModule, pre_process_mesh_pascal, center_crop_fun, RasterizationSettings, MeshRasterizer, camera_position_from_spherical_angles, campos_to_R_T, camera_position_to_spherical_angle
# from pytorch3d.renderer import PerspectiveCameras, OpenGLPerspectiveCameras
from pytorch3d.renderer import OpenGLPerspectiveCameras, PerspectiveCameras
from SphereSampleManager import SphereSampleManager

from lib.NCEAverage_new2 import NearestMemoryManager, mask_remove_near
from models.KeypointRepresentationNet import NetE2E
import os
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
import random
import datetime

import seaborn as sns
import matplotlib.pyplot as plt
import argparse


parser = argparse.ArgumentParser(description='Pose estimation')
parser.add_argument('--net_type', default='resnet50_pre', type=str, help='')
# parser.add_argument('--net_type', default='resnet50_pre', type=str, help='')
parser.add_argument('--cate', default='car', type=str, help='')
args = parser.parse_args()


standard_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

cate = args.cate
image_start_all = dict()
image_start_all['car'] = {'002905_0':2, '000920_0': 2, '000120_0':2, '007195_4':2, '001217_1': 2, '002017_0':2, '003507_0':2}


image_start = image_start_all[cate]
device = 'cuda:0'
mesh_d = 'buildn'
# mesh_d = 'sphere1'

mesh_path = '../PASCAL3D/PASCAL3D+_release1.1/CAD_' + mesh_d + '/' + cate

anno_path = '../KITTI/KITTI_train_distcrop/annotations/' + cate
img_path = '../KITTI/KITTI_train_distcrop/images/' + cate


net_type = args.net_type

save_dir = '../3DrepresentationData/Unsupervised_V5_%s_trained_KITTI_50/' % net_type

image_size_ori = {'car': (256, 672), 'bus': (320, 800), 'motorbike': (512, 512)}[cate]
batch_size = {'car': 32, 'bus': 30, 'motorbike': 36}[cate]

use_anno_list = '../KITTI/KITTI_train_distcrop/list/car_mostly_visible_f012_50.txt'
# use_anno_list = None
extra_train_epochs = -1

num_noise = 5
max_group = 512

distance_thr = 48

dual_side = True
n_render_iter = 1
# gt_pose_mask = 8
gt_pose_mask = None

n_hypo = 1

distance = 5
elevation_base, azimuth_base, theta_base = 0, np.pi / 2, 0


total_epochs = 100

criterion = torch.nn.CrossEntropyLoss(reduction='none').cuda()

pose_template = './UVsamples4p1.off'

group_sel = 2

accumulated_num = 0

verts_pose, faces_pose = load_off(pose_template)
pose_manager = SphereSampleManager(verts_pose, faces_pose, batch_size=1, distance=distance, return_idx=True)

mask_for_pose0 = dict(np.load(pose_template.split('.off')[0] + '_mask_%s.npz' % cate))
mask_for_pose = []
for k in range(len(mask_for_pose0.keys())):
    mask_for_pose.append(torch.from_numpy(mask_for_pose0[str(k)]).unsqueeze(0))
mask_for_pose = torch.cat(mask_for_pose)


# lr = 1e-5
lr = 1e-4
update_lr_epoch_n = 30
update_lr_ = 0.5
use_pose_num = 25
num_workers = 0

each_train_epoch = 120

net_type_mapping = {'vggp4': 'vgg_pool4', 'resnet50_pre': 'resnet50_pre', 'resnetupsample': 'resnetupsample'}
net_dfeature_mapping = {'vggp4': 512, 'resnet50_pre': 1024, 'resnetupsample': 2048}


img_dir = 'img_unsuper_%s_KITTI' % net_type
feature_d = net_dfeature_mapping[net_type]
# feature_d = 1024

os.makedirs(img_dir, exist_ok=True)

def load_one_image(img_path_, transforms_=standard_transforms):
    img_ = Image.open(img_path_).convert('RGB')
    return transforms_(img_)


def load_one_image_with_anno(img_path_, anno_path_, transforms_=standard_transforms):
    img_ = Image.open(img_path_).convert('RGB')
    anno = np.load(anno_path_, allow_pickle=True)
    return transforms_(img_), anno


class Pascal3D(Dataset):
    useful_keys = {'elevation': np.float32, 'azimuth': np.float32, 'cad_index': np.int32}

    def __init__(self, img_path=img_path, anno_path=anno_path, image_list=None, enable_cache=True, transform=None):
        if image_list is None:
            all_imgs = os.listdir(img_path)
        elif isinstance(image_list, str):
            all_imgs = [t.strip() + '.JPEG' for t in open(image_list).readlines()]
        else:
            all_imgs = [t.strip().split('.')[0] + '.JPEG' for t in image_list]
        self.all_imgs = [t.split('.')[0] for t in all_imgs]
        self.img_path = img_path
        self.anno_path = anno_path

        self.cache_anno = dict()
        self.cache_img = dict()
        self.transform = transform
        self.enable_cache = enable_cache

        self.enabled_mask = False
        self.mask = None

    def __getitem__(self, item):
        if self.enabled_mask:
            img_name = self.all_imgs[self.mask[item]]
        else:
            img_name = self.all_imgs[item]
        if not self.enable_cache:
            img_ = Image.open(os.path.join(self.img_path, img_name + '.JPEG')).convert('RGB')
            anno = np.load(os.path.join(self.anno_path, img_name + '.npz'), allow_pickle=True)
            anno = {k_: anno[k_] for k_ in self.useful_keys}
        elif img_name not in self.cache_anno.keys():
            img_ = Image.open(os.path.join(self.img_path, img_name + '.JPEG')).convert('RGB')
            anno = np.load(os.path.join(self.anno_path, img_name + '.npz'), allow_pickle=True)
            anno = {k_: anno[k_].astype(self.useful_keys[k_]) for k_ in self.useful_keys.keys()}

            self.cache_img[img_name] = img_
            self.cache_anno[img_name] = anno
        else:
            img_ = self.cache_img[img_name]
            anno = self.cache_anno[img_name]

        return self.transform(img_), anno

    def __len__(self):
        if self.enabled_mask:
            return len(self.mask)
        return len(self.all_imgs)

    def get_list_images(self, img_name_list):
        imgs, annos = [], []
        for img_name in img_name_list:
            img, anno = self.__getitem__(self.all_imgs.index(img_name))
            imgs.append(img)
            annos.append(anno)
        return torch.stack(imgs), {k_: torch.stack([torch.from_numpy(anno_[k_]) for anno_ in annos]) for k_ in annos[0].keys()}

    def enable_mask(self, mask):
        if isinstance(mask, torch.Tensor):
            mask = mask.squeeze().cpu().numpy()

        if isinstance(mask, np.ndarray):
            mask = mask.tolist()
        self.enabled_mask = True
        self.mask = mask

    def disable_mask(self):
        self.enabled_mask = False


class LabelLoader(Dataset):
    def __init__(self, labels, vis, masks):
        self.labels = labels
        self.vis = vis.type(torch.float32)
        self.masks = masks.type(torch.float32)

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, item):
        return self.labels[item], self.vis[item], self.masks[item]


def do_render(inter_module, cam_pos, theta, verts_mask=None):
    R, T = campos_to_R_T(cam_pos, theta, device=device)
    projected_map, fragment = inter_module.forward(cam_pos, theta, return_fragments=True)
    out_vis = []

    # (sum(V_n), 3)
    verts = inter_module.meshes.verts_packed()
    faces = inter_module.meshes.faces_packed()

    for i in range(cam_pos.shape[0]):
        isvisible_ = if_visiable_pytorch_given_pix_to_face(fragment.pix_to_face[i] - i * faces.shape[0], verts=verts, faces=faces).to(device)
        out_vis.append(isvisible_.unsqueeze(0))
    all_vis = torch.cat(out_vis, dim=0)

    # [n_azum * n_elev, n_vert, 2]
    all_vert = verts_proj_matrix(verts.unsqueeze(0), R, T, principal=(image_size_ori[0] // 2, image_size_ori[1] // 2))
    # all_vert = limit_to_img_size(all_vert, img.shape[2::])
    if verts_mask is None:
        all_vert = limit_to_img_size(all_vert, image_size_ori)
    else:
        all_vert, all_vis = limit_to_obj_mask(all_vert, mask=verts_mask, vis=all_vis)


    return projected_map, all_vert, all_vis


def retrieve_similarity(target_maps, this_loader, network, return_annos=True):
    out_sims = []
    out_annos = []

    network.eval()
    with torch.no_grad():
        for img, anno in this_loader:
            if target_maps.device != torch.device('cpu'):
                img = img.cuda()
            feature_map = network(X=img, mode=0)

            # [n_img, n_pos]
            similarity = torch.sum(feature_map.unsqueeze(1) * target_maps.unsqueeze(0), dim=(2, 3, 4))
            out_sims.append(similarity)
            out_annos.append(anno)
    sims = torch.cat(out_sims, dim=0)
    annos = {k_: torch.cat([anno_[k_] for anno_ in out_annos]) for k_ in out_annos[0].keys()}

    if return_annos:
        return sims, annos
    else:
        return sims


def train_fixed(training_images, training_verts, training_vis, training_mask, network, memory_bank, optimizer, cuda_=True):
    features = network(X=training_images, keypoint_positions=training_verts, obj_mask=1 - training_mask.type(torch.float32).to(training_images.device), mode=-1)
    index = torch.Tensor([[k for k in range(training_verts.shape[1])] for _ in range(training_images.shape[0])]).to(device)
    score, y_idx, noise_score = memory_bank.forward(features, index, training_vis)
    score /= 0.07

    with torch.no_grad():
        mask_distance_legal = mask_remove_near(training_verts, thr=distance_thr, num_neg=num_noise * max_group,
                                               dtype_template=score, neg_weight=5e-3)

    loss = criterion(
        ((score.view(-1, score.shape[2]) - mask_distance_legal.view(-1, score.shape[2])))[training_vis.view(-1), :],
        y_idx.view(-1)[training_vis.view(-1)])

    loss = torch.mean(loss)
    loss_main = loss.item()
    if num_noise > 0 and True:
        # loss_reg = torch.nn.functional.leaky_relu(torch.mean(noise_sim), negative_slope=0.1, inplace=False) * 0.2
        loss_reg = torch.mean(noise_score) * 0.1
        loss += loss_reg
    else:
        loss_reg = torch.zeros(1)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss_main


def train_step(this_loader, label_loader, network, memory_bank, optimizer, cuda_=True):
    network.train()
    iter_num = 0
    loss_mains = []
    for (img, anno), (verts, vis, masks) in zip(this_loader, label_loader):
        iter_num += 1
        if cuda_:
            img = img.cuda()
            verts = verts.cuda()
            vis = vis.cuda()
            masks = masks.cuda()
        features = network(X=img, keypoint_positions=verts, obj_mask=1 - masks, mode=-1)

        index = torch.Tensor([[k for k in range(verts.shape[1])] for _ in range(img.shape[0])]).to(device)
        score, y_idx, noise_score = memory_bank.forward(features, index, vis)
        score /= 0.07

        with torch.no_grad():
            mask_distance_legal = mask_remove_near(verts, thr=distance_thr, num_neg=num_noise * max_group,
                                                   dtype_template=score, neg_weight=5e-3)
        vis = vis.type(torch.bool).to(vis.device)

        loss = criterion(
            ((score.view(-1, score.shape[2]) - mask_distance_legal.view(-1, score.shape[2])))[vis.view(-1), :],
            y_idx.view(-1)[vis.view(-1)])

        loss = torch.mean(loss)
        loss_main = loss.item()
        if num_noise > 0 and True:
            # loss_reg = torch.nn.functional.leaky_relu(torch.mean(noise_sim), negative_slope=0.1, inplace=False) * 0.2
            loss_reg = torch.mean(noise_score) * 0.1
            loss += loss_reg
            # print('n_iter', iter_num, 'loss', '%.5f' % loss_main, 'loss_reg', '%.5f' % loss_reg.item())
        else:
            # print('n_iter', iter_num, 'loss', loss_main)
            loss_reg = torch.zeros(1)
        loss.backward()

        loss_mains.append(loss_main)

        if iter_num % 1 == 0:
            optimizer.step()
            optimizer.zero_grad()
    return loss_mains


def similarity_to_select(similarity, num_per_pos):
    # [n_img, n_pos]
    max_idx1 = torch.argmax(similarity, dim=1)
    selected_sim1 = torch.zeros_like(similarity).scatter(1, max_idx1.unsqueeze(1), 1.0).view(*similarity.shape)

    similarity = selected_sim1 * similarity

    # [num_per, n_pos]
    max_value, max_idx = torch.topk(similarity, num_per_pos, dim=0)
    max_idx = (max_value > 0) * max_idx + (max_value <= 0) * -1

    return max_idx


def get_idx_sel(max_idx):
    get_idx = torch.unique(max_idx)
    get_idx = get_idx[get_idx > -1]
    return get_idx, torch.where(max_idx.unsqueeze(0) == get_idx.unsqueeze(1).unsqueeze(2))[2]


def add_filpped(img=None, pose=None, annos=None):
    if img is not None:
        img_filpped = torch.flip(img, dims=(3,))
        img_out = torch.cat((img, img_filpped), dim=0)
    else:
        img_out = None
    if pose is not None:
        pose_filpped = torch.cat([-pose[:, 0:1], pose[:, 1:2], pose[:, 2:3]], dim=1)
        pose_out = torch.cat((pose, pose_filpped), dim=0)
    else:
        pose_out = None

    if annos is not None:
        for k in annos.keys():
            if k == 'azimuth':
                annos[k] = torch.cat((annos[k], np.pi * 2 - annos[k]))
            else:
                annos[k] = torch.cat((annos[k], annos[k]))
        return img_out, pose_out, annos

    return img_out, pose_out


def add_filpped_pose(pose, fp=(-1, -1)):
    pose_filpped = np.concatenate([fp[0] * pose[:, 0:1], pose[:, 1:2], fp[1] * pose[:, 2:3]], axis=1)
    return np.concatenate((pose, pose_filpped), axis=0), pose_filpped


def gt_mask(anchor_annos, gt_annos, thr):
    # return: [n_img, n_pos]
    gt_cam_pos = camera_position_from_spherical_angles(distance, gt_annos['elevation'], gt_annos['azimuth'], degrees=False)
    anchor_cam_pos = camera_position_from_spherical_angles(distance, anchor_annos['elevation'], anchor_annos['azimuth'], degrees=False)
    return (((gt_cam_pos.unsqueeze(1) - anchor_cam_pos.unsqueeze(0)) ** 2).sum(2) < (thr ** 2)).type(torch.float32)


def normalize(x, dim=1, p=2):
    return x / x.pow(p).sum(dim, keepdims=True).pow(1/p)


def draw_one_vis(annos_anchor, annos_sel, all_sel, idx_sel, save_name):
    all_sel = all_sel.cpu()
    idx_sel = idx_sel.cpu()

    for i in range(annos_anchor['azimuth'].shape[0]):
        plt.clf()

        azim_ = np.array([annos_anchor['azimuth'][i]])
        elev_ = np.array([annos_anchor['elevation'][i]])

        graph = sns.jointplot(x=azim_,
                              y=elev_, color='purple', )

        this_sel = torch.zeros(annos_sel['azimuth'].shape[0], dtype=torch.bool).scatter(0, all_sel[idx_sel == i], True)

        graph.x = annos_sel['azimuth'][torch.logical_not(this_sel)].numpy()
        graph.y = annos_sel['elevation'][torch.logical_not(this_sel)].numpy()
        graph.plot_joint(plt.scatter, c='b')

        graph.x = annos_sel['azimuth'][this_sel].numpy()
        graph.y = annos_sel['elevation'][this_sel].numpy()
        graph.plot_joint(plt.scatter, c='r')

        graph.x = azim_
        graph.y = elev_
        graph.plot_joint(plt.scatter, c='g')

        plt.savefig(save_name % i)
        plt.close()


def get_gt_dataset(annos, theta_=theta_base, distance_=distance, device_=device, render_batch_size=20):
    with torch.no_grad():
        this_azimuth = [t['azimuth'] for t in annos]
        this_elevation = [t['elevation'] for t in annos]

        this_azimuth = torch.from_numpy(np.array(this_azimuth)).to(device_)
        this_elevation = torch.from_numpy(np.array(this_elevation)).to(device_)
        # this_elevation = torch.from_numpy(np.array([0.])).to(device_)

        cam_pos = camera_position_from_spherical_angles(distance=distance_, azimuth=this_azimuth, elevation=this_elevation,
                                                        degrees=False, device=device_).type(torch.float32)

        if not isinstance(theta_, torch.Tensor):
            theta_ = torch.ones(1).to(device_) * theta_

        all_verts = []
        all_vis = []
        all_masks = []
        for ii in range(cam_pos.shape[0] // render_batch_size + 1):
            this_cam_pos = cam_pos[ii * render_batch_size: (ii + 1) * render_batch_size]
            this_pos_idx = pose_manager.get_idx(this_cam_pos.cpu().numpy())
            _, this_vert, this_vis = do_render(inter_module=inter_module, cam_pos=this_cam_pos,
                                                 theta=theta_.repeat_interleave(this_cam_pos.shape[0] // theta_.shape[0]),)
                                                 # verts_mask=mask_for_pose[this_pos_idx].to(device))

            all_masks.append(mask_for_pose[this_pos_idx])
            all_verts.append(this_vert.cpu())
            all_vis.append(this_vis.cpu())

        all_verts = torch.cat(all_verts, dim=0)
        all_vis = torch.cat(all_vis, dim=0)
        all_masks = torch.cat(all_masks, dim=0)
    return LabelLoader(labels=all_verts, vis=all_vis, masks=all_masks)


if __name__ == '__main__':
    # ------------- Init ----------------
    net = NetE2E(net_type=net_type_mapping[net_type], local_size=(1, 1),
                 output_dimension=-1, reduce_function=None, n_noise_points=num_noise, pretrain=True,
                 noise_on_mask=True)
    verts, faces = load_off(os.path.join(mesh_path, '01.off'), to_torch=True)

    down_sample_rate = net.net_stride
    feature_size = (image_size_ori[0] // down_sample_rate, image_size_ori[1] // down_sample_rate)
    feature_size_m = (max(feature_size), max(feature_size))

    # viewpoint * focal
    camera = PerspectiveCameras(focal_length=1.0 * 3000 / down_sample_rate, principal_point=((feature_size_m[0] / 2, feature_size_m[1] / 2), ), image_size=(feature_size_m, ), device=device)
    # camera = OpenGLPerspectiveCameras(device=device, fov=12.0)
    raster_settings = RasterizationSettings(
        image_size=feature_size_m[0],
        blur_radius=0.0,
        faces_per_pixel=1,
        bin_size=0
    )
    rasterizer = MeshRasterizer(
        cameras=camera,
        raster_settings=raster_settings
    )
    inter_module = MeshInterpolateModule(vertices=[verts] * n_hypo, faces=[faces] * n_hypo, memory_bank=[torch.zeros((verts.shape[0], feature_d))], rasterizer=rasterizer,
                                         post_process=center_crop_fun(feature_size, feature_size_m))

    inter_module = inter_module.cuda()
    net = torch.nn.DataParallel(net.cuda())

    memory_bank = NearestMemoryManager(inputSize=feature_d, outputSize=verts.shape[0] + num_noise * max_group,
                                       K=1, num_noise=num_noise, num_pos=verts.shape[0], momentum=0.95, )
    memory_bank.cuda()

    dataset = Pascal3D(transform=standard_transforms, image_list='../KITTI/KITTI_train_distcrop/list/car_mostly_visible_f012.txt')
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    theta0 = torch.ones(1).to(device) * theta_base


    # --------------- Prepare Part dataset ----------------
    if use_anno_list is not None:
        annos_list = open(use_anno_list).readlines()
        annos_list = [t.split('.')[0].strip() for t in annos_list if len(t.strip()) > 0]
        print('num_anno:', len(annos_list))

        out_list = [k for k in image_start.keys() if k not in annos_list]
        annos_list = annos_list[0:len(annos_list) - len(out_list)] + out_list

        with torch.no_grad():
            useful_anno_keys = ['azimuth', 'elevation']
            out_anno = []
            for anno_name in annos_list:
                this_fl = np.load(os.path.join(anno_path, anno_name + '.npz'))
                out_anno.append({k:float(this_fl[k]) for k in useful_anno_keys})

                out_anno[-1]['name'] = anno_name

            gt_label_set = get_gt_dataset(out_anno)
            gt_image_set = Pascal3D(img_path=img_path, anno_path=anno_path, transform=standard_transforms, image_list=annos_list)

            print(len(gt_image_set), len(gt_label_set))

            gt_label_loader = DataLoader(gt_label_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
            gt_image_loader = DataLoader(gt_image_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        
        if extra_train_epochs > 0:
            optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
            for epoch in range(extra_train_epochs):
                if (epoch + 1) % update_lr_epoch_n == 0:
                    # optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                    for param_group in optim.param_groups:
                        param_group['lr'] *= update_lr_
                loss_anno = train_step(gt_image_loader, gt_label_loader, network=net,
                                       memory_bank=memory_bank, optimizer=optim)
                print('Train on anno: ', ' epoch:', epoch, ' loss_anno: ', np.mean(loss_anno))


    step = 0
    # ------------- Step 0 ----------------
    optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
    start_imgs, start_annos = dataset.get_list_images(image_start.keys())

    start_cam_pos = camera_position_from_spherical_angles(distance=distance, azimuth=start_annos['azimuth'], elevation=start_annos['elevation'], degrees=False)
    start_cam_pos = start_cam_pos.type(torch.float32).to(device)

    start_imgs, start_cam_pos, start_annos = add_filpped(start_imgs, start_cam_pos, start_annos)
    _, start_idx = pose_manager.get_init(normalize(start_cam_pos, p=2, dim=1).cpu().numpy())

    # Get annotations for start point gt training
    _, start_vert, start_vis = do_render(inter_module=inter_module, cam_pos=start_cam_pos,
                                               theta=theta0.repeat_interleave(start_cam_pos.shape[0] // theta0.shape[0]),
                                               verts_mask=mask_for_pose[start_idx].to(device))

    print(start_cam_pos)

    with torch.no_grad():
        start_imgs = start_imgs.cuda()
        start_features = net(X=start_imgs, mode=0)


    # ---------------Hack for object mask on feature map
    with torch.no_grad():
        this_masks = mask_for_pose[start_idx].type(torch.float32).to(device)
        net_stride = net.module.net_stride
        this_masks = torch.nn.functional.max_pool2d(this_masks.unsqueeze(dim=1), kernel_size=net_stride, stride=net_stride,
                                        padding=(net_stride - 1) // 2)
        start_features *= this_masks

    # ---------------End Hack

    start_similarity, start_retrieved_annos = retrieve_similarity(start_features, dataloader, net)

    if gt_pose_mask is not None:
        start_similarity *= gt_mask(start_annos, start_retrieved_annos, gt_pose_mask).to(start_similarity.device)

    max_sel_idx = similarity_to_select(start_similarity, use_pose_num)
    selected_idx, selected_cam_pos = get_idx_sel(max_sel_idx)

    _, pro_vert, pro_vis = do_render(inter_module=inter_module, cam_pos=start_cam_pos, theta=theta0.repeat_interleave(start_cam_pos.shape[0] // theta0.shape[0]), verts_mask=mask_for_pose[torch.from_numpy(start_idx)].to(device))

    verts_label, vis_label, mask_label = pro_vert[selected_cam_pos], pro_vis[selected_cam_pos], mask_for_pose[torch.from_numpy(start_idx)].to(device)[selected_cam_pos]
    dataloader_train = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    dataset.enable_mask(selected_idx)
    label_loader = DataLoader(LabelLoader(verts_label, vis_label, mask_label), batch_size=batch_size, shuffle=False, num_workers=num_workers)

    for epoch in range(each_train_epoch):
        if (epoch + 1) % update_lr_epoch_n == 0:
            # optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
            for param_group in optim.param_groups:
                param_group['lr'] *= update_lr_
        loss_get = train_step(this_loader=dataloader_train, label_loader=label_loader, network=net,
                              memory_bank=memory_bank, optimizer=optim)
        if epoch % 8 < 8:
            loss_start = train_fixed(training_images=start_imgs, training_verts=start_vert, training_vis=start_vis,
                                     training_mask=mask_for_pose[start_idx].to(device), network=net,
                                     memory_bank=memory_bank, optimizer=optim)
            print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get), ' loss_start: ', loss_start)
        else:
            print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get))


        if use_anno_list is not None:
            loss_anno = train_step(gt_image_loader, gt_label_loader, network=net,
                          memory_bank=memory_bank, optimizer=optim)
            print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get), ' loss_anno: ', np.mean(loss_anno))


    inter_module.update_memory(memory_bank=memory_bank.memory[0:verts.shape[0]])

    draw_one_vis(start_annos, start_retrieved_annos, selected_idx, selected_cam_pos, img_dir + '/%s_%d_%s.png' % (cate, step, '%d'))

    dataset.disable_mask()

    step += 1
    # ------------- Step i ----------------
    for i in range(total_epochs):
        optim = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)

        # No grad on render
        with torch.set_grad_enabled(False):
            if (i + 1) % 20 == 0:
                os.makedirs(save_dir, exist_ok=True)
                saved_dict = dict(net=net.state_dict(), memory=memory_bank.memory, start=list(image_start.keys()))
                torch.save(saved_dict, os.path.join(save_dir, 'saved_model_%d_%d_ungt_%s.pth' % (each_train_epoch, step, cate, )))

            if pose_manager.if_maxed(0):
                cam_pos, pos_idx = start_cam_pos, start_idx
                _, _ = pose_manager.get_init(normalize(cam_pos, p=2, dim=1).cpu().numpy())
                lr *= 0.1

            else:
                cam_pos, pos_idx = next(pose_manager)

                cam_pos, cam_pos_f = add_filpped_pose(cam_pos)
                pos_idx_f = pose_manager.get_idx(cam_pos_f)
                pos_idx = np.concatenate((pos_idx, pos_idx_f))

                cam_pos, cam_pos_f = add_filpped_pose(cam_pos, fp=(-1, 1))
                pos_idx_f = pose_manager.get_idx(cam_pos_f)
                pos_idx = np.concatenate((pos_idx, pos_idx_f))

                pos_idx = torch.from_numpy(pos_idx).type(torch.long)
                cam_pos = torch.from_numpy(cam_pos).to(device)

            _, elev_ach, azum_ach = camera_position_to_spherical_angle(cam_pos)
            # print(elev_ach.shape)
            pseudo_annos = dict(azimuth=azum_ach.cpu().numpy(), elevation=elev_ach.cpu().numpy())

            pro_feature, pro_vert, pro_vis = do_render(inter_module=inter_module, cam_pos=cam_pos,
                                             theta=theta0.repeat_interleave(cam_pos.shape[0] // theta0.shape[0]),
                                             verts_mask=mask_for_pose[pos_idx].to(device))

            this_similarity, this_retrieved_annos = retrieve_similarity(pro_feature, dataloader, net)

            if gt_pose_mask is not None:
                this_similarity *= gt_mask(pseudo_annos, this_retrieved_annos, gt_pose_mask).to(
                    this_similarity.device)

            max_sel_idx = similarity_to_select(this_similarity, use_pose_num)
            selected_idx, selected_cam_pos = get_idx_sel(max_sel_idx)

            verts_label, vis_label, mask_label = pro_vert[selected_cam_pos], pro_vis[selected_cam_pos], \
                                                 mask_for_pose[pos_idx].to(device)[selected_cam_pos]
            dataloader_train = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

            dataset.enable_mask(selected_idx)
            label_loader = DataLoader(LabelLoader(verts_label, vis_label, mask_label), batch_size=batch_size,
                                      shuffle=False, num_workers=num_workers)

        for epoch in range(each_train_epoch):
            if (epoch + 1) % update_lr_epoch_n == 0:
                # optim = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
                for param_group in optim.param_groups:
                    param_group['lr'] *= update_lr_

            loss_get = train_step(this_loader=dataloader_train, label_loader=label_loader, network=net,
                                  memory_bank=memory_bank, optimizer=optim)
            if epoch % 8 < 8:
                loss_start = train_fixed(training_images=start_imgs, training_verts=start_vert, training_vis=start_vis,
                            training_mask=mask_for_pose[start_idx].to(device), network=net,  memory_bank=memory_bank, optimizer=optim)
                print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get), ' loss_start: ', loss_start)
            else:
                print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get))

            if use_anno_list is not None:
                loss_anno = train_step(gt_image_loader, gt_label_loader, network=net,
                              memory_bank=memory_bank, optimizer=optim)
                print('step:', step, ' epoch:', epoch, ' loss:', np.mean(loss_get), ' loss_anno: ', np.mean(loss_anno))


        inter_module.update_memory(memory_bank=memory_bank.memory[0:verts.shape[0]])

        draw_one_vis(pseudo_annos, this_retrieved_annos, selected_idx, selected_cam_pos,
                     img_dir + '/%s_%d_%s.png' % (cate, step, '%d'))

        dataset.disable_mask()
        step += 1


