import os
import sys
from PIL import Image
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import point_cloud_utils as pcu
from sklearn.neighbors import KDTree
from pytorch3d.io import load_obj

from corr.utils import get_abs_path
import os


class MeshLoader():
    def __init__(self, dataset_config, cate='car', type=None):
        if cate == 'car':
            self.skip_list = ['17c32e15723ed6e0cd0bf4a0e76b8df5']
            self.ray_list = ['85f6145747a203becc08ff8f1f541268', '5343e944a7753108aa69dfdc5532bb13',
                             '67a3dfa9fb2d1f2bbda733a39f84326d']
            self.up_list = []
            cate_id = '02958343'
            chosen_id = '372ceb40210589f8f500cc506a763c18'
            self.anno_parts = ['body', 'front_left_wheel', 'front_right_wheel', 'back_left_wheel', 'back_right_wheel',
                               'left_door', 'right_door']
        elif cate == 'aeroplane':
            self.skip_list = []
            self.ray_list = []
            self.up_list = ['1d63eb2b1f78aa88acf77e718d93f3e1']
            cate_id = '02691156'
            # chosen_id = '1d63eb2b1f78aa88acf77e718d93f3e1'
            chosen_id = '3cb63efff711cfc035fc197bbabcd5bd'
            # self.anno_parts = ['head', 'body', 'engine', 'wing', 'tail']
            self.anno_parts = ["body", "left_wheel", "right_wheel", "left_wing", "right_wing", "left_engine",
                               "right_engine", "tail"]
        elif cate == 'boat':
            self.skip_list = []
            self.ray_list = []
            self.up_list = []
            cate_id = '04530566'
            chosen_id = '2340319ec4d93ae8c1df6b0203ecb359'
            self.anno_parts = ['body', 'sail']
        elif cate == 'bicycle':
            self.skip_list = ['3ZOy2KonL0']
            self.ray_list = []
            self.up_list = []
            cate_id = '02834778'
            chosen_id = '91k7HKqdM9'
            self.anno_parts = ["body", "front_wheel", "back_wheel", "saddle"]
        else:
            raise NotImplementedError

        index_path = os.path.join(dataset_config['root_path'], 'index', cate, chosen_id)

        self.mesh_path = os.path.join(dataset_config['root_path'], 'mesh', cate)
        name_path = self.mesh_path
        if type is not None:
            img_path = os.path.join(dataset_config['root_path'], 'image', type, cate)
            name_path = img_path
        self.mesh_name_dict = dict()
        for name in os.listdir(name_path):
            name = name.split('_')[0]
            if name in self.skip_list or '.' in name:
                continue
            self.mesh_name_dict[name] = len(self.mesh_name_dict)
        if chosen_id not in self.mesh_name_dict:
            self.mesh_name_dict[chosen_id] = len(self.mesh_name_dict)
        self.mesh_list = [self.get_meshes(name_) for name_ in self.mesh_name_dict.keys()]

        self.index_list = [np.load(os.path.join(index_path, t, 'index.npy'), allow_pickle=True)[()] for t in self.mesh_name_dict.keys()]
        # for idx in range(len(self.index_list)):
        #     print('index_list: ', self.index_list[idx].shape)
        #     print('name: ', list(self.mesh_name_dict.keys())[idx])

        if dataset_config.get('ori_mesh', False):
            self.ori_mesh_path = os.path.join(dataset_config['root_path'], 'ori_mesh', cate_id)
            self.ori_mesh_list = [self.get_ori_meshes(name_) for name_ in self.mesh_name_dict.keys()]

            # nearst point
            for idx, index in enumerate(self.index_list):
                sample_verts = self.mesh_list[idx][0][index]
                ori_verts = self.ori_mesh_list[idx][0]
                kdtree = KDTree(ori_verts)
                _, nearest_idx = kdtree.query(sample_verts, k=1)
                # print('nearest_idx: ', nearest_idx.shape)
                self.index_list[idx] = nearest_idx[:, 0]

    def get_mesh_listed(self):
        return [t[0].numpy() for t in self.mesh_list], [t[1].numpy() for t in self.mesh_list]

    def get_ori_mesh_listed(self):
        return [t[0].numpy() for t in self.ori_mesh_list], [t[1].numpy() for t in self.ori_mesh_list]

    def get_index_list(self, indexs=None):
        if indexs is not None:
            return torch.from_numpy(np.array([self.index_list[t] for t in indexs]))
        return torch.from_numpy(np.array(self.index_list))

    def get_meshes(self, instance_id, ):
        verts, faces, _, _ = pcu.load_mesh_vfnc(os.path.join(self.mesh_path, f'{instance_id}_recon_mesh.ply'))
        # offset = np.load(os.path.join(self.index_path, instance_id, 'offset.npy'), allow_pickle=True)[()]

        # faces
        faces = torch.from_numpy(faces.astype(np.int32))

        # normalize
        vert_middle = (verts.max(axis=0) + verts.min(axis=0)) / 2
        if instance_id in self.ray_list:
            vert_middle[1] += 0.05
        if instance_id in self.up_list:
            vert_middle[1] -= 0.08
        vert_scale = ((verts.max(axis=0) - verts.min(axis=0)) ** 2).sum() ** 0.5
        verts = verts - vert_middle
        verts = verts / vert_scale
        verts = torch.from_numpy(verts.astype(np.float32))
        
        return verts, faces

    def get_ori_meshes(self, instance_id, ):
        mesh_fn = os.path.join(self.ori_mesh_path, instance_id, 'models', 'model_normalized.obj')
        verts, faces_idx, _ = load_obj(mesh_fn)
        faces = faces_idx.verts_idx

        # normalize
        vert_middle = (verts.max(dim=0)[0] + verts.min(dim=0)[0]) / 2
        verts = verts - vert_middle

        return verts, faces


