import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
from PIL import Image
import json
import pytorch3d

from corr.models.base_model import BaseModel
from corr.models.feature_banks import mask_remove_near, remove_near_vertices_dist
from corr.models.mesh_deform_module import MeshDeformModule
from corr.models.solve_pose import pre_compute_kp_coords
from corr.models.solve_pose import solve_pose
from corr.models.deformable_solve_pose import get_pre_render_samples, loss_curve_part, batch_only_scale, part_initialization
from corr.models.deformable_solve_pose import solve_pose
from corr.models.deformable_solve_pose import solve_deform, part_deform
from corr.models.deformable_solve_pose import solve_part_whole as batch_solve_part_whole
from corr.utils import center_crop_fun
from corr.utils import construct_class_by_name
from corr.utils import get_param_samples
from corr.utils import normalize_features
from corr.utils import pose_error, iou, pre_process_mesh_pascal, load_off
from corr.utils.pascal3d_utils import IMAGE_SIZES

from corr.models.project_kp import PackedRaster, func_reselect, to_tensor

from Models.ShapeNeRS import HarmonicEmbedding, EncodeNetwork


def save_off(file_name, vertices, faces):
    if isinstance(vertices, torch.Tensor):
        vertices = vertices.cpu().numpy()
    if isinstance(faces, torch.Tensor):
        faces = faces.cpu().numpy()

    out_string = 'OFF\n'
    out_string += '%d %d 0\n' % (vertices.shape[0], faces.shape[0])
    for v in vertices:
        out_string += '%.16f %.16f %.16f\n' % (v[0], v[1], v[2])
    for f in faces:
        out_string += '3 %d %d %d\n' % (f[0], f[1], f[2])
    with open(file_name, 'w') as fl:
        fl.write(out_string)
    return


