import os
import sys
from PIL import Image
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision
# import point_cloud_utils as pcu
# from sklearn.neighbors import KDTree
from pytorch3d.io import load_obj

from corr.utils import get_abs_path
import os
import json
import trimesh
import pandas as pd


class MeshLoader():
    def __init__(self, dataset_config, cate):
        cate_id = cate
        
        self.skip_list = []

        root_path = dataset_config['root_path']
        infos = json.load(open(os.path.join(root_path, 'infos.json')))
        chosen_id = infos[cate_id]['chosen_instance']
        index_id = chosen_id
        annotation_path = os.path.join(root_path, 'CAD_models', cate_id, 'new_json_files', f'{index_id}.json')
        if not os.path.exists(annotation_path):
            annotation_path = os.path.join(root_path, 'CAD_models', cate_id, 'json_files', f'{index_id}.json')
        annotation = json.load(open(annotation_path))
        self.anno_parts = [key for key in annotation.keys() if len(annotation[key]) > 0]
        print("anno_parts: ", self.anno_parts)
        index_path = os.path.join(root_path, 'corr_index', cate_id, chosen_id)
        self.anno_path = os.path.join(root_path, 'CAD_models', cate_id, 'new_json_files')
        if not os.path.exists(self.anno_path):
            self.anno_path = os.path.join(root_path, 'CAD_models', cate_id, 'json_files')


        self.mesh_path = os.path.join(root_path, 'corr_recover_CAD_models', cate_id, chosen_id)
        name_path = self.mesh_path
        self.mesh_name_dict = dict()
        for name in os.listdir(name_path):
            name = name.split('.')[0]
            # print("mesh name: ", name)
            if name in self.skip_list or '.' in name:
                continue
            self.mesh_name_dict[name] = len(self.mesh_name_dict)
        if chosen_id not in self.mesh_name_dict:
            self.mesh_name_dict[chosen_id] = len(self.mesh_name_dict)
        self.mesh_list = [self.get_meshes(name_) for name_ in self.mesh_name_dict.keys()]

        print("mesh_list len: ", len(self.mesh_list))

        self.index_list = [np.load(os.path.join(index_path, t, 'index.npy'), allow_pickle=True)[()] for t in self.mesh_name_dict.keys()]

    def get_mesh_listed(self):
        return [t[0].numpy() for t in self.mesh_list], [t[1].numpy() for t in self.mesh_list]

    # def get_ori_mesh_listed(self):
        # return [t[0].numpy() for t in self.ori_mesh_list], [t[1].numpy() for t in self.ori_mesh_list]

    def get_index_list(self, indexs=None):
        if indexs is not None:
            return torch.from_numpy(np.array([self.index_list[t] for t in indexs]))
        return torch.from_numpy(np.array(self.index_list))

    def get_meshes(self, instance_id, ):
        obj_fn = os.path.join(self.mesh_path, f'{instance_id}.obj')
        glb_fn = os.path.join(self.mesh_path, f'{instance_id}.glb')
        if os.path.exists(obj_fn):
            verts, faces_idx, _ = load_obj(obj_fn)
            faces = faces_idx.verts_idx

            # normalize
            vert_middle = (verts.max(dim=0)[0] + verts.min(dim=0)[0]) / 2
            # vert_scale = ((verts.max(dim=0)[0] - verts.min(dim=0)[0]) ** 2).sum() ** 0.5
            vert_scale = (verts.max(dim=0)[0] - verts.min(dim=0)[0]).max()
            verts = verts - vert_middle
            verts = verts / vert_scale
        elif os.path.exists(glb_fn):
            mesh = trimesh.load_mesh(glb_fn)
            if isinstance(mesh, trimesh.Scene):
                # Convert the scene to a single mesh
                mesh = trimesh.util.concatenate(mesh.dump())

            # Extract vertices and faces
            verts = mesh.vertices
            faces = mesh.faces

            vert_middle = (verts.max(axis=0) + verts.min(axis=0)) / 2
            # vert_scale = ((verts.max(axis=0) - verts.min(axis=0)) ** 2).sum() ** 0.5
            vert_scale = (verts.max(axis=0) - verts.min(axis=0)).max()
            verts = verts - vert_middle
            verts = verts / vert_scale
            verts = torch.from_numpy(verts.astype(np.float32))
            faces = torch.from_numpy(faces.astype(np.int32))
        else:
            print('no mesh!')
            verts = None
            faces = None

        return verts, faces
    
    def get_gt_parts(self, instance_id, part_names):
        verts, faces = self.get_meshes(instance_id)
        faces = faces.numpy()
        annotations = json.load(open(os.path.join(self.anno_path, f'{instance_id}.json')))
        part_verts = []
        part_faces = []
        for part_name in part_names:
            if part_name not in annotations:
                part_verts.append(None)
                part_faces.append(np.array([]))
                continue
            part_vert_index = annotations[part_name]
            if len(part_vert_index) == 0:
                part_verts.append(None)
            else:
                part_verts.append(verts[part_vert_index])

            vertex_to_new_index = {v: i for i, v in enumerate(part_vert_index)}
            
            # select faces based on the part vertices
            mask = np.isin(faces, part_vert_index)
            part_face = faces[mask.all(axis=1)]

            for face_id, face in enumerate(part_face):
                part_face[face_id] = [vertex_to_new_index[f] for f in part_face[face_id]]

            part_face = torch.from_numpy(np.array(part_face))
            part_faces.append(part_face)

        return part_verts, part_faces


