import os
import sys
from PIL import Image
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import point_cloud_utils as pcu
from sklearn.neighbors import KDTree
from pytorch3d.io import load_obj

from corr.utils import get_abs_path
import os


class MeshLoader():
    def __init__(self, dataset_config, cate='car'):
        if cate == 'car':
            self.skip_list = ['17c32e15723ed6e0cd0bf4a0e76b8df5']
            self.ray_list = ['85f6145747a203becc08ff8f1f541268', '5343e944a7753108aa69dfdc5532bb13',
                             '67a3dfa9fb2d1f2bbda733a39f84326d']
            self.up_list = []
            chosen_id = '372ceb40210589f8f500cc506a763c18'
            self.anno_parts = ['body', 'front_left_wheel', 'front_right_wheel', 'back_left_wheel', 'back_right_wheel',
                               'left_door', 'right_door']
        elif cate == 'aeroplane':
            self.skip_list = []
            self.ray_list = []
            self.up_list = ['1d63eb2b1f78aa88acf77e718d93f3e1']
            chosen_id = '3cb63efff711cfc035fc197bbabcd5bd'
            self.anno_parts = ["body", "left_wheel", "right_wheel", "left_wing", "right_wing", "left_engine",
                               "right_engine", "tail"]
        elif cate == 'boat':
            self.skip_list = []
            self.ray_list = []
            self.up_list = []
            chosen_id = '2340319ec4d93ae8c1df6b0203ecb359'
            self.anno_parts = ['body', 'sail']
        elif cate == 'bicycle':
            self.skip_list = ['3ZOy2KonL0']
            self.ray_list = []
            self.up_list = []
            chosen_id = '91k7HKqdM9'
            self.anno_parts = ["body", "front_wheel", "back_wheel", "saddle"]
        else:
            raise NotImplementedError

        root_path = os.path.join(dataset_config['root_path'], 'synthetic')
        index_path = os.path.join(root_path, 'index', cate, chosen_id)

        self.mesh_path = os.path.join(root_path, 'mesh', cate)
        name_path = self.mesh_path
        self.mesh_name_dict = dict()
        for name in os.listdir(name_path):
            name = name.split('_')[0]
            if name in self.skip_list or '.' in name:
                continue
            self.mesh_name_dict[name] = len(self.mesh_name_dict)
        if chosen_id not in self.mesh_name_dict:
            self.mesh_name_dict[chosen_id] = len(self.mesh_name_dict)
        self.mesh_list = [self.get_meshes(name_) for name_ in self.mesh_name_dict.keys()]

        self.index_list = [np.load(os.path.join(index_path, t, 'index.npy'), allow_pickle=True)[()] for t in self.mesh_name_dict.keys()]

    def get_mesh_listed(self):
        return [t[0].numpy() for t in self.mesh_list], [t[1].numpy() for t in self.mesh_list]

    def get_ori_mesh_listed(self):
        return [t[0].numpy() for t in self.ori_mesh_list], [t[1].numpy() for t in self.ori_mesh_list]

    def get_index_list(self, indexs=None):
        if indexs is not None:
            return torch.from_numpy(np.array([self.index_list[t] for t in indexs]))
        return torch.from_numpy(np.array(self.index_list))

    def get_meshes(self, instance_id, ):
        verts, faces, _, _ = pcu.load_mesh_vfnc(os.path.join(self.mesh_path, f'{instance_id}_recon_mesh.ply'))

        # faces
        faces = torch.from_numpy(faces.astype(np.int32))

        # normalize
        vert_middle = (verts.max(axis=0) + verts.min(axis=0)) / 2
        if instance_id in self.ray_list:
            vert_middle[1] += 0.05
        if instance_id in self.up_list:
            vert_middle[1] -= 0.08
        vert_scale = ((verts.max(axis=0) - verts.min(axis=0)) ** 2).sum() ** 0.5
        verts = verts - vert_middle
        verts = verts / vert_scale
        verts = torch.from_numpy(verts.astype(np.float32))
        
        return verts, faces