class Corr(BaseModel):
    def __init__(
            self,
            cfg,
            cate,
            mode,
            backbone,
            memory_bank,
            num_noise,
            max_group,
            down_sample_rate,
            training,
            inference,
            proj_mode='runtime',
            checkpoint=None,
            transforms=[],
            device="cuda:0",
            **kwargs
    ):
        super().__init__(cfg, cate, mode, checkpoint, transforms, ['loss', 'loss_main', 'loss_reg'], device)
        self.net_params = backbone
        self.memory_bank_params = memory_bank
        self.num_noise = num_noise
        self.max_group = max_group
        self.down_sample_rate = down_sample_rate
        self.training_params = training
        self.inference_params = inference
        self.dataset_config = cfg.dataset
        self.accumulate_steps = 0
        self.cate = cate
        self.num_verts = 1024
        self.visual_kp = cfg.training.visual_kp
        self.visual_part = cfg.inference.visual_part
        self.visual_mesh = cfg.training.visual_mesh
        self.visual_pose = cfg.inference.visual_pose
        self.save_mesh = cfg.inference.save_mesh
        self.part_consistency = cfg.inference.part_consistency
        self.chamfer = cfg.inference.chamfer
        self.use_pred = cfg.inference.use_pred
        self.deform_model_path = f'../corr_shape/saved_models/{cate}_{cfg.inference.deform_model_name}.pth'
        if mode == 'test':
            self.folder = cfg.args.save_dir.split('/')[-1]

        if self.dataset_config['dataset'] == 'imagenet':
            self.focal_length = 3200
            self.n_cad = 8
            from corr.datasets.imagenet_part import MeshLoader, PartsLoader
            if cate == 'car':
                self.chosen_id = '4d22bfe3097f63236436916a86a90ed7'
                self.index_id = '4d22bfe3097f63236436916a86a90ed7'
                self.part_names = ['body', 'front_left_wheel', 'front_right_wheel', 'back_left_wheel', 'back_right_wheel',
                                   'left_door', 'right_door']
            elif cate == 'aeroplane':
                self.use_pred = True
                self.chosen_id = '3cb63efff711cfc035fc197bbabcd5bd'
                self.index_id = '3cb63efff711cfc035fc197bbabcd5bd'
                self.part_names = ["body", "left_wheel", "right_wheel", "left_wing", "right_wing", "left_engine",
                                   "right_engine", "tail"]
            elif cate == 'boat':
                self.chosen_id = '246335e0dfc3a0ea834ac3b5e36b95c'
                self.index_id = '246335e0dfc3a0ea834ac3b5e36b95c'
            elif cate == 'bicycle':
                self.n_cad = 2
                self.chosen_id = '9i2YcKNJpi'
                self.index_id = '91k7HKqdM9'
                self.part_names = ["body", "front_wheel", "back_wheel", "saddle"]
            else:
                raise NotImplementedError
            self.json_path = os.path.join(self.dataset_config['root_path'], 'index', self.cate, self.index_id,
                                          f'{self.cate}_part_indexes.json')
        elif self.dataset_config['dataset'] == 'vehicle':
            from corr.datasets.vehicle_part import MeshLoader, PartsLoader
            if cate == 'airliner':
                self.chosen_id = '22831bc32bd744d3f06dea205edf9704'
                self.part_names = ['right_engine', 'left_engine', 'fuselarge', 'horizontal_stabilizer',
                                   'vertical_stabilizer', 'wheel', 'right_wing', 'left_wing']
                self.focal_length = 800
                self.n_cad = 8
                self.dataset_config['mesh_type'] = 'ply'
            elif cate == 'police':
                self.chosen_id = '372ceb40210589f8f500cc506a763c18'
                self.part_names = ['wheel', 'door', 'back_trunk', 'frame']
                self.n_cad = 8
                self.focal_length = 660
                self.dataset_config['mesh_type'] = 'ply'
            elif cate == 'bicycle':
                self.chosen_id = '91k7HKqdM9'
                self.focal_length = 660
                self.dataset_config['mesh_type'] = '!ply'
            elif cate == 'jeep':
                self.chosen_id = '178f22467bae4c729bdcc15dbc7e445d'
                self.part_names = ['front_left_wheel', 'front_right_wheel', 'back_left_wheel', 'back_right_wheel',
                                   'door', 'back_trunk', 'frame']
                self.n_cad = 4
                self.focal_length = 660
                self.dataset_config['mesh_type'] = '!ply'
            elif cate == 'minibus':
                self.chosen_id = '152f62800be34652af0545487129ca2e'
                self.part_names = ['front_left_wheel', 'front_right_wheel', 'back_left_wheel', 'back_right_wheel',
                                   'left_door', 'right_door', 'frame']
                self.n_cad = 4
                self.focal_length = 660
                self.dataset_config['mesh_type'] = 'obj'
            else:
                raise NotImplementedError

            self.json_path = os.path.join(self.dataset_config['root_path'], 'DST_part3d', 'part', f'{cate}_part_indexes.json')
        elif self.dataset_config['dataset'] == 'uda':
            self.focal_length = 3200
            self.n_cad = 8
            from corr.datasets.uda_part import MeshLoader, PartsLoader
            if cate == 'car':
                self.chosen_id = '4d22bfe3097f63236436916a86a90ed7'
                self.part_names = ['body', 'front_right_wheel', 'back_right_wheel', 'back_left_wheel',
                                   'front_left_wheel', 'left_mirror', 'right_mirror',
                                   'left_door', 'right_door']
            elif cate == 'aeroplane':
                self.use_pred = True
                self.chosen_id = '3cb63efff711cfc035fc197bbabcd5bd'
                self.part_names = ['body', 'left_wing', 'right_wing', 'left_engine', 'right_engine', 'tail']
            elif cate == 'bicycle':
                self.chosen_id = '91k7HKqdM9'
                self.part_names = ["body", "front_wheel", "back_wheel", "saddle"]
            else:
                raise NotImplementedError

            self.json_path = os.path.join(self.dataset_config['root_path'], 'index', self.cate, self.chosen_id,
                                          f'uda_{self.cate}_part_indexes.json')

        else:
            raise NotImplementedError

        if self.mode == "test":
            self.focal_length = 3200

        self.mesh_loader = MeshLoader(self.dataset_config, cate=cate)
        self.anno_parts = self.mesh_loader.anno_parts
        self.build()

        # careful !!! set focal length
        self.training_params.kp_projecter['focal_length'] = self.focal_length
        self.raster_conf = {
            'image_size': (self.dataset_config.image_sizes, self.dataset_config.image_sizes),
            **self.training_params.kp_projecter
        }
        if self.raster_conf['down_rate'] == -1:
            self.raster_conf['down_rate'] = self.net.module.net_stride
        self.net.module.kwargs['n_vert'] = self.num_verts

        self.projector = PackedRaster(self.raster_conf, self.mesh_loader.get_mesh_listed(), device='cuda')

    def build(self):
        if self.mode == "train":
            self._build_train()
        else:
            self._build_inference()

    def _build_train(self):
        self.n_gpus = torch.cuda.device_count()
        if self.training_params.separate_bank:
            self.ext_gpu = f"cuda:{self.n_gpus - 1}"
        else:
            self.ext_gpu = ""

        net = construct_class_by_name(**self.net_params)
        if self.training_params.separate_bank:
            self.net = nn.DataParallel(net, device_ids=[i for i in range(self.n_gpus - 1)]).cuda()
        else:
            self.net = nn.DataParallel(net).cuda()

        memory_bank = construct_class_by_name(
            **self.memory_bank_params,
            output_size=self.num_verts + self.num_noise * self.max_group,
            num_pos=self.num_verts,
            num_noise=self.num_noise)
        if self.training_params.separate_bank:
            self.memory_bank = memory_bank.cuda(self.ext_gpu)
        else:
            self.memory_bank = memory_bank.cuda()

        self.optim = construct_class_by_name(
            **self.training_params.optimizer, params=self.net.parameters())
        self.scheduler = construct_class_by_name(
            **self.training_params.scheduler, optimizer=self.optim)

    def step_scheduler(self):
        self.scheduler.step()
        if self.training_params.kp_projecter.type == 'voge' or self.training_params.kp_projecter.type == 'vogew':
            self.projector.step()

    def train(self, sample):
        self.net.train()
        sample = self.transforms(sample)

        img = sample['img'].cuda()
        obj_mask = sample["obj_mask"].cuda()

        index = torch.Tensor([[k for k in range(self.num_verts)]] * img.shape[0]).cuda()
        mesh_index = [self.mesh_loader.mesh_name_dict[t] for t in sample['instance_id']]

        kwargs_ = dict(func_of_mesh=func_reselect, indexs=mesh_index)
        get_mesh_index = self.mesh_loader.get_index_list(mesh_index).cuda()

        with torch.no_grad():
            kp, kpvis = self.projector(azim=sample['azimuth'].float().cuda(), elev=sample['elevation'].float().cuda(),
                                       dist=sample['distance'].float().cuda(), theta=sample['theta'].float().cuda(),
                                       **kwargs_)
            kp = torch.gather(kp, dim=1, index=get_mesh_index[..., None].expand(-1, -1, 2))
            kpvis = torch.gather(kpvis, dim=1, index=get_mesh_index)

        # import BboxTools as bbt
        # from PIL import Image, ImageDraw
        #
        # def foo(t0, kps, vis_mask_, iidx=0, point_size=3):
        #     im = Image.fromarray((t0.cpu().numpy()[iidx] * 255).astype(np.uint8)).convert('RGB')
        #     imd = ImageDraw.ImageDraw(im)
        #     for k, vv in zip(kps[iidx], vis_mask_[iidx]):
        #         this_bbox = bbt.box_by_shape((point_size, point_size), (int(k[0]), int(k[1])), image_boundary=im.size[::-1])
        #         imd.ellipse(this_bbox.pillow_bbox(), fill=((0, 255, 0) if vv.item() else (255, 0, 0)))
        #         break
        #     return im
        # for idx in range(8):
        # ##  careful !! ori img
        # ##  foo(sample['img_ori'].permute(0, 2, 3, 1), kp, kpvis, iidx=idx).save(f'tem_{sample["instance_id"][idx]}.png')
        #     foo(sample['img_ori'].permute(0, 2, 3, 1), kp, kpvis, iidx=idx).save(f'tem_{idx}.png')
        #
        # exit(0)

        features = self.net.forward(img, keypoint_positions=kp, obj_mask=1 - obj_mask, do_normalize=True, )

        if self.training_params.separate_bank:
            get, y_idx, noise_sim = self.memory_bank(
                features.to(self.ext_gpu), index.to(self.ext_gpu), kpvis.to(self.ext_gpu)
            )
        else:
            get, y_idx, noise_sim = self.memory_bank(features, index, kpvis)

        if 'voge' in self.projector.raster_type:
            kpvis = kpvis > self.projector.kp_vis_thr

        get /= self.training_params.T

        kappas = {'pos': self.training_params.get('weight_pos', 0),
                  'near': self.training_params.get('weight_near', 1e5),
                  'clutter': -math.log(self.training_params.weight_noise)}
        # The default manner in VoGE-NeMo
        if self.training_params.remove_near_mode == 'vert':
            with torch.no_grad():
                verts_ = func_reselect(self.projector.meshes, mesh_index)[1]
                vert_ = torch.gather(verts_, dim=1, index=get_mesh_index[..., None].expand(-1, -1, 3))

                vert_dis = (vert_.unsqueeze(1) - vert_.unsqueeze(2)).pow(2).sum(-1).pow(.5)

                mask_distance_legal = remove_near_vertices_dist(
                    vert_dis,
                    thr=self.training_params.distance_thr,
                    num_neg=self.num_noise * self.max_group,
                    kappas=kappas,
                )
                if mask_distance_legal.shape[0] != get.shape[0]:
                    mask_distance_legal = mask_distance_legal.expand(get.shape[0], -1, -1).contiguous()
        # The default manner in original-NeMo
        else:
            mask_distance_legal = mask_remove_near(
                kp,
                thr=self.training_params.distance_thr
                    * torch.ones((img.shape[0],), dtype=torch.float32).cuda(),
                num_neg=self.num_noise * self.max_group,
                dtype_template=get,
                kappas=kappas,
            )
        loss_main = nn.CrossEntropyLoss(reduction="none").cuda()(
            (get.view(-1, get.shape[2]) - mask_distance_legal.view(-1, get.shape[2]))[kpvis.view(-1), :],
            y_idx.view(-1)[kpvis.view(-1)],
        )
        loss_main = torch.mean(loss_main)

        if self.num_noise > 0:
            loss_reg = torch.mean(noise_sim) * self.training_params.loss_reg_weight
            loss = loss_main + loss_reg
        else:
            loss_reg = torch.zeros(1)
            loss = loss_main

        loss.backward()

        self.accumulate_steps += 1
        if self.accumulate_steps % self.training_params.train_accumulate == 0:
            self.optim.step()
            self.optim.zero_grad()

        self.loss_trackers['loss'].append(loss.item())
        self.loss_trackers['loss_main'].append(loss_main.item())
        self.loss_trackers['loss_reg'].append(loss_reg.item())

        return {'loss': loss.item(), 'loss_main': loss_main.item(), 'loss_reg': loss_reg.item()}

    def _build_inference(self):
        self.net = construct_class_by_name(**self.net_params)
        self.net = nn.DataParallel(self.net).to(self.device)
        self.net.load_state_dict(self.checkpoint["state"])

        self.memory_bank = construct_class_by_name(
            **self.memory_bank_params,
            output_size=self.num_verts,
            num_pos=self.num_verts,
            num_noise=0
        ).to(self.device)

        with torch.no_grad():
            self.memory_bank.memory.copy_(
                self.checkpoint["memory"][0: self.memory_bank.memory.shape[0]]
            )
        memory = (
            self.checkpoint["memory"][0: self.memory_bank.memory.shape[0]]
            .detach()
            .cpu()
            .numpy()
        )
        clutter = (
            self.checkpoint["memory"][self.memory_bank.memory.shape[0]:]
            .detach()
            .cpu()
            .numpy()
        )
        self.feature_bank = torch.from_numpy(memory)
        self.clutter_bank = torch.from_numpy(clutter).to(self.device)
        self.clutter_bank = normalize_features(
            torch.mean(self.clutter_bank, dim=0)
        ).unsqueeze(0)
        self.kp_features = self.checkpoint["memory"][
                           0: self.memory_bank.memory.shape[0]
                           ].to(self.device)

        if self.cfg.task == 'correlation_marking':
            return

        image_h, image_w = (self.dataset_config.image_sizes, self.dataset_config.image_sizes)
        map_shape = (image_h // self.down_sample_rate, image_w // self.down_sample_rate)

        if self.inference_params.cameras.get('image_size', 0) == -1:
            self.inference_params.cameras['image_size'] = (map_shape, )
        if self.inference_params.cameras.get('principal_point', 0) == -1:
            self.inference_params.cameras['principal_point'] = ((map_shape[1] // 2, map_shape[0] // 2), )
            print('principal_point: ', self.inference_params.cameras['principal_point'])
        self.inference_params.cameras['focal_length'] = self.focal_length / self.down_sample_rate
        print('focal_length: ', self.inference_params.cameras['focal_length'])

        cameras = construct_class_by_name(**self.inference_params.cameras, device=self.device)
        raster_settings = construct_class_by_name(
            **self.inference_params.raster_settings, image_size=map_shape
        )
        if self.inference_params.rasterizer.class_name == 'VoGE.Renderer.GaussianRenderer':
            self.rasterizer = construct_class_by_name(
                **self.inference_params.rasterizer, cameras=cameras, render_settings=raster_settings
            )
        else:
            self.rasterizer = construct_class_by_name(
                **self.inference_params.rasterizer, cameras=cameras, raster_settings=raster_settings
            )

        # careful !!! chosen_id or chosen_ids
        chosen_verts, chosen_faces = self.mesh_loader.get_meshes(self.chosen_id)

        self.chosen_verts = chosen_verts
        self.chosen_faces = chosen_faces

        xvert = [chosen_verts]
        xface = [chosen_faces]
        mesh_index = [self.mesh_loader.mesh_name_dict[self.chosen_id]]
        get_mesh_index = self.mesh_loader.get_index_list(mesh_index).cuda()

        # interpolate features from self.feature_bank
        sample_index = get_mesh_index[0]
        vertex = xvert[0]
        sample_vertex = vertex[sample_index]
        kdtree = KDTree(sample_vertex)
        dist, nearest_idx = kdtree.query(vertex, k=3)
        dist = torch.from_numpy(dist).to(self.feature_bank.device) + 1e-4
        dist = dist.type(torch.float32)
        weight = torch.softmax(1 / dist, dim=1)
        # interpolate
        nearest_feature = [self.feature_bank[nearest] for nearest in nearest_idx]
        # print('nearest_feature: ', torch.stack(nearest_feature, dim=0).shape)
        # print('dist: ', weight.unsqueeze(-1).shape)
        feature = (torch.stack(nearest_feature, dim=0) * weight.unsqueeze(-1)).sum(dim=1)
        feature = feature / torch.norm(feature, dim=1, keepdim=True)
        self.chosen_feature = feature

        self.inter_module = MeshDeformModule(
            xvert,
            xface,
            [feature],
            rasterizer=self.rasterizer,
        ).to(self.device)

        self.deform_encoder = HarmonicEmbedding().to(self.device)
        self.deform_net = EncodeNetwork(n_inputs=4, n_lantern=1, n_output=2, input_size=60, lantern_size=self.n_cad,
                                        hidden_size=32, output_size=3).to(self.device)

        self.deform_net = torch.load(self.deform_model_path)
        self.deform_net.eval()

        # verts_with_feature = chosen_verts[get_mesh_index[0]]
        # kdtree = KDTree(verts_with_feature)
        # dist, nearest_idx = kdtree.query(chosen_verts, k=10)
        # dist = torch.from_numpy(dist).to(self.device) + 1e-4
        # dist = dist.type(torch.float32)
        # self.deform_weight = torch.softmax(1 / dist, dim=1)

        # part_names = self.parts_loader.get_name_listed()
        # print('part_names: ', part_names)

        # part index
        self.part_indexes = json.load(open(self.json_path))

        self.parts_faces = []
        for part_id, part_name in enumerate(self.part_names):
            print('divide part, part_id: ', part_id)
            part_index = self.part_indexes[part_name]
            part_list = dict()
            for idx, index in enumerate(part_index):
                part_list[index] = idx

            part_faces = []
            for face in self.chosen_faces:
                v1, v2, v3 = face
                v1 = v1.item()
                v2 = v2.item()
                v3 = v3.item()
                if v1 in part_list and v2 in part_list and v3 in part_list:
                    face[0] = part_list[v1]
                    face[1] = part_list[v2]
                    face[2] = part_list[v3]
                    part_faces.append(face)
            part_faces = np.array(part_faces, dtype=np.int32)
            part_faces = torch.from_numpy(part_faces)
            self.parts_faces.append(part_faces)
            print('part_faces: ', part_faces.shape)

        (azimuth_samples,
         elevation_samples,
         theta_samples,
         distance_samples,
         px_samples,
         py_samples,
         ) = get_param_samples(self.cfg)

        self.init_mode = self.cfg.inference.get('init_mode', '3d_batch')

        if 'batch' in self.init_mode:
            self.feature_pre_rendered, self.cam_pos_pre_rendered, self.theta_pre_rendered = get_pre_render_samples(
                self.inter_module,
                azum_samples=azimuth_samples,
                elev_samples=elevation_samples,
                theta_samples=theta_samples,
                distance_samples=distance_samples,
                device=self.device
            )
            # assert distance_samples.shape[0] == 1
            self.record_distance = distance_samples[0]

        else:
            self.poses, self.kp_coords, self.kp_vis = pre_compute_kp_coords(
                self.mesh_path,
                azimuth_samples=azimuth_samples,
                elevation_samples=elevation_samples,
                theta_samples=theta_samples,
                distance_samples=distance_samples,
            )

        print('deformable_correspondence build inference done')

    def evaluate(self, sample, debug=False):
        self.net.eval()

        sample = self.transforms(sample)
        img = sample['img'].cuda()
        print('img: ', img.shape)
        if img.shape[2] != 512 or img.shape[3] != 512:
            print('wrong img shape')
            print('img: ', img.shape)
            return None

        mesh_index = [0] * img.shape[0]
        vis_mesh_index = [self.mesh_loader.mesh_name_dict[self.chosen_id]] * img.shape[0]

        kwargs_ = dict(indexs=mesh_index)

        with torch.no_grad():
            feature_map = self.net.module.forward_test(img)

        # pose
        if self.use_pred and sample['pose_pred'] == 0:
            return None
        pose_pred = []
        dof = int(self.init_mode.split('d_')[0])
        if not self.use_pred:
            # print('sample distance: ', sample['distance'])

            preds = solve_pose(
                self.cfg,
                feature_map,
                self.inter_module,
                self.clutter_bank,
                cam_pos_pre_rendered=self.cam_pos_pre_rendered,
                theta_pre_rendered=self.theta_pre_rendered,
                feature_pre_rendered=self.feature_pre_rendered,
                device=self.device,
                principal=None,
                distance_source=sample['distance'].to(feature_map.device) if dof == 3 else None,
                distance_target=self.record_distance * torch.ones(feature_map.shape[0]).to(
                    feature_map.device) if dof == 3 else torch.ones(feature_map.shape[0]).to(feature_map.device),
                pre_render=self.cfg.inference.get('pre_render', True),
                dof=dof,
                **kwargs_
            )
            if isinstance(preds, dict):
                preds = [preds]

            for i, pred in enumerate(preds):
                pose_pred.append(pred["final"][0])
                if "azimuth" in sample and "elevation" in sample and "theta" in sample:
                    pred["pose_error"] = pose_error({k: sample[k][i] for k in ["azimuth", "elevation", "theta"]},
                                                    pred["final"][0])
                    print('pose_error: ', pred["pose_error"])
                else:
                    pred["pose_error"] = np.random.rand()

            if self.visual_pose:
                for idx in range(len(img)):
                    self.projector.visual_pose(vis_mesh_index[idx], sample['img_ori'][idx], preds[idx]["final"][0],
                                               self.folder, sample["name"][idx])

            distances = torch.from_numpy(np.array([pred["final"][0]['distance'] for pred in preds])).cuda()
            elevations = torch.from_numpy(np.array([pred["final"][0]['elevation'] for pred in preds])).cuda()
            azimuths = torch.from_numpy(np.array([pred["final"][0]['azimuth'] for pred in preds])).cuda()
            thetas = torch.from_numpy(np.array([pred["final"][0]['theta'] for pred in preds])).cuda()
        else:
            preds = sample['pose_pred']
            if isinstance(preds, dict):
                preds = [preds]

            pose_pred = preds

            if self.visual_pose:
                for idx in range(len(img)):
                    self.projector.visual_pose(vis_mesh_index[idx], sample['img_ori'][idx], preds[idx],
                                               self.folder, sample["name"][idx])

            distances = torch.from_numpy(np.array([pred['distance'] for pred in preds])).cuda()
            elevations = torch.from_numpy(np.array([pred['elevation'] for pred in preds])).cuda()
            azimuths = torch.from_numpy(np.array([pred['azimuth'] for pred in preds])).cuda()
            thetas = torch.from_numpy(np.array([pred['theta'] for pred in preds])).cuda()

        initial_pose = dict(
            distance=distances,
            elevation=elevations,
            azimuth=azimuths,
            theta=thetas,
        )

        # deform
        deformation, latent = solve_deform(
            self.cfg,
            feature_map,
            self.inter_module,
            self.clutter_bank,
            self.deform_net,
            self.deform_encoder,
            self.n_cad,
            initial_pose,
            device=self.device,
        )

        deformation = deformation.detach().cpu()
        verts = self.chosen_verts + deformation * self.cfg.inference.defrom_weight
        faces = self.chosen_faces

        # save mesh as .obj
        if self.save_mesh:
            saved_path = './visual/Mesh/' + self.folder
            if not os.path.exists(saved_path):
                os.makedirs(saved_path)

            save_off(os.path.join(saved_path, sample["name"][0] + '.off'), verts, faces)

            torch.save(dict(deform_verts=verts, deform_faces=faces, latent=latent),
                       os.path.join(saved_path, sample["name"][0] + '_deform.pth'))
            # for idx in range(len(img)):
            #     mesh_path = os.path.join(saved_path, sample["name"][idx] + '.obj')
            #     pytorch3d.io.save_obj(mesh_path, self.chosen_verts, faces)
            #
            # thefile = open(os.path.join(saved_path, sample["name"][0] + '.obj'), 'w')
            # for item in verts:
            #     thefile.write("v {0} {1} {2}\n".format(item[0], item[1], item[2]))
            #
            # save_faces = faces + 1
            #
            # for item in save_faces:
            #     thefile.write("f {0}//{0} {1}//{1} {2}//{2}\n".format(item[0], item[1], item[2]))
            #
            # thefile.close()

            # save_vertices = self.chosen_verts.numpy()
            # save_faces = faces.numpy()
            # thefile = os.path.join(saved_path, sample["name"][0] + '.off')
            # out_string = 'OFF\n'
            # out_string += '%d %d 0\n' % (save_vertices.shape[0], save_faces.shape[0])
            # for v in save_vertices:
            #     out_string += '%.16f %.16f %.16f\n' % (v[0], v[1], v[2])
            # for f in save_faces:
            #     out_string += '3 %d %d %d\n' % (f[0], f[1], f[2])
            # with open(thefile, 'w') as fl:
            #     fl.write(out_string)

        # if self.visual_pose:
        #     for idx in range(len(img)):
        #         self.projector.visual_pose(0, sample['img_ori'][idx], preds[idx]["final"][0],
        #                                    self.folder, sample["name"][idx], vert=verts, face=faces)

        # divide parts
        parts_name = []
        parts_xvert = []
        parts_xface = []
        parts_feature = []
        parts_offset = []
        chosen_scales = []
        chosen_offsets = []
        offset = None
        for part_id, part_name in enumerate(self.part_names):
            part_index = self.part_indexes[part_name]
            part_vert = verts[part_index]
            part_face = self.parts_faces[part_id]

            part_feature = self.chosen_feature[part_index]
            # print('part_name: ', part_name, 'part_vert: ', part_vert.shape, 'part_face: ', part_face.shape,
            #       'part_feature: ', part_feature.shape, 'part_face_max: ', part_face.max(), 'part_face_min: ', part_face.min())

            part_offsets = (part_vert.max(axis=0)[0] + part_vert.min(axis=0)[0]) / 2
            part_vert = part_vert - part_offsets

            part_inter_module = MeshDeformModule(
                [part_vert],
                [part_face],
                features=[part_feature],
                rasterizer=self.rasterizer,
            ).to(self.device)

            if self.cfg.inference.part_initialization is True:
                loss, offset, scale, score = part_initialization(
                    self.cfg,
                    feature_map,
                    self.clutter_bank,
                    part_inter_module,
                    [part_feature],
                    initial_pose,
                    part_offsets.numpy(),
                )
            else:
                loss, scale, score = batch_only_scale(
                    self.cfg,
                    feature_map,
                    self.clutter_bank,
                    part_inter_module,
                    [part_feature],
                    initial_pose,
                    part_offsets.numpy(),
                )

            if score > 0:
                # part deform
                deform_loss, part_deformation = part_deform(
                    self.cfg,
                    feature_map,
                    part_inter_module,
                    self.clutter_bank,
                    self.deform_net,
                    self.deform_encoder,
                    self.n_cad,
                    initial_pose,
                    part_offsets.numpy()
                )

                if deform_loss < loss:
                    part_deformation = part_deformation.detach().cpu()
                    part_vert = part_vert + part_deformation * self.cfg.inference.part_defrom_weight

                parts_xvert.append(part_vert)
                parts_xface.append(part_face)
                parts_feature.append(part_feature)
                parts_offset.append(part_offsets)
                chosen_scales.append(scale)
                parts_name.append(part_name)

                if offset is not None:
                    chosen_offsets.append(offset[0])
            else:
                print('no part ', part_name)

        kwargs_ = dict(chosen_scales=chosen_scales, chosen_offsets=chosen_offsets)

        if len(parts_name) > 0:
            if self.part_consistency is True:
                near_pairs = []
                index_offset_1 = 0
                for part_id1, part1_vert in enumerate(parts_xvert):
                    kdtree = KDTree(part1_vert)

                    if part_id1 > 0:
                        index_offset_1 += len(parts_xvert[part_id1 - 1])
                    index_offset_2 = index_offset_1
                    for part_id2, part2_vert in enumerate(parts_xvert):
                        if part_id2 <= part_id1:
                            continue
                        index_offset_2 += len(parts_xvert[part_id2 - 1])

                        dist, nearest_idx = kdtree.query(part2_vert, k=1)
                        nearest = np.argwhere(dist < self.cfg.inference.dis_threshold)
                        for idx in nearest:
                            self.near_pairs.append((nearest_idx[idx] + index_offset_1, idx + index_offset_2))

                print('near_pairs: ', len(near_pairs))
                kwargs_['near_pairs'] = near_pairs

            parts_inter_module = MeshDeformModule(
                parts_xvert,
                parts_xface,
                parts_feature,
                rasterizer=self.rasterizer,
            ).to(self.device)

            part_preds = batch_solve_part_whole(
                self.cfg,
                feature_map,
                self.clutter_bank,
                parts_inter_module,
                parts_feature,
                initial_pose,
                parts_offset,
                **kwargs_
            )

            parts_xvert = [part_vert.numpy() for part_vert in parts_xvert]
            parts_xface = [part_face.numpy() for part_face in parts_xface]

            save_path = './result/' + self.folder
            if os.path.exists(save_path) is False:
                os.makedirs(save_path)
            # careful !! original image shape different
            # vis_img = sample['img_ori'].numpy()[0]
            # vis_img = vis_img.transpose(1, 2, 0) * 255
            torch.save(dict(verts=parts_xvert, faces=parts_xface, pred_pose=pose_pred[0],
                            pred_part=part_preds[0]["final"]), os.path.join(save_path, sample["name"][0] + '.pth'))

            if self.visual_part:
                # careful !! original image shape different
                vis_img = sample['img_ori'].numpy()[0]
                # print('vis_img: ', vis_img.shape)
                # print('vis_img: ', vis_img.transpose(1, 2, 0).shape)
                self.projector.visual_part_pose(parts_xvert, parts_xface, vis_img * 255,
                                                pose_pred[0], part_preds[0]["final"], self.folder,
                                                sample["name"][0])

            segment = self.projector.get_segment_depth(parts_xvert, parts_xface, pose_pred[0], part_preds[0]["final"])
        else:
            segment = np.ones((512, 512)) * len(parts_name)
        annotations = sample['seg']
        vis_imgs = sample['img_ori'].numpy()
        # print('vis_imgs: ', vis_imgs.shape)
        # print('annotations: ', annotations.shape)
        iou_dict = dict()
        # print('segment: ', segment.min(), segment.max())
        # careful !! axis different for different dataset
        if vis_imgs[0].shape[0] == 3:
            img_mask = np.sum(vis_imgs[0], axis=0) > 0
        elif vis_imgs[0].shape[2] == 3:
            img_mask = np.sum(vis_imgs[0], axis=2) > 0
        else:
            print('wrong ori_img shape')
            exit(0)
        anno = annotations[0].type(torch.int32)
        anno_compare = anno[img_mask]

        # print('anno: ', anno.min(), anno.max())
        if anno.max() != len(self.anno_parts):
            print('not aligned')
            exit(0)
        seg = torch.zeros_like(anno) + len(self.anno_parts)

        total_intersection = 0
        total_union = 0
        intersections = []
        unions = []
        for anno_id, name in enumerate(self.anno_parts):
            for part_id, part_name in enumerate(parts_name):
                if name in part_name:
                    seg[segment == part_id] = anno_id

            seg_compare = seg[img_mask]

            intersection = ((seg_compare == anno_id) & (anno_compare == anno_id)).sum()
            union = ((seg_compare == anno_id) | (anno_compare == anno_id)).sum()
            iou = intersection / union
            iou_dict[name] = iou
            total_intersection += intersection
            total_union += union
            intersections.append(intersection)
            unions.append(union)

        seg_compare = seg[img_mask]
        bg_intersection = ((seg_compare == len(self.anno_parts)) & (anno_compare == len(self.anno_parts))).sum()
        bg_union = ((seg_compare == len(self.anno_parts)) | (anno_compare == len(self.anno_parts))).sum()
        total_intersection += bg_intersection
        total_union += bg_union
        bg_iou = bg_intersection / bg_union
        intersections.append(bg_intersection)
        unions.append(bg_union)
        iou_dict['bg'] = bg_iou

        miou = total_intersection / total_union
        iou_dict['mIoU'] = miou

        iou_dict['intersections'] = intersections
        iou_dict['unions'] = unions

        save_path = './visual/segment/' + self.folder
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        vis_get = seg.clone() / seg.max() * 255
        vis_get = vis_get.detach().cpu().numpy().astype(np.uint8)
        vis_get = Image.fromarray(vis_get)
        vis_get.save(f'{save_path}/{sample["name"][0]}_{miou}_seg.jpg')

        # print('vis_imgs[0]: ', vis_imgs[0].shape)
        # vis_img = vis_imgs[0].transpose(1, 2, 0) * 255
        # vis_img = Image.fromarray(vis_img)
        # vis_img.save(f'{save_path}/{sample["name"][0]}_ori.png')

        vis_anno = anno.clone() / anno.max() * 255
        vis_anno = vis_anno.numpy().astype(np.uint8)
        vis_anno = Image.fromarray(vis_anno)
        vis_anno.save(f'{save_path}/{sample["name"][0]}_{miou}_gt.png')

        return iou_dict

    def get_ckpt(self, **kwargs):
        ckpt = {}
        ckpt['state'] = self.net.state_dict()
        ckpt['memory'] = self.memory_bank.memory
        ckpt['lr'] = self.optim.param_groups[0]['lr']
        for k in kwargs:
            ckpt[k] = kwargs[k]
        return ckpt

    def predict_inmodal(self, sample, visualize=False):
        self.net.eval()

        # sample = self.transforms(sample)
        img = sample["img"].to(self.device)
        assert len(img) == 1, "The batch size during validation should be 1"

        with torch.no_grad():
            feature_map = self.net.module.forward_test(img)

        clutter_score = None
        if not isinstance(self.clutter_bank, list):
            clutter_bank = [self.clutter_bank]
        for cb in clutter_bank:
            _score = (
                torch.nn.functional.conv2d(feature_map, cb.unsqueeze(2).unsqueeze(3))
                .squeeze(0)
                .squeeze(0)
            )
            if clutter_score is None:
                clutter_score = _score
            else:
                clutter_score = torch.max(clutter_score, _score)

        nkpt, c = self.kp_features.shape
        feature_map_nkpt = feature_map.expand(nkpt, -1, -1, -1)
        kp_features = self.kp_features.view(nkpt, c, 1, 1)
        kp_score = torch.sum(feature_map_nkpt * kp_features, dim=1)
        kp_score, _ = torch.max(kp_score, dim=0)

        clutter_score = clutter_score.detach().cpu().numpy().astype(np.float32)
        kp_score = kp_score.detach().cpu().numpy().astype(np.float32)
        pred_mask = (kp_score > clutter_score).astype(np.uint8)
        pred_mask_up = cv2.resize(
            pred_mask, dsize=(pred_mask.shape[1] * self.down_sample_rate, pred_mask.shape[0] * self.down_sample_rate),
            interpolation=cv2.INTER_NEAREST)

        pred = {
            'clutter_score': clutter_score,
            'kp_score': kp_score,
            'pred_mask_orig': pred_mask,
            'pred_mask': pred_mask_up,
        }

        if 'inmodal_mask' in sample:
            gt_mask = sample['inmodal_mask'][0].detach().cpu().numpy()
            pred['gt_mask'] = gt_mask
            pred['iou'] = iou(gt_mask, pred_mask_up)

            obj_mask = sample['amodal_mask'][0].detach().cpu().numpy()
            pred['obj_mask'] = obj_mask

            # pred_mask_up[obj_mask == 0] = 0
            thr = 0.8
            new_mask = (kp_score > thr).astype(np.uint8)
            new_mask = cv2.resize(new_mask, dsize=(obj_mask.shape[1], obj_mask.shape[0]),
                                  interpolation=cv2.INTER_NEAREST)
            new_mask[obj_mask == 0] = 0
            pred['iou'] = iou(gt_mask, new_mask)
            pred['pred_mask'] = new_mask

        return pred

    def fix_init(self, sample):
        self.net.train()
        sample = self.transforms(sample)

        img = sample['img'].cuda()
        obj_mask = sample["obj_mask"].cuda()
        index = torch.Tensor([[k for k in range(self.num_verts)]] * img.shape[0]).cuda()

        kwargs_ = dict(principal=sample['principal']) if 'principal' in sample.keys() else dict()
        if 'voge' in self.projector.raster_type:
            with torch.no_grad():
                frag_ = self.projector(azim=sample['azimuth'].float().cuda(), elev=sample['elevation'].float().cuda(),
                                       dist=sample['distance'].float().cuda(), theta=sample['theta'].float().cuda(),
                                       **kwargs_)

            features, kpvis = self.net.forward(img, keypoint_positions=frag_, obj_mask=1 - obj_mask,
                                               do_normalize=True, )
        else:
            if self.training_params.proj_mode == 'prepared':
                kp = sample['kp'].cuda()
                kpvis = sample["kpvis"].cuda().type(torch.bool)
            else:
                with torch.no_grad():
                    kp, kpvis = self.projector(azim=sample['azimuth'].float().cuda(),
                                               elev=sample['elevation'].float().cuda(),
                                               dist=sample['distance'].float().cuda(),
                                               theta=sample['theta'].float().cuda(), **kwargs_)

            features = self.net.forward(img, keypoint_positions=kp, obj_mask=1 - obj_mask, do_normalize=True, )
        return features.detach(), kpvis.detach()
