import os
import os.path as osp
import glob
import json
import pickle

import torch
import trimesh
import numpy as np
from tqdm import tqdm
from omegaconf import OmegaConf
from loguru import logger
from sklearn.decomposition import PCA


class ThumanDataSet(torch.utils.data.Dataset):

    def __init__(self, 
                 mode,
                 rep,
                 config,
                 **kwargs):
        '''
        Args:
            sdf_dir: raw sdf dir
            raw_mesh_dir: raw mesh dir, might not have consistent topology
            registration_dir: registered mesh dir, must have consistent topology
            num_samples: num of samples used to train sdfnet
        '''
        super().__init__()

        self.rep = rep
        self.config = config
        self.mode = mode
        if self.mode == 'train':
            split = 'train'
        elif self.mode == 'test':
            split = 'test'
        else:
            raise ValueError('invalid mode')

        self.data_dir = config.data_dir
        self.sdf_dir = config.sdf_dir
        self.raw_mesh_dir = config.raw_mesh_dir
        self.registration_dir = config.registration_dir
        self.num_samples = config.num_samples
        # self.template_path = config.template_path
        self.dataset_name = config.dataset_name

        # load data split
        split_cfg_fname = config.split_cfg[split]
        current_dir = os.path.dirname(os.path.realpath(__file__))
        split_path = f"{current_dir}/splits/{self.dataset_name}/{split_cfg_fname}"
        with open(split_path, "r") as f:
            split_names = json.load(f)

        self.fid_list = self.get_fid_list(split_names)
        self.num_data = len(self.fid_list)

        self.raw_mesh_file_type = config.get('raw_mesh_file_type', 'ply')
        logger.info(f"dataset mode = {mode}, split = {split}, len = {self.num_data}\n")


    def get_fid_list(self, split_names):
        fid_list = []
        assert(len(split_names) == 1)
        for dataset in split_names:
            for class_name in split_names[dataset]:
                for instance_name in split_names[dataset][class_name]:
                    for shape in split_names[dataset][class_name][instance_name]:
                        fid = f"{class_name}-{instance_name}-{shape}"
                        fid_list.append(fid)
        return fid_list


    def update_pca_sv(self, train_pca_axes, train_pca_sv_mean, train_pca_sv_std):
        pca_sv = np.matmul(self.verts_init_nml.reshape(self.num_data, -1), train_pca_axes.transpose())
        self.pca_sv = (pca_sv - train_pca_sv_mean) / train_pca_sv_std


    def __len__(self):
        return self.num_data


    def __getitem__(self, idx):
        data_dict = {}
        data_dict['idx'] = idx
        fid = self.fid_list[idx]
        fname = '/'.join(fid.split('-'))

        if self.rep in ['mesh']:
            # no sdf, only load mesh. TODO: verts num diff, need to use PyG dataloader
            data_dict['verts_init_nml'] = self.verts_init_nml[idx]
            data_dict['verts_raw'] = self.verts_raw_list[idx]
            data_dict['faces_raw'] = self.faces_raw_list[idx]

        elif self.rep in ['sdf']:
            # load sdf data

            point_set_mnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}.npy")).float()
            samples_nonmnfld = torch.from_numpy(np.load(f"{self.sdf_dir}/{fname}_dist_triangle.npy")).float()

            random_idx = (torch.rand(self.num_samples) * point_set_mnfld.shape[0]).long()
            point_set_mnfld = torch.index_select(point_set_mnfld, 0, random_idx)
            normal_set_mnfld = point_set_mnfld[:, 3:] 
            point_set_mnfld = point_set_mnfld[:, :3] # currently all center == [0, 0, 0], scale == 1

            random_idx = (torch.rand(self.num_samples) * samples_nonmnfld.shape[0]).long()
            samples_nonmnfld = torch.index_select(samples_nonmnfld, 0, random_idx)

            data_dict['points_mnfld'] = point_set_mnfld
            data_dict['normals_mnfld'] = normal_set_mnfld
            data_dict['samples_nonmnfld'] = samples_nonmnfld

        return data_dict