class PartLoader():
    def __init__(self, dataset_config, cate='car'):
        if cate == 'car':
            chosen_id = '372ceb40210589f8f500cc506a763c18'
            cate_id = '02958343'
        elif cate == 'aeroplane':
            cate_id = '02691156'
            chosen_id = '1d63eb2b1f78aa88acf77e718d93f3e1'
        elif cate == 'boat':
            cate_id = '04530566'
            chosen_id = '2340319ec4d93ae8c1df6b0203ecb359'
        elif cate == 'chair':
            cate_id = '03001627'
            chosen_id = '10de9af4e91682851e5f7bff98fb8d02'
        elif cate == 'bicycle':
            cate_id = '02834778'
            chosen_id = '91k7HKqdM9'
        else:
            raise NotImplementedError

        # load chosen mesh
        ori_mesh_path = os.path.join(dataset_config['root_path'], 'ori_mesh', cate_id, chosen_id, 'models', 'model_normalized.obj')
        chosen_verts, _, _ = load_obj(ori_mesh_path)
        vert_middle = (chosen_verts.max(axis=0)[0] + chosen_verts.min(axis=0)[0]) / 2

        # load annotated parts
        self.part_meshes = []
        self.part_names = []
        self.offsets = []
        part_path = os.path.join(dataset_config['root_path'], 'part', cate)
        for name in os.listdir(part_path):
            if '.obj' not in name:
                continue
            part_fn = os.path.join(part_path, name)
            part_verts, faces_idx, _ = load_obj(part_fn)
            part_faces = faces_idx.verts_idx
            part_verts = part_verts - vert_middle
            part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
            part_verts = part_verts - part_middle
            self.offsets.append(np.array(part_middle))
            self.part_meshes.append((part_verts, part_faces))
            self.part_names.append(name.split('.')[0])

        print('part names: ', self.part_names)

    def get_name_listed(self):
        return self.part_names

    def get_part_mesh(self, name=None):
        if name is None:
            return [mesh[0].numpy() for mesh in self.part_meshes], [mesh[1].numpy() for mesh in self.part_meshes]
        return self.part_meshes[self.part_names.index(name)]

    def get_offset(self, name=None):
        if name is None:
            return self.offsets
        return self.offsets[self.part_names.index(name)]


