import glob
import math
import numpy as np
import os.path as osp
import pointgroup_ops
import scipy.interpolate as interpolate
import scipy.ndimage as ndimage
import torch
import torch_scatter
from torch.utils.data import Dataset
from typing import Dict, Sequence, Tuple, Union

from ..utils import Instances3D
import pickle as pkl
import os
from pathlib import Path
import yaml
import fpsample

class S3DISDataset(Dataset):

    CLASSES = (
        "ceiling",
        "floor",
        "wall",
        "beam",
        "column",
        "window",
        "door",
        "chair",
        "table",
        "bookcase",
        "sofa",
        "board",
        "clutter",
    )
    BENCHMARK_SEMANTIC_IDXS = [i for i in range(15)]  # NOTE DUMMY values just for save results

    def __init__(self,
                 data_root,
                 prefix,
                 suffix,
                 voxel_cfg=None,
                 training=True,
                 with_label=True,
                 mode=4,
                 with_elastic=True,
                 use_xyz=True,
                 logger=None,
                 use_normalized=False,
                 exclude_zero_gt=False,
                 with_normals=False,
                 resample=False,
                 trainval=False,
                 num_classes=20,
                 stuff_class_ids=[0,1],
                 sub_epoch_size=3000):
        self.data_root = data_root
        self.prefix = prefix
        self.suffix = suffix
        self.voxel_cfg = voxel_cfg
        self.training = training
        self.with_label = with_label
        self.mode = mode
        self.with_elastic = with_elastic
        self.use_xyz = use_xyz
        self.logger = logger
        self.filenames = self.get_filenames()
        self.logger.info(f'Load {self.prefix} dataset: {len(self.filenames)} scans')
        self.use_normalized = use_normalized
        self.exclude_zero_gt = exclude_zero_gt
        self.with_normals = with_normals
        self.resample = resample
        self.num_classes = num_classes

    def get_filenames(self):
        filenames_all = []

        if self.prefix == 'train':
            for p in ['Area_1', 'Area_2', 'Area_3', 'Area_4', 'Area_6']:
                filenames = glob.glob(osp.join(self.data_root, "preprocess", p + "*" + self.suffix))
                assert len(filenames) > 0, f"Empty {p}"
                filenames_all.extend(filenames)
            #filenames_all = filenames_all * 20
        else:
            for p in ['Area_5']:
                filenames = glob.glob(osp.join(self.data_root, "preprocess", p + "*" + self.suffix))
                assert len(filenames) > 0, f"Empty {p}"
                filenames_all.extend(filenames)
        if isinstance(self.prefix, str):
            self.prefix = [self.prefix]

        return sorted(filenames_all)

    def load(self, filename):
        scan_id = osp.basename(filename).replace(self.suffix, "")

        xyz, rgb, semantic_label, instance_label = torch.load(filename)

        superpoint_filename = osp.join(self.data_root, "superpoints", scan_id + ".pth")
        superpoint = torch.load(superpoint_filename)

        N = xyz.shape[0]

        if self.training:  # NOTE Avoid OOM
            if N > 5000000:  # NOTE Avoid OOM
                inds = np.random.choice(N, int(self.voxel_cfg.max_npoint), replace=False)
            else:
                inds = np.random.choice(N, int(N * 0.25), replace=False)
            xyz = xyz[inds]
            rgb = rgb[inds]
            superpoint = superpoint[inds]
    
            superpoint = np.unique(superpoint, return_inverse=True)[1]
    
            semantic_label = semantic_label[inds]
            instance_label = self.getCroppedInstLabel(instance_label, inds)
        elif N > 5000000:  # NOTE Avoid OOM
            inds = np.arange(N)[::4]

            xyz = xyz[inds]
            rgb = rgb[inds]
            superpoint = superpoint[inds]

            superpoint = np.unique(superpoint, return_inverse=True)[1]

            semantic_label = semantic_label[inds]
            instance_label = self.getCroppedInstLabel(instance_label, inds)
        else:
            inds = np.random.choice(N, int(N), replace=False)
            xyz = xyz[inds]
            rgb = rgb[inds]
            superpoint = superpoint[inds]
    
            superpoint = np.unique(superpoint, return_inverse=True)[1]
    
            semantic_label = semantic_label[inds]
            instance_label = self.getCroppedInstLabel(instance_label, inds)

        return xyz, rgb, superpoint, semantic_label, instance_label, None

    def getCroppedInstLabel(self, instance_label, valid_idxs):
        instance_label = instance_label[valid_idxs]
        j = 0
        while j < instance_label.max():
            if len(np.where(instance_label == j)[0]) == 0:
                instance_label[instance_label == instance_label.max()] = j
            j += 1
        return instance_label

    def crop(self, xyz, step=64):
        return super().crop(xyz, step=step)

    def __len__(self):
        return len(self.filenames)

    def transform_train(self, xyz, rgb, superpoint, semantic_label, instance_label, normal=None):
        xyz_middle, normal = self.data_aug(xyz, True, True, True, normal)
        rgb += np.random.randn(3) * 0.1
        xyz = xyz_middle * self.voxel_cfg.scale
        if self.with_elastic:
            xyz = self.elastic(xyz, 6, 40.)
            xyz = self.elastic(xyz, 20, 160.)
        xyz = xyz - xyz.min(0)
        N = xyz.shape[0]

        return xyz, xyz_middle, rgb, superpoint, semantic_label, instance_label, normal


    def transform_test(self, xyz, rgb, superpoint, semantic_label=None, instance_label=None, normal=None):

        xyz_middle, _ = self.data_aug(xyz, False, False, False, None)
        #xyz_middle = xyz
        xyz = xyz_middle * self.voxel_cfg.scale
        xyz -= xyz.min(0)
        N = xyz.shape[0]
        valid_idxs = np.ones(xyz.shape[0], dtype=bool)
        superpoint = np.unique(superpoint[valid_idxs], return_inverse=True)[1]
        if instance_label is not None:
            instance_label = self.get_cropped_inst_label(instance_label, valid_idxs)

        return xyz, xyz_middle, rgb, superpoint, semantic_label, instance_label, normal

    def data_aug(self, xyz, jitter=False, flip=False, rot=False, normal=None):
        m = np.eye(3)
        if jitter:
            m += np.random.randn(3, 3) * 0.1
        if flip:
            m[0][0] *= np.random.randint(0, 2) * 2 - 1  # flip x randomly
        if rot:
            theta = np.random.rand() * 2 * math.pi
            m = np.matmul(
                m,
                [[math.cos(theta), math.sin(theta), 0], [-math.sin(theta), math.cos(theta), 0], [0, 0, 1]])  # rotation
        if normal is not None:
            normal = np.matmul(normal, m)
        return np.matmul(xyz, m), normal


    def crop(self, xyz: np.ndarray) -> Union[np.ndarray, np.ndarray]:
        r"""
        crop the point cloud to reduce training complexity

        Args:
            xyz (np.ndarray, [N, 3]): input point cloud to be cropped

        Returns:
            Union[np.ndarray, np.ndarray]: processed point cloud and boolean valid indices
        """
        xyz_offset = xyz.copy()
        valid_idxs = xyz_offset.min(1) >= 0
        assert valid_idxs.sum() == xyz.shape[0]

        full_scale = np.array([self.voxel_cfg.spatial_shape[1]] * 3)
        room_range = xyz.max(0) - xyz.min(0)
        while valid_idxs.sum() > self.voxel_cfg.max_npoint * 2:
            offset = np.clip(full_scale - room_range + 0.001, None, 0) * np.random.rand(3)
            xyz_offset = xyz + offset
            valid_idxs = (xyz_offset.min(1) >= 0) * ((xyz_offset < full_scale).sum(1) == 3)
            full_scale[:2] -= 32

        if valid_idxs.any():
            return xyz_offset, valid_idxs

        return xyz_offset, valid_idxs

    def elastic(self, xyz, gran, mag):
        """Elastic distortion (from point group)

        Args:
            xyz (np.ndarray): input point cloud
            gran (float): distortion param
            mag (float): distortion scalar

        Returns:
            xyz: point cloud with elastic distortion
        """
        blur0 = np.ones((3, 1, 1)).astype('float32') / 3
        blur1 = np.ones((1, 3, 1)).astype('float32') / 3
        blur2 = np.ones((1, 1, 3)).astype('float32') / 3

        bb = np.abs(xyz).max(0).astype(np.int32) // gran + 3
        noise = [np.random.randn(bb[0], bb[1], bb[2]).astype('float32') for _ in range(3)]
        noise = [ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        noise = [ndimage.filters.convolve(n, blur0, mode='constant', cval=0) for n in noise]
        noise = [ndimage.filters.convolve(n, blur1, mode='constant', cval=0) for n in noise]
        noise = [ndimage.filters.convolve(n, blur2, mode='constant', cval=0) for n in noise]
        ax = [np.linspace(-(b - 1) * gran, (b - 1) * gran, b) for b in bb]
        interp = [interpolate.RegularGridInterpolator(ax, n, bounds_error=0, fill_value=0) for n in noise]

        def g(xyz_):
            return np.hstack([i(xyz_)[:, None] for i in interp])

        return xyz + g(xyz) * mag

    def get_cropped_inst_label(self, instance_label: np.ndarray, valid_idxs: np.ndarray) -> np.ndarray:
        r"""
        get the instance labels after crop operation and recompact

        Args:
            instance_label (np.ndarray, [N]): instance label ids of point cloud
            valid_idxs (np.ndarray, [N]): boolean valid indices

        Returns:
            np.ndarray: processed instance labels
        """
        #print(instance_label)
        #print('inst', instance_label.shape, valid_idxs.sum())
        instance_label = instance_label[valid_idxs]
        
        j = 0
        while j < instance_label.max():
            if len(np.where(instance_label == j)[0]) == 0:
                instance_label[instance_label == instance_label.max()] = j
            j += 1
        return instance_label
    def batch_giou_cross(self, boxes1, boxes2):
        # boxes1: N, 6
        # boxes2: M, 6
        # out: N, M
        boxes1 = boxes1[:, None, :]
        boxes2 = boxes2[None, :, :]
        intersection = torch.prod(
            torch.clamp(
                (torch.min(boxes1[..., 3:], boxes2[..., 3:]) - torch.max(boxes1[..., :3], boxes2[..., :3])), min=0.0
            ),
            -1,
        )  # N

        boxes1_volumes = torch.prod(torch.clamp((boxes1[..., 3:] - boxes1[..., :3]), min=0.0), -1)
        boxes2_volumes = torch.prod(torch.clamp((boxes2[..., 3:] - boxes2[..., :3]), min=0.0), -1)

        union = boxes1_volumes + boxes2_volumes - intersection
        iou = intersection / (union + 1e-6)


        return iou

    def is_within_bb_torch(self, points, bb_min, bb_max):
        return torch.all(points >= bb_min, dim=-1) & torch.all(points <= bb_max, dim=-1)

    def is_box1_in_box2(self, box1, box2, offset=0.05):
        return torch.all((box1[:3] + offset) >= box2[:3]) & torch.all((box1[3:] - offset) <= box2[3:])

    def get_instance3D(self, instance_label, semantic_label, superpoint, coord_float, scan_id):
        num_insts = instance_label.max().item() + 1
        num_points = len(instance_label)
        gt_masks, gt_labels = [], []
        gt_bboxes = []
        
        det_idx = []
        instance_boxes = []
        gt_centers = []
        #instance_boxes = []

        if self.use_normalized:
            scene_min = coord_float.min(0)[0]
            scene_max = coord_float.max(0)[0]


        gt_inst = torch.zeros(num_points, dtype=torch.int64)
        fps_idx = fpsample.fps_sampling(coord_float.numpy(), 250)
        fps_sample = coord_float[fps_idx.astype(np.int32)]
        for i in range(num_insts):
            idx = torch.where(instance_label == i)
            #print(torch.unique(semantic_label[idx]))
            assert len(torch.unique(semantic_label[idx])) == 1
            sem_id = semantic_label[idx][0]
            if semantic_label[idx][0] == -100:
                continue
            
            gt_mask = torch.zeros(num_points)
            gt_mask[idx] = 1
            gt_masks.append(gt_mask)

            gt_label = sem_id
            gt_labels.append(gt_label)
            gt_inst[idx] = (sem_id + 1) * 1000 + i + 1
            #print(gt_inst)

            ### bbox
            xyz_i = coord_float[idx]
            mean_xyz_i = xyz_i.mean(0)
            min_xyz_i = xyz_i.min(0)[0]
            max_xyz_i = xyz_i.max(0)[0]
            center_xyz_i = (min_xyz_i + max_xyz_i) / 2
            hwz_i = (max_xyz_i - min_xyz_i)
            
            gt_bbox = torch.cat([mean_xyz_i, center_xyz_i, hwz_i], dim=0)
            instance_boxes.append(torch.cat([min_xyz_i, max_xyz_i], axis = 0))
            gt_center = mean_xyz_i
            if self.use_normalized:
                mean_xyz_i_norm = (mean_xyz_i - scene_min) / (scene_max - scene_min)
                center_xyz_i_norm = (center_xyz_i - scene_min) / (scene_max - scene_min)
                hwz_i_norm = hwz_i / (scene_max - scene_min)
                gt_center = mean_xyz_i_norm
                gt_bbox = torch.cat([gt_bbox, mean_xyz_i_norm, center_xyz_i_norm, hwz_i_norm], dim=0)
            
            gt_bboxes.append(gt_bbox)
            gt_centers.append(gt_center)


        if len(gt_masks) > 0:
            gt_masks = torch.stack(gt_masks, dim=0)
            gt_spmasks = torch_scatter.scatter_mean(gt_masks.float(), superpoint, dim=-1)
            gt_spmasks = (gt_spmasks > 0.5).float()
        else:
            gt_masks = torch.tensor([])
            gt_spmasks = torch.tensor([])
        
        if instance_boxes:
            instance_boxes = torch.stack(instance_boxes)
            gt_box_masks = self.is_within_bb_torch(
                coord_float[:, None, :], instance_boxes[None, :, :3], instance_boxes[None, :, 3:]
            ) 

            cross_box_iou = self.batch_giou_cross(instance_boxes, instance_boxes) 
            cross_box_iou.fill_diagonal_(0.0)
            n_boxes = len(instance_boxes)
            box_visited = torch.zeros(n_boxes)
            for b1 in range(n_boxes):
                b1_ious = cross_box_iou[b1]
                overlap_cond = (b1_ious > 0.0001) & (box_visited == 0)
                overlap_inds = torch.nonzero(overlap_cond).view(-1).int()
                n_overlap_ = len(overlap_inds)
                if n_overlap_ == 0:
                    box_visited[b1] = 1
                    continue

                for b2 in overlap_inds:
                    intersect_cond = (gt_box_masks[:, b1] == 1) & (gt_box_masks[:, b2] == 1)

                    intersect_inds = torch.nonzero(intersect_cond).view(-1)
                    num_intersect_points = len(intersect_inds)

                    if num_intersect_points == 0:
                        continue
                    if self.is_box1_in_box2(instance_boxes[b1], instance_boxes[b2], offset=0.05):
                        gt_box_masks[intersect_inds, b2] = 0
                        box_visited[b1] = 1
                        break
                    if self.is_box1_in_box2(instance_boxes[b2], instance_boxes[b1], offset=0.05):
                        gt_box_masks[intersect_inds, b1] = 0
                        box_visited[b2] = 1
                        continue
                box_visited[b1] = 1



            gt_box_masks = gt_box_masks.float()


            det_idx = (torch.sum(gt_box_masks, dim=1)) <= 1
            gt_box_masks = gt_box_masks.T
            gt_box_spmasks = torch_scatter.scatter_mean(gt_box_masks.float(), superpoint, dim=-1)
            gt_box_spmasks = (gt_box_spmasks > 0.5).float()
            sp_det_idx = (torch.sum(gt_box_spmasks.T, dim=1)) <= 1

        else:
            gt_box_masks = torch.tensor([])
            gt_box_spmasks = torch.tensor([])
            det_idx = torch.tensor([])
            sp_det_idx = torch.tensor([])

        gt_labels = torch.tensor(gt_labels)


        if len(gt_bboxes) > 0:
            gt_bboxes = torch.stack(gt_bboxes, dim=0)
            gt_centers = torch.stack(gt_centers, dim=0)
        else:
            gt_bboxes = torch.tensor(gt_bboxes)
            gt_centers = torch.tensor(gt_centers)
        assert gt_labels.shape[0] == gt_bboxes.shape[0]

        inst = Instances3D(num_points, gt_instances=gt_inst.numpy())
        inst.gt_sem = semantic_label.long()
        inst.gt_labels = gt_labels.long()
        inst.gt_spmasks = gt_spmasks
        inst.gt_bboxes = gt_bboxes
        inst.gt_masks = gt_masks
        inst.gt_box_masks = gt_box_masks#
        inst.gt_box_spmasks = gt_box_spmasks#
        inst.det_idx = det_idx
        inst.sp_det_idx = sp_det_idx
        inst.gt_centers = gt_centers
        inst.fps_sample = fps_sample
        return inst

    def __getitem__(self, index: int) -> Tuple:
        
        if self.resample:
            if index < self.last_index:
                self.epoch_idx += 1
            if self.trainval:
                iter_ = index + self.epoch_idx * 1513 #378
            else:
                iter_ = index + self.epoch_idx * 1201 #301
            filename = osp.join(self.data_root, self.scan_ids[iter_])
            self.last_index = index
        else:
            filename = self.filenames[index]
        scan_id = osp.basename(filename).replace(self.suffix, '')

        data = self.load(filename)

        data = self.transform_train(*data) if self.training else self.transform_test(*data)
        xyz, xyz_middle, rgb, superpoint, semantic_label, instance_label, normal = data

        # print("normal.shape: ", normal.shape)

        coord = torch.from_numpy(xyz).long()
        coord_float = torch.from_numpy(xyz_middle).float()
        feat = torch.from_numpy(rgb).float()
        superpoint = torch.from_numpy(superpoint)
        if normal is not None:
            normal = torch.from_numpy(normal).float()

        if semantic_label is not None:
            semantic_label = torch.from_numpy(semantic_label).long()
        else:
            semantic_label = torch.ones(xyz.shape[0]).long() * (-100)

        if instance_label is not None:
            instance_label = torch.from_numpy(instance_label).long()
        else:
            instance_label = torch.zeros(xyz.shape[0]).long()

        inst = self.get_instance3D(instance_label, semantic_label, superpoint, coord_float, scan_id)

        #return scan_id, coord, coord_float, feat, superpoint, inst, normal

        return {'scan_id': scan_id,
                'coord': coord,
                'coord_float': coord_float,
                'feat': feat,
                'superpoint': superpoint,
                'inst': inst,
                'normal': normal}

    def collate_fn(self, features) -> Dict:
        scan_ids, coords, coords_float, feats, superpoints, insts = [], [], [], [], [], []
        batch_offsets = [0]
        superpoint_bias = 0
        # batch_points_offsets = [0]
        point_bias = 0
        normals = []

        for i, data in enumerate(features):
            scan_id, coord, coord_float, feat, superpoint, inst, normal = list(data.values())
            

            superpoint += superpoint_bias
            superpoint_bias = superpoint.max().item() + 1
            batch_offsets.append(superpoint_bias)

            scan_ids.append(scan_id)
            coords.append(torch.cat([torch.LongTensor(coord.shape[0], 1).fill_(i), coord], 1))
            coords_float.append(coord_float)
            feats.append(feat)
            superpoints.append(superpoint)
            insts.append(inst)
            #normals.append(normal)

            point_bias += coord_float.shape[0]
            # batch_points_offsets.append(point_bias)

        # merge all scan in batch
        batch_offsets = torch.tensor(batch_offsets, dtype=torch.int)  # int [B+1]
        coords = torch.cat(coords, 0)  # long [B*N, 1 + 3], the batch item idx is put in b_xyz[:, 0]
        coords_float = torch.cat(coords_float, 0)  # float [B*N, 3]
        feats = torch.cat(feats, 0)  # float [B*N, 3]
        superpoints = torch.cat(superpoints, 0).long()  # long [B*N, ]
        if self.use_xyz:
            feats = torch.cat((feats, coords_float), dim=1)

        #if self.with_normals:
        #    normals = torch.cat(normals, dim=0)
        #    feats = torch.cat([feats, normals], dim=1)

        # batch_points_offsets = torch.tensor(batch_points_offsets, dtype=torch.int)
        
        # voxelize
        spatial_shape = np.clip((coords.max(0)[0][1:] + 1).numpy(), self.voxel_cfg.spatial_shape[0], None)  # long [3]
        voxel_coords, p2v_map, v2p_map = pointgroup_ops.voxelization_idx(coords, len(features), self.mode)

        return {
            'scan_ids': scan_ids,
            'voxel_coords': voxel_coords,
            'p2v_map': p2v_map,
            'v2p_map': v2p_map,
            'spatial_shape': spatial_shape,
            'feats': feats,
            'superpoints': superpoints,
            'batch_offsets': batch_offsets,
            'insts': insts,
            'coords_float': coords_float,
            # 'batch_points_offsets': batch_points_offsets,
        }
