import warnings
from torch.utils.data import Dataset
from tqdm import tqdm
from pathlib import Path
import open3d as o3d
import os
import numpy as np

import hashlib
import torch
import matplotlib.pyplot as plt

synset_to_label = {
    '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket',
    '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench',
    '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus',
    '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera',
    '02954340': 'cap', '02958343': 'car', '03001627': 'chair',
    '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor',
    '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can',
    '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard',
    '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file',
    '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar',
    '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop',
    '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone',
    '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug',
    '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol',
    '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control',
    '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard',
    '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel',
    '04554684': 'washer', '02992529': 'cellphone',
    '02843684': 'birdhouse', '02871439': 'bookshelf',
    # '02858304': 'boat', no boat in our dataset, merged into vessels
    # '02834778': 'bicycle', not in our taxonomy
}

# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {v: k for k, v in synset_to_label.items()}

def _convert_categories(categories):
    assert categories is not None, 'List of categories cannot be empty!'
    if not (c in synset_to_label.keys() + label_to_synset.keys()
            for c in categories):
        warnings.warn('Some or all of the categories requested are not part of \
            ShapeNetCore. Data loading may fail if these categories are not avaliable.')
    synsets = [label_to_synset[c] if c in label_to_synset.keys()
               else c for c in categories]
    return synsets


class ShapeNet_Multiview_Points(Dataset):
    def __init__(self, root_pc:str, root_views: str, cache: str, categories: list = ['chair'], split: str= 'val',
                 npoints=2048, sv_samples=800, all_points_mean=None, all_points_std=None, get_image=False):
        self.root = Path(root_views)
        self.split = split
        self.get_image = get_image
        params = {
            'cat': categories,
            'npoints': npoints,
            'sv_samples': sv_samples,
        }
        params = tuple(sorted(pair for pair in params.items()))
        self.cache_dir = Path(cache) / 'svpoints/{}/{}'.format('_'.join(categories), hashlib.md5(bytes(repr(params), 'utf-8')).hexdigest())

        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.paths = []
        self.synset_idxs = []
        self.synsets = _convert_categories(categories)
        self.labels = [synset_to_label[s] for s in self.synsets]
        self.npoints = npoints
        self.sv_samples = sv_samples

        self.all_points = []
        self.all_points_sv = []

        # loops through desired classes
        for i in range(len(self.synsets)):

            syn = self.synsets[i]
            class_target = self.root / syn
            if not class_target.exists():
                raise ValueError('Class {0} ({1}) was not found at location {2}.'.format(
                    syn, self.labels[i], str(class_target)))


            sub_path_pc = os.path.join(root_pc, syn, split)
            if not os.path.isdir(sub_path_pc):
                print("Directory missing : %s" % sub_path_pc)
                continue

            self.all_mids = []
            self.imgs = []
            for x in os.listdir(sub_path_pc):
                if not x.endswith('.npy'):
                    continue
                self.all_mids.append(os.path.join(split, x[:-len('.npy')]))

            for mid in tqdm(self.all_mids):
                # obj_fname = os.path.join(sub_path, x)
                obj_fname = os.path.join(root_pc, syn, mid + ".npy")
                cams_pths = list((self.root/ syn/ mid.split('/')[-1]).glob('*_cam_params.npz'))
                if len(cams_pths) < 20:
                    continue
                point_cloud = np.load(obj_fname)
                sv_points_group = []
                img_path_group = []
                (self.cache_dir / (mid.split('/')[-1])).mkdir(parents=True, exist_ok=True)
                success = True
                for i, cp in enumerate(cams_pths):
                    cp = str(cp)
                    vp = cp.split('cam_params')[0] + 'depth.png'
                    depth_minmax_pth = cp.split('_cam_params')[0] + '.npy'
                    cache_pth = str(self.cache_dir / mid.split('/')[-1] / os.path.basename(depth_minmax_pth) )

                    cam_params = np.load(cp)
                    extr = cam_params['extr']
                    intr = cam_params['intr']

                    self.transform = DepthToSingleViewPoints(cam_ext=extr, cam_int=intr)

                    try:
                        sv_point_cloud = self._render(cache_pth, vp, depth_minmax_pth)

                        img_path_group.append(vp)

                        sv_points_group.append(sv_point_cloud)
                    except Exception as e:
                        print(e)
                        success=False
                        break
                if not success:
                    continue
                self.all_points_sv.append(np.stack(sv_points_group, axis=0))
                self.all_points.append(point_cloud)
                self.imgs.append(img_path_group)

        self.all_points = np.stack(self.all_points, axis=0)

        self.all_points_sv = np.stack(self.all_points_sv, axis=0)
        if all_points_mean is not None and all_points_std is not None:  # using loaded dataset stats
            self.all_points_mean = all_points_mean
            self.all_points_std = all_points_std
        else:  # normalize across the dataset
            self.all_points_mean = self.all_points.reshape(-1, 3).mean(axis=0).reshape(1, 1, 3)
            self.all_points_std = self.all_points.reshape(-1).std(axis=0).reshape(1, 1, 1)

        self.all_points = (self.all_points - self.all_points_mean) / self.all_points_std
        self.train_points = self.all_points[:,:10000]
        self.test_points = self.all_points[:,10000:]
        self.all_points_sv = (self.all_points_sv - self.all_points_mean) / self.all_points_std

    def get_pc_stats(self, idx):

        return self.all_points_mean.reshape(1,1, -1), self.all_points_std.reshape(1,1, -1)

    def __len__(self):
        """Returns the length of the dataset. """
        return len(self.all_points)

    def __getitem__(self, index):


        tr_out = self.train_points[index]
        tr_idxs = np.random.choice(tr_out.shape[0], self.npoints)
        tr_out = tr_out[tr_idxs, :]

        gt_points = self.test_points[index][:self.npoints]

        m, s = self.get_pc_stats(index)

        sv_points = self.all_points_sv[index]

        idxs = np.arange(0, sv_points.shape[-2])[:self.sv_samples]#np.random.choice(sv_points.shape[0], 500, replace=False)

        data = torch.cat([torch.from_numpy(sv_points[:,idxs]).float(),
                          torch.zeros(sv_points.shape[0], self.npoints - idxs.shape[0], sv_points.shape[2])], dim=1)
        masks = torch.zeros_like(data)
        masks[:,:idxs.shape[0]] = 1

        res = {'train_points': torch.from_numpy(tr_out).float(),
                'test_points': torch.from_numpy(gt_points).float(),
                'sv_points': data,
                'masks': masks,
                'std': s, 'mean': m,
                'idx': index,
               'name':self.all_mids[index]
                }

        if self.split != 'train' and self.get_image:

            img_lst = []
            for n in range(self.all_points_sv.shape[1]):

                img = torch.from_numpy(plt.imread(self.imgs[index][n])).float().permute(2,0,1)[:3]

                img_lst.append(img)

            img = torch.stack(img_lst, dim=0)

            res['image'] = img

        return res



    def _render(self, cache_path, depth_pth, depth_minmax_pth):
        # if not os.path.exists(cache_path.split('.npy')[0] + '_color.png') and os.path.exists(cache_path):
        #
        #     os.remove(cache_path)

        if os.path.exists(cache_path):
            data = np.load(cache_path)
        else:

            data, depth = self.transform(depth_pth, depth_minmax_pth)
            assert data.shape[0] > 600, 'Only {} points found'.format(data.shape[0])
            data = data[np.random.choice(data.shape[0], 600, replace=False)]
            np.save(cache_path, data)

        return data




