import os
import sys
from PIL import Image
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
import point_cloud_utils as pcu
from sklearn.neighbors import KDTree
from pytorch3d.io import load_obj

from corr.utils import get_abs_path
import os


class MeshLoader():
    def __init__(self, dataset_config, cate='car', type=None):
        skip_list = ['511962626501e4abf500cc506a763c18']
        if cate == 'car':
            chosen_id = '372ceb40210589f8f500cc506a763c18'
            self.anno_parts = ['body', 'front_right_wheel', 'back_right_wheel', 'back_left_wheel',
                               'front_left_wheel', 'left_mirror', 'right_mirror',
                               'left_door', 'right_door']
        elif cate == 'aeroplane':
            chosen_id = '3cb63efff711cfc035fc197bbabcd5bd'
            self.anno_parts = ['body', 'left_wing', 'right_wing', 'left_engine', 'right_engine', 'tail']
        elif cate == 'bicycle':
            chosen_id = '91k7HKqdM9'
            self.anno_parts = ["body", "front_wheel", "back_wheel", "saddle"]
        else:
            raise NotImplementedError

        index_path = os.path.join(dataset_config['root_path'], 'index', cate, chosen_id)

        self.mesh_path = os.path.join(dataset_config['root_path'], 'mesh', cate)
        self.mesh_name_dict = dict()
        for name in os.listdir(self.mesh_path):
            name = name.split('_')[0]
            if name in skip_list:
                continue
            self.mesh_name_dict[name] = len(self.mesh_name_dict)
        if chosen_id not in self.mesh_name_dict:
            self.mesh_name_dict[chosen_id] = len(self.mesh_name_dict)
        self.mesh_list = [self.get_meshes(name_) for name_ in self.mesh_name_dict.keys()]

        self.index_list = [np.load(os.path.join(index_path, t, 'index.npy'), allow_pickle=True)[()] for t in
                           self.mesh_name_dict.keys()]

    def get_mesh_listed(self):
        return [t[0].numpy() for t in self.mesh_list], [t[1].numpy() for t in self.mesh_list]

    def get_index_list(self, indexs=None):
        if indexs is not None:
            return torch.from_numpy(np.array([self.index_list[t] for t in indexs]))
        return torch.from_numpy(np.array(self.index_list))

    def get_meshes(self, instance_id, ):
        verts, faces, _, _ = pcu.load_mesh_vfnc(os.path.join(self.mesh_path, f'{instance_id}_recon_mesh.ply'))

        # faces
        faces = torch.from_numpy(faces.astype(np.int32))

        # normalize
        vert_middle = (verts.max(axis=0) + verts.min(axis=0)) / 2
        vert_scale = ((verts.max(axis=0) - verts.min(axis=0)) ** 2).sum() ** 0.5
        verts = verts - vert_middle
        verts = verts / vert_scale
        verts = torch.from_numpy(verts.astype(np.float32))

        return verts, faces