class NewDSTPart(Dataset):
    def __init__(self, data_type, category, root_path, **kwargs):
        super().__init__()
        self.data_type = data_type
        self.category = category
        
        root_path = get_abs_path(root_path)
        self.skip_dict = {
            "n02749479": [
                "67f67bc8e340fb23cdbadd1af48b5cd6",
                "9a2d5bc2029c82cc1b7837f7a64e3031"
            ],
            "n02981792": [
                "9f19a29d84234e1aa6022fd16e0a752d",
                "e64b636f26a34668a49316d842e9f084"
            ],
            "n03063689": [
                "0bb17661ab004906be1ab452c6b2ade6",
                "211c6d30db8c4f5682837c2e514b5b39"
            ],
            "n03100240": [
                "edce9973ca8b4862896179439803e3e1",
                "2495432c72324ace8f1be1a2e29f52cb"
            ],
            "n03344393": [],
            "n03445924": [
                "c47ccbe5107d4d8f92b3caaf93f3d9cc"
            ],
            "n03481172": [
                "f5a132de5e4f4458935c528044dc8c52"
            ],
            "n03498962": [],
            "n03594945": [
                "a8c75ce1d4704e55bfecd1e81c60a373"
            ],
            "n03770679": [],
            "n03785016": [
                "518a7d5a4a0b41218650ee0823320f5a",
                "de33815904e241148f62f16ddf8bf931",
                "a8f0d74334034ca1bcf397f21c26a4c3"
            ],
            "n03947888": [
                "6b32fb0dac4c4e79a2a09a93559302e8"
            ],
            "n04037443": [],
            "n04065272": [],
            "n04147183": [
                "a3109a4b09953b5e2b141dc6bd7c4bce",
                "f791188138166b4d9db44e8b39ef337"
            ],
            "n04204347": [],
            "n04266014": [
                "42af3119844a4fb685dcddffedee875a"
            ],
            "n04285008": [],
            "n04465501": [],
            "n04509417": [],
            "n04612504": [],
            "n03272562": [
                "06bf33c545f14898b41c220e88364430",
                "36f89cddb615444491df3e9e419bb9a2"
            ],
            "n03345487": [
                "2de6c417331a4c3890d7a59ae1097f83"
            ],
            "n03417042": [
                "8580a428074449c4b0bf3d402d15d183"
            ],
            "n03496892": [
                "eaf89f5bc8e94f17a82d7dd1afe003f4"
            ],
            "n03599486": [
                "ce5732fe86cc43939ea05dc6f893cd72"
            ],
            "n03649909": [
                "ce58c158034e4ccf8bf9d385427e4b4a",
                "e187c4568ee5433ebab47a9d5bf53ca7"
            ],
            "n03769881": [
                "6c9f30417f2747b895b03d3ed2f212b3",
                "7b7a72708b614d95af15b8ad9235670e"
            ],
            "n04252225": [
                "853a2472d0054c44afa5778fc7d2856f",
                "9541aaf266a14389ae41f7c123cad403",
                "bdfe648359894f5eb8b8967c554c75a9"
            ]
        }
        if self.category in self.skip_dict:
            self.skip_list = self.skip_dict[self.category]
        else:
            self.skip_list = []

        # infos = json.load(open(os.path.join(root_path, 'infos.json')))
        # chosen_id = infos[category]['chosen_instance']

        # self.rotation_data = dict()

        # raw_rotation_data = pd.read_csv('tools/3d-dst-models.csv')
        # for _, row in raw_rotation_data.iterrows():
        #     nid = str(row[0])
        #     instance_id = str(row[1])
        #     if nid == category:
        #         self.rotation_data[instance_id] = row[2:]

        self.img_path = os.path.join(root_path, 'image', self.category)
        self.mesh_path = os.path.join(root_path, 'aligned_CAD_models_new', self.category)
        self.render_img_path = os.path.join(root_path, 'render_img', self.category)
        self.angle_path = os.path.join(root_path, 'render_img', self.category)

        annotated_mesh_path = os.path.join(root_path, 'CAD_models', category, 'new_json_files')
        if not os.path.exists(annotated_mesh_path):
            annotated_mesh_path = os.path.join(root_path, 'CAD_models', category, 'json_files')
        # print('annotated_mesh_path: ', annotated_mesh_path)
        self.gt_instances = [x.split('.')[0] for x in os.listdir(annotated_mesh_path) if '.' in x]
        print('gt_instances: ', self.gt_instances)
        
        if self.data_type == 'train':
            self.skip_list += self.gt_instances
            self.instance_list = [x.split('.')[0] for x in os.listdir(self.mesh_path)]
            # if there are less than 20 instances, cut down the test set
            if len(self.instance_list) - len(self.gt_instances) < 6:
                self.skip_list = self.gt_instances[:len(self.gt_instances) // 3]
            self.image_per_instance = -1
        else:
            self.instance_list = self.gt_instances
            self.image_per_instance = 10

        self.img_fns = []
        self.angle_fns = []
        self.render_img_fns = []
        lambda_fn = lambda x: int(x[:3])
        for instance_id in self.instance_list:
            if instance_id in self.skip_list:
                continue

            # try: 
            #     img_list = os.listdir(os.path.join(self.img_path, instance_id, 'gpt41mini'))
            #     img_list = sorted(img_list, key=lambda_fn)
            #     self.img_fns += [os.path.join(self.img_path, instance_id, 'gpt41mini', x) for x in img_list]
            # except:
            # if category in ["n04487081", "n04467665", "n03785016", "n03977966", "n04465501"]:
            #     img_list = os.listdir(os.path.join(self.img_path, instance_id, 'image_minigpt4_1008'))
            # else:
            img_list = os.listdir(os.path.join(self.img_path, instance_id, 'new_dst_part_gpt4.1'))
            img_list = sorted(img_list, key=lambda_fn)
            if self.image_per_instance != -1:
                img_list = img_list[:self.image_per_instance]
            # if category in ["n04487081", "n04467665", "n03785016", "n03977966", "n04465501"]:
            #     self.img_fns += [os.path.join(self.img_path, instance_id, 'image_minigpt4_1008', x) for x in img_list]
            # else:
            self.img_fns += [os.path.join(self.img_path, instance_id, 'new_dst_part_gpt4.1', x) for x in img_list]

            angle_list = os.listdir(os.path.join(self.angle_path, instance_id, 'annotation'))
            angle_list = sorted(angle_list, key=lambda_fn)
            if self.image_per_instance != -1:
                angle_list = angle_list[:self.image_per_instance]
            self.angle_fns += [os.path.join(self.angle_path, instance_id, 'annotation', x) for x in angle_list]

            render_img_list = os.listdir(os.path.join(self.render_img_path, instance_id, 'image_render'))
            render_img_list = sorted(render_img_list, key=lambda_fn)
            if self.image_per_instance != -1:
                render_img_list = render_img_list[:self.image_per_instance]
            self.render_img_fns += [os.path.join(self.render_img_path, instance_id, 'image_render', x) for x in render_img_list]

        print('length: ', len(self.img_fns), len(self.angle_fns), len(self.render_img_fns))

        assert len(self.img_fns) == len(self.angle_fns) == len(self.render_img_fns), \
            f'{len(self.img_fns)}, {len(self.angle_fns)}, {len(self.render_img_fns)}'

    def __getitem__(self, item):
        sample = dict()
        
        ori_img = cv2.imread(self.img_fns[item], cv2.IMREAD_UNCHANGED)
        ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        render_img = cv2.imread(self.render_img_fns[item], cv2.IMREAD_UNCHANGED)
        ori_img = np.array(ori_img)
        # print("ori_img shape: ", ori_img.shape)
        render_img = np.array(render_img)
        # print("render_img shape: ", render_img.shape)
        angle = np.load(self.angle_fns[item], allow_pickle=True)[()]

        instance_id = self.img_fns[item].split('/')[-3]

        ori_img = cv2.resize(ori_img, (render_img.shape[1], render_img.shape[0]), interpolation=cv2.INTER_LINEAR)

        img = ori_img.transpose(2, 0, 1)
        # print("img shape: ", img.shape)
        mask = render_img[:, :, 3]
        mask = cv2.resize(mask, (img.shape[2], img.shape[1]), interpolation=cv2.INTER_NEAREST)
        mask[mask > 0] = 1

        render_img = render_img[:, :, :3]
        render_img = render_img.transpose(2, 0, 1)
        render_img = render_img / 255.0

        # distance, azimuth, elevation, strength, sampling = self.rotation_data[instance_id]

        distance = angle['dist']
        elevation = np.pi / 2 - angle['phi']
        azimuth = angle['theta'] + np.pi / 2
        theta = angle['camera_rotation']

        img = img / 255.0
        sample['img'] = np.ascontiguousarray(img).astype(np.float32)
        sample['img_ori'] = np.ascontiguousarray(img).astype(np.float32)
        sample['obj_mask'] = np.ascontiguousarray(mask).astype(np.float32)

        sample['distance'] = distance
        sample['elevation'] = elevation
        sample['azimuth'] = azimuth
        sample['theta'] = theta
        sample['instance_id'] = instance_id
        sample['this_name'] = item
        sample['name'] = instance_id + '_' + self.img_fns[item].split('/')[-1].split('.')[0]

        return sample

    def __len__(self):
        return len(self.img_fns)


class Normalize:
    def __init__(self):
        self.trans = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )

    def __call__(self, sample):
        sample["img"] = self.trans(sample["img"])
        return sample