class PartsLoader():
    def __init__(self, dataset_config, cate='car', chosen_ids=None):
        if cate == 'car':
            cate_id = '02958343'
        elif cate == 'aeroplane':
            cate_id = '02691156'
        elif cate == 'boat':
            cate_id = '04530566'
        elif cate == 'bicycle':
            cate_id = '02834778'
        elif cate == 'chair':
            cate_id = '03001627'
        else:
            raise NotImplementedError

        self.parts_meshes = dict()
        self.parts_offsets = dict()
        self.part_names = None
        self.dataset_config = dataset_config
        self.cate = cate

        self.whole_offset = dict()
        for chosen_id in chosen_ids:
            if cate in ['bicycle', 'boat']:
                recon_mesh_path = os.path.join(dataset_config['root_path'], 'mesh', cate, f'{chosen_id}_recon_mesh.ply')
                chosen_verts, _, _, _ = pcu.load_mesh_vfnc(recon_mesh_path)
                vert_scale = ((chosen_verts.max(axis=0) - chosen_verts.min(axis=0)) ** 2).sum() ** 0.5
                chosen_verts = torch.from_numpy(chosen_verts.astype(np.float32))
            else:
                # load chosen mesh
                ori_mesh_path = os.path.join(dataset_config['root_path'], 'ori_mesh', cate_id, chosen_id, 'models', 'model_normalized.obj')
                chosen_verts, _, _ = load_obj(ori_mesh_path)
                vert_scale = 1
            vert_middle = (chosen_verts.max(axis=0)[0] + chosen_verts.min(axis=0)[0]) / 2
            if chosen_id in ['1d63eb2b1f78aa88acf77e718d93f3e1', '3cb63efff711cfc035fc197bbabcd5bd']:
                vert_middle[1] -= 0.08

            self.whole_offset[chosen_id] = vert_middle

            # load annotated parts
            part_meshes = []
            offsets = []
            if dataset_config.get('ori_mesh', False):
                part_path = os.path.join(dataset_config['root_path'], 'ori_parts', cate, chosen_id)
            else:
                part_path = os.path.join(dataset_config['root_path'], f'{cate}', chosen_id)

            if cate in ['bicycle', 'boat']:
                if self.part_names is None:
                    self.part_names = []
                    for name in os.listdir(part_path):
                        if '.ply' not in name:
                            continue
                        part_fn = os.path.join(part_path, name)
                        part_verts, part_faces, _, _ = pcu.load_mesh_vfnc(part_fn)
                        part_verts = torch.from_numpy(part_verts.astype(np.float32))
                        part_faces = torch.from_numpy(part_faces.astype(np.int32))
                        part_verts = part_verts - vert_middle
                        part_verts = part_verts / vert_scale
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
                        self.part_names.append(name.split('.')[0].split('_')[0])
                else:
                    for name in self.part_names:
                        part_fn = os.path.join(part_path, f'{name}_recon.ply')
                        if not os.path.exists(part_fn):
                            part_meshes.append((torch.zeros(1, 3), torch.zeros(1, 3)))
                            offsets.append(np.array([[0., 0., 0.]]))
                            continue
                        part_verts, part_faces, _, _ = pcu.load_mesh_vfnc(part_fn)
                        part_verts = torch.from_numpy(part_verts.astype(np.float32))
                        part_faces = torch.from_numpy(part_faces.astype(np.int32))
                        part_verts = part_verts - vert_middle
                        part_verts = part_verts / vert_scale
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
            else:
                if self.part_names is None:
                    self.part_names = []
                    for name in os.listdir(part_path):
                        if '.obj' not in name:
                            continue
                        part_fn = os.path.join(part_path, name)
                        part_verts, faces_idx, _ = load_obj(part_fn)
                        part_faces = faces_idx.verts_idx
                        part_verts = part_verts - vert_middle
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
                        self.part_names.append(name.split('.')[0].split('_')[0])
                else:
                    for name in self.part_names:
                        if dataset_config.get('ori_mesh', False):
                            part_fn = os.path.join(part_path, f'{name}.obj')
                        else:
                            part_fn = os.path.join(part_path, f'{name}_recon.obj')
                        if not os.path.exists(part_fn):
                            # print('no part ', name)
                            part_meshes.append((torch.zeros(1, 3), torch.zeros(1, 3)))
                            offsets.append(np.array([[0., 0., 0.]]))
                            continue
                        part_verts, faces_idx, _ = load_obj(part_fn)
                        part_faces = faces_idx.verts_idx
                        part_verts = part_verts - vert_middle
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))

            self.parts_meshes[chosen_id] = part_meshes
            self.parts_offsets[chosen_id] = offsets

            print('part names: ', self.part_names)

    def get_name_listed(self):
        return self.part_names

    def get_part_mesh(self, id=None, name=None):
        part_meshes = self.parts_meshes[id]
        if name is None:
            return [mesh[0].numpy() for mesh in part_meshes], [mesh[1].numpy() for mesh in part_meshes]
        return part_meshes[self.part_names.index(name)]

    def get_offset(self, id=None, name=None):
        offsets = self.parts_offsets[id]
        if name is None:
            return offsets
        return offsets[self.part_names.index(name)]

    def get_ori_part(self, id, name):
        part_path = os.path.join(self.dataset_config['root_path'], 'ori_parts_1', self.cate, id)
        part_fn = os.path.join(part_path, f'{name}.obj')
        part_vert, faces_idx, _ = load_obj(part_fn)
        part_face = faces_idx.verts_idx
        part_vert = part_vert - self.whole_offset[id]
        part_middle = (part_vert.max(axis=0)[0] + part_vert.min(axis=0)[0]) / 2
        part_vert = part_vert - part_middle

        return part_vert.numpy(), part_face.numpy()