class PartsLoader():
    def __init__(self, dataset_config, cate='car', chosen_ids=None):
        if cate == 'car':
            cate_id = '02958343'
        elif cate == 'aeroplane':
            cate_id = '02691156'
        elif cate == 'bicycle':
            cate_id = '02834778'
        else:
            raise NotImplementedError

        self.parts_meshes = dict()
        self.parts_offsets = dict()
        self.part_names = None
        self.dataset_config = dataset_config
        self.cate = cate

        for chosen_id in chosen_ids:
            if cate in ['bicycle']:
                recon_mesh_path = os.path.join(dataset_config['root_path'], 'mesh', cate, f'{chosen_id}_recon_mesh.ply')
                chosen_verts, _, _, _ = pcu.load_mesh_vfnc(recon_mesh_path)
                vert_scale = ((chosen_verts.max(axis=0) - chosen_verts.min(axis=0)) ** 2).sum() ** 0.5
                chosen_verts = torch.from_numpy(chosen_verts.astype(np.float32))
            else:
                # load chosen mesh
                ori_mesh_path = os.path.join(dataset_config['root_path'], 'ori_mesh', cate_id, chosen_id, 'models', 'model_normalized.obj')
                chosen_verts, _, _ = load_obj(ori_mesh_path)
                vert_scale = 1
            vert_middle = (chosen_verts.max(axis=0)[0] + chosen_verts.min(axis=0)[0]) / 2
            if chosen_id in ['1d63eb2b1f78aa88acf77e718d93f3e1', '3cb63efff711cfc035fc197bbabcd5bd']:
                vert_middle[1] -= 0.08

            # load annotated parts
            part_meshes = []
            offsets = []
            part_path = os.path.join(dataset_config['root_path'], cate, f'{chosen_id}_uda')

            if cate in ['bicycle']:
                if self.part_names is None:
                    self.part_names = []
                    for name in os.listdir(part_path):
                        if '.ply' not in name:
                            continue
                        part_fn = os.path.join(part_path, name)
                        part_verts, part_faces, _, _ = pcu.load_mesh_vfnc(part_fn)
                        part_verts = torch.from_numpy(part_verts.astype(np.float32))
                        part_faces = torch.from_numpy(part_faces.astype(np.int32))
                        part_verts = part_verts - vert_middle
                        part_verts = part_verts / vert_scale
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
                        self.part_names.append(name.split('.')[0].split('_')[0])
                else:
                    for name in self.part_names:
                        part_fn = os.path.join(part_path, f'{name}_recon.ply')
                        if not os.path.exists(part_fn):
                            part_meshes.append((torch.zeros(1, 3), torch.zeros(1, 3)))
                            offsets.append(np.array([[0., 0., 0.]]))
                            continue
                        part_verts, part_faces, _, _ = pcu.load_mesh_vfnc(part_fn)
                        part_verts = torch.from_numpy(part_verts.astype(np.float32))
                        part_faces = torch.from_numpy(part_faces.astype(np.int32))
                        part_verts = part_verts - vert_middle
                        part_verts = part_verts / vert_scale
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
            else:
                if self.part_names is None:
                    self.part_names = []
                    for name in os.listdir(part_path):
                        if '.obj' not in name:
                            continue
                        part_fn = os.path.join(part_path, name)
                        part_verts, faces_idx, _ = load_obj(part_fn)
                        part_faces = faces_idx.verts_idx
                        part_verts = part_verts - vert_middle
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))
                        self.part_names.append(name.split('.')[0].split('_')[0])
                else:
                    for name in self.part_names:
                        part_fn = os.path.join(part_path, f'{name}_recon.obj')
                        if not os.path.exists(part_fn):
                            # print('no part ', name)
                            part_meshes.append((torch.zeros(1, 3), torch.zeros(1, 3)))
                            offsets.append(np.array([[0., 0., 0.]]))
                            continue
                        part_verts, faces_idx, _ = load_obj(part_fn)
                        part_faces = faces_idx.verts_idx
                        part_verts = part_verts - vert_middle
                        part_middle = (part_verts.max(axis=0)[0] + part_verts.min(axis=0)[0]) / 2
                        part_verts = part_verts - part_middle
                        offsets.append(np.array(part_middle))
                        part_meshes.append((part_verts, part_faces))

            self.parts_meshes[chosen_id] = part_meshes
            self.parts_offsets[chosen_id] = offsets

    def get_name_listed(self):
        return self.part_names

    def get_part_mesh(self, id=None, name=None):
        part_meshes = self.parts_meshes[id]
        if name is None:
            return [mesh[0].numpy() for mesh in part_meshes], [mesh[1].numpy() for mesh in part_meshes]
        return part_meshes[self.part_names.index(name)]

    def get_offset(self, id=None, name=None):
        offsets = self.parts_offsets[id]
        if name is None:
            return offsets
        return offsets[self.part_names.index(name)]