class DepthToSingleViewPoints(object):
    '''
    render a view then save mask
    '''
    def __init__(self, cam_ext, cam_int):

        self.cam_ext = cam_ext.reshape(4,4)
        self.cam_int = cam_int.reshape(3,3)


    def __call__(self, depth_pth, depth_minmax_pth):

        depth_minmax = np.load(depth_minmax_pth)
        depth_img = plt.imread(depth_pth)[...,0]
        mask = np.where(depth_img == 0, -1.0, 1.0)
        depth_img = 1 - depth_img
        depth_img = (depth_img * (np.max(depth_minmax) - np.min(depth_minmax)) + np.min(depth_minmax)) * mask

        intr = o3d.camera.PinholeCameraIntrinsic(depth_img.shape[0], depth_img.shape[1],
                                                 self.cam_int[0, 0], self.cam_int[1, 1], self.cam_int[0,2],
                                                 self.cam_int[1,2])

        depth_im = o3d.geometry.Image(depth_img.astype(np.float32, copy=False))

        # rgbd_im = o3d.geometry.RGBDImage.create_from_color_and_depth(color_im, depth_im)
        pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_im, intr, self.cam_ext, depth_scale=1.)
        pc =  np.asarray(pcd.points)

        return pc, depth_img

    def __repr__(self):
        return 'MeshToMaskedVoxel_'+str(self.radius)+str(self.resolution)+str(self.elev )+str(self.azim)+str(self.img_size )