class ImagenetPart(Dataset):
    def __init__(self, data_type, category, root_path, **kwargs):
        super().__init__()
        self.data_type = data_type
        self.category = category
        if data_type == 'test':
            root_path = os.path.join(root_path, 'PartImageNet')
            root_path = get_abs_path(root_path)

            data_path = os.path.join(root_path, 'images', category)
            anno_path = os.path.join(root_path, 'annotations', category)
            pose_path = 'eval/pose_estimation_3d_nemo_%s' % category
            self.preds = torch.load(os.path.join(pose_path, 'pascal3d_occ0_%s_val.pth' % category))
            self.kkeys = {t.split('/')[1]: t for t in self.preds.keys()}
            # print('len: ', len(self.kkeys))

            self.save_path = get_abs_path(data_path)

            self.images = []
            self.segs = []
            self.annos = []
            self.names = []
            for image_name in os.listdir(data_path):
                if 'seg' in image_name:
                    continue
                image_fn = os.path.join(data_path, image_name)
                seg_fn = os.path.join(data_path, image_name.replace('.JPEG', '_seg.png'))
                image = cv2.imread(image_fn, cv2.IMREAD_UNCHANGED)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                seg = cv2.imread(seg_fn, cv2.IMREAD_UNCHANGED)
                anno_fn = os.path.join(anno_path, image_name.replace('.JPEG', '.npz'))
                anno = np.load(anno_fn)
                self.images.append(image)
                self.segs.append(seg)
                self.annos.append(anno)
                self.names.append(image_name.split('.')[0])
        else:
            root_path = os.path.join(get_abs_path(root_path), 'synthetic')
            self.skip_list = ['17c32e15723ed6e0cd0bf4a0e76b8df5']

            self.img_path = os.path.join(root_path, 'image', self.data_type, self.category)
            self.render_img_path = os.path.join(root_path, 'render_img', self.data_type, self.category)
            self.angle_path = os.path.join(root_path, 'angle', self.data_type, self.category)

            self.instance_list = [x for x in os.listdir(self.img_path) if '.' not in x]

            self.img_fns = []
            self.angle_fns = []
            self.render_img_fns = []
            lambda_fn = lambda x: int(x[:3])
            for instance_id in self.instance_list:
                if instance_id in self.skip_list:
                    continue
                img_list = os.listdir(os.path.join(self.img_path, instance_id))
                img_list = sorted(img_list, key=lambda_fn)
                self.img_fns += [os.path.join(self.img_path, instance_id, x) for x in img_list]
                angle_list = os.listdir(os.path.join(self.angle_path, instance_id))
                angle_list = sorted(angle_list, key=lambda_fn)
                self.angle_fns += [os.path.join(self.angle_path, instance_id, x) for x in angle_list]
                render_img_list = os.listdir(os.path.join(self.render_img_path, instance_id))
                render_img_list = sorted(render_img_list, key=lambda_fn)
                self.render_img_fns += [os.path.join(self.render_img_path, instance_id, x) for x in render_img_list]

                assert len(self.img_fns) == len(self.angle_fns) == len(self.render_img_fns), \
                    f'{len(self.img_fns)}, {len(self.angle_fns)}, {len(self.render_img_fns)}'

    def __getitem__(self, item):
        sample = dict()
        if self.data_type == 'test':
            ori_img = self.images[item]
            img = ori_img / 255.
            seg = self.segs[item]
            anno = self.annos[item]
            name = self.names[item]
            pose_pred = 0.
            if name in self.kkeys:
                pose_pred = self.preds[self.kkeys[name]]['final'][0]
                if self.category == 'bicycle':
                    pose_pred['azimuth'] = pose_pred['azimuth'] + np.pi / 2
                pose_pred['elevation'] = np.pi - pose_pred['elevation']

            distance = float(anno['distance'])
            elevation = float(anno['elevation'])
            azimuth = float(anno['azimuth'])
            theta = float(anno['theta'])

            img = img.transpose(2, 0, 1)

            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(ori_img).astype(np.float32)
            sample['seg'] = np.ascontiguousarray(seg).astype(np.float32)
            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['pose_pred'] = pose_pred
            sample['name'] = name
        else:
            ori_img = cv2.imread(self.img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
            render_img = cv2.imread(self.render_img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = np.array(ori_img)
            render_img = np.array(render_img)
            angle = np.load(self.angle_fns[item], allow_pickle=True)[()]

            instance_id = self.img_fns[item].split('/')[-2]

            img = ori_img.transpose(2, 0, 1)
            mask = render_img[:, :, 3]
            mask = cv2.resize(mask, (img.shape[2], img.shape[1]), interpolation=cv2.INTER_NEAREST)
            mask[mask > 0] = 1

            distance = angle['dist']
            elevation = np.pi / 2 - angle['phi']
            azimuth = angle['theta'] + np.pi / 2
            theta = angle['camera_rotation']

            img = img / 255.0
            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(img).astype(np.float32)
            sample['obj_mask'] = np.ascontiguousarray(mask).astype(np.float32)

            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['instance_id'] = instance_id
            sample['this_name'] = item

        return sample

    def __len__(self):
        if self.data_type == 'test':
            return len(self.images)
        else:
            return len(self.img_fns)


class Normalize:
    def __init__(self):
        self.trans = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

    def __call__(self, sample):
        sample["img"] = self.trans(sample["img"])
        return sample