class UDAPart(Dataset):
    def __init__(self, data_type, category, root_path, **kwargs):
        super().__init__()
        self.data_type = data_type
        self.category = category

        if self.data_type == 'test':
            root_path = os.path.join(root_path, 'pascalUDApart')
            self.root_path = get_abs_path(root_path)

            img_path = os.path.join(root_path, 'images', category)
            seg_path = os.path.join(root_path, 'images', category + '_merge')
            anno_path = os.path.join(root_path, 'annotations', category)
            pose_path = os.path.join(root_path, 'images', f'{category}_pose')
            pose_fn = os.path.join(pose_path, 'pascal3d_occ0_%s_val.pth' % category)
            if os.path.exists(pose_fn):
                self.preds = torch.load(pose_fn)
                self.kkeys = {t.split('/')[1]: t for t in self.preds.keys()}
            else:
                self.preds = None
                self.kkeys = dict()

            self.images = []
            self.segs = []
            self.annos = []
            self.names = []
            for image_name in os.listdir(img_path):
                if 'seg' in image_name:
                    continue
                image_fn = os.path.join(img_path, image_name)
                seg_fn = os.path.join(seg_path, image_name.replace('.JPEG', '_seg.png'))
                image = cv2.imread(image_fn, cv2.IMREAD_UNCHANGED)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                seg = cv2.imread(seg_fn, cv2.IMREAD_UNCHANGED)
                anno_fn = os.path.join(anno_path, image_name.replace('.JPEG', '.npz'))
                anno = np.load(anno_fn)
                self.images.append(image)
                self.segs.append(seg)
                self.annos.append(anno)
                self.names.append(image_name.split('.')[0])
        else:
            self.root_path = get_abs_path(root_path)
            self.skip_list = ['17c32e15723ed6e0cd0bf4a0e76b8df5']

            self.img_path = os.path.join(self.root_path, 'syn_image', self.data_type, self.category)
            self.render_img_path = os.path.join(self.root_path, 'render_img', self.data_type, self.category)
            self.angle_path = os.path.join(self.root_path, 'angle', self.data_type, self.category)

            self.instance_list = [x for x in os.listdir(self.img_path) if '.' not in x]

            self.img_fns = []
            self.angle_fns = []
            self.render_img_fns = []
            lambda_fn = lambda x: int(x[:3])
            for instance_id in self.instance_list:
                if instance_id in self.skip_list:
                    continue
                img_list = os.listdir(os.path.join(self.img_path, instance_id))
                img_list = sorted(img_list, key=lambda_fn)
                self.img_fns += [os.path.join(self.img_path, instance_id, x) for x in img_list]
                angle_list = os.listdir(os.path.join(self.angle_path, instance_id))
                angle_list = sorted(angle_list, key=lambda_fn)
                self.angle_fns += [os.path.join(self.angle_path, instance_id, x) for x in angle_list]
                render_img_list = os.listdir(os.path.join(self.render_img_path, instance_id))
                render_img_list = sorted(render_img_list, key=lambda_fn)
                self.render_img_fns += [os.path.join(self.render_img_path, instance_id, x) for x in render_img_list]

                assert len(self.img_fns) == len(self.angle_fns) == len(self.render_img_fns), \
                    f'{len(self.img_fns)}, {len(self.angle_fns)}, {len(self.render_img_fns)}'

    def __getitem__(self, item):
        sample = dict()
        if self.data_type == 'test':
            ori_img = self.images[item]
            img = ori_img / 255.
            seg = self.segs[item]
            anno = self.annos[item]
            name = self.names[item]
            pose_pred = 0.
            if name in self.kkeys:
                pose_pred = self.preds[self.kkeys[name]]['final'][0]
                if self.category == 'bicycle':
                    pose_pred['azimuth'] = pose_pred['azimuth'] + np.pi / 2
                pose_pred['elevation'] = np.pi - pose_pred['elevation']

            distance = float(anno['distance'])
            elevation = float(anno['elevation'])
            azimuth = float(anno['azimuth'])
            theta = float(anno['theta'])

            img = img.transpose(2, 0, 1)

            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(ori_img).astype(np.float32)
            sample['seg'] = np.ascontiguousarray(seg).astype(np.float32)
            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['pose_pred'] = pose_pred
            sample['name'] = name
        else:
            ori_img = cv2.imread(self.img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
            render_img = cv2.imread(self.render_img_fns[item], cv2.IMREAD_UNCHANGED)
            ori_img = np.array(ori_img)
            render_img = np.array(render_img)
            angle = np.load(self.angle_fns[item], allow_pickle=True)[()]

            instance_id = self.img_fns[item].split('/')[-2]

            img = ori_img.transpose(2, 0, 1)
            mask = render_img[:, :, 3]
            mask = cv2.resize(mask, (img.shape[2], img.shape[1]), interpolation=cv2.INTER_NEAREST)
            mask[mask > 0] = 1

            distance = angle['dist']
            elevation = np.pi / 2 - angle['phi']
            azimuth = angle['theta'] + np.pi / 2
            theta = angle['camera_rotation']

            img = img / 255.0
            sample['img'] = np.ascontiguousarray(img).astype(np.float32)
            sample['img_ori'] = np.ascontiguousarray(img).astype(np.float32)
            sample['obj_mask'] = np.ascontiguousarray(mask).astype(np.float32)

            sample['distance'] = distance
            sample['elevation'] = elevation
            sample['azimuth'] = azimuth
            sample['theta'] = theta
            sample['instance_id'] = instance_id
            sample['this_name'] = item

        return sample

    def __len__(self):
        if self.data_type == 'test':
            return len(self.images)
        else:
            return len(self.img_fns)