class ImagenetPart(Dataset):
    def __init__(self, data_type, category, root_path, **kwargs):
        super().__init__()
        self.data_type = data_type
        self.category = category
        if data_type == 'test':
            root_path = os.path.join(root_path, 'pascalimagepart')
            self.root_path = get_abs_path(root_path)

            data_path = os.path.join(root_path, 'new_images', category)
            anno_path = os.path.join(root_path, 'annotations', category)
            pose_path = 'eval/pose_estimation_3d_nemo_%s' % category
            # pose_path_train = 'eval1/pose_estimation_3d_nemo_%s_training' % category
            self.preds = torch.load(os.path.join(pose_path, 'pascal3d_occ0_%s_val.pth' % category))
            # self.preds.update(torch.load(os.path.join(pose_path_train, 'pascal3d_occ0_%s_val.pth' % category)))
            self.kkeys = {t.split('/')[1]: t for t in self.preds.keys()}
            print('len: ', len(self.kkeys))

            self.save_path = get_abs_path(data_path)

            self.images = []
            self.segs = []
            self.annos = []
            self.names = []
            for image_name in os.listdir(data_path):
                if 'seg' in image_name:
                    continue
                image_fn = os.path.join(data_path, image_name)
                seg_fn = os.path.join(data_path, image_name.replace('.JPEG', '_seg.png'))
                image = cv2.imread(image_fn, cv2.IMREAD_UNCHANGED)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                seg = cv2.imread(seg_fn, cv2.IMREAD_UNCHANGED)
                anno_fn = os.path.join(anno_path, image_name.replace('.JPEG', '.npz'))
                anno = np.load(anno_fn)
                self.images.append(image)
                self.segs.append(seg)
                self.annos.append(anno)
                self.names.append(image_name.split('.')[0])
        else:
            self.root_path = get_abs_path(root_path)
            self.skip_list = ['17c32e15723ed6e0cd0bf4a0e76b8df5']

            self.img_path = os.path.join(self.root_path, 'image', self.data_type, self.category)
            self.render_img_path = os.path.join(self.root_path, 'render_img', self.data_type, self.category)
            self.angle_path = os.path.join(self.root_path, 'angle', self.data_type, self.category)

            self.instance_list = [x for x in os.listdir(self.img_path) if '.' not in x]

            self.img_fns = []
            self.angle_fns = []
            self.render_img_fns = []
            lambda_fn = lambda x: int(x[:3])
            for instance_id in self.instance_list:
                if instance_id in self.skip_list:
                    continue
                img_list = os.listdir(os.path.join(self.img_path, instance_id))
                img_list = sorted(img_list, key=lambda_fn)
                self.img_fns += [os.path.join(self.img_path, instance_id, x) for x in img_list]
                angle_list = os.listdir(os.path.join(self.angle_path, instance_id))
                angle_list = sorted(angle_list, key=lambda_fn)
                self.angle_fns += [os.path.join(self.angle_path, instance_id, x) for x in angle_list]
                render_img_list = os.listdir(os.path.join(self.render_img_path, instance_id))
                render_img_list = sorted(render_img_list, key=lambda_fn)
                self.render_img_fns += [os.path.join(self.render_img_path, instance_id, x) for x in render_img_list]

                assert len(self.img_fns) == len(self.angle_fns) == len(self.render_img_fns), \
                    f'{len(self.img_fns)}, {len(self.angle_fns)}, {len(self.render_img_fns)}'

    def __getitem__(self, item):
        sample = dict()
        if self.data_type == 'test':
            ori_img = self.images[item]
            img = ori_img / 255.
            seg = self.segs[item]
            anno = self.annos[item]
            name = self.names[item]
            pose_pred = 0.
            if name in self.kkeys:
                pose_pred = self.preds[self.kkeys[name]]['final'][0]
                if self.category == 'bicycle':
                    pose_pred['azimuth'] = pose_pred['azimuth'] + np.pi / 2
                pose_pred['elevation'] = np.pi - pose_pred['elevation']

            distance = float(anno['distance'])
            elevation = float(anno['elevation'])
            azimuth = float(anno['azimuth'])
            theta = float(anno['theta'])

            img = img.transpose(2, 0, 1)

            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(ori_img).astype(np.float32)
            sample['seg'] = np.ascontiguousarray(seg).astype(np.float32)
            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['pose_pred'] = pose_pred
            sample['name'] = name
        else:
            ori_img = cv2.imread(self.img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
            render_img = cv2.imread(self.render_img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = np.array(ori_img)
            render_img = np.array(render_img)
            angle = np.load(self.angle_fns[item], allow_pickle=True)[()]

            instance_id = self.img_fns[item].split('/')[-2]

            img = ori_img.transpose(2, 0, 1)
            mask = render_img[:, :, 3]
            mask = cv2.resize(mask, (img.shape[2], img.shape[1]), interpolation=cv2.INTER_NEAREST)
            mask[mask > 0] = 1

            distance = angle['dist']
            elevation = np.pi / 2 - angle['phi']
            azimuth = angle['theta'] + np.pi / 2
            theta = angle['camera_rotation']

            img = img / 255.0
            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(img).astype(np.float32)
            sample['obj_mask'] = np.ascontiguousarray(mask).astype(np.float32)

            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['instance_id'] = instance_id
            sample['this_name'] = item

        return sample

    def __len__(self):
        if self.data_type == 'test':
            return len(self.images)
        else:
            return len(self.img_fns)


class Normalize:
    def __init__(self):
        self.trans = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

    def __call__(self, sample):
        sample["img"] = self.trans(sample["img"])
        return sample
