import numpy as np
from mmcv.parallel import DataContainer as DC

from mmdet3d.core.points import BasePoints
from mmdet.datasets.builder import PIPELINES
from mmdet.datasets.pipelines import to_tensor

@PIPELINES.register_module()
class FormatBundleMap(object):
    """Format data for map tasks and then collect data for model input.

    These fields are formatted as follows.

    - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True)
    - semantic_mask (if exists): (1) to tensor, (2) to DataContainer (stack=True)
    - vectors (if exists): (1) to DataContainer (cpu_only=True)
    - img_metas: (1) to DataContainer (cpu_only=True)
    """

    def __init__(self, process_img=True, 
                keys=['img', 'semantic_mask', 'vectors'], 
                meta_keys=['intrinsics', 'extrinsics']):
        
        self.process_img = process_img
        self.keys = keys
        self.meta_keys = meta_keys

    def __call__(self, results):
        """Call function to transform and format common fields in results.

        Args:
            results (dict): Result dict contains the data to convert.

        Returns:
            dict: The result dict contains the data that is formatted with
                default bundle.
        """
        # Format 3D data
        if 'points' in results:
            assert isinstance(results['points'], BasePoints)
            results['points'] = DC(results['points'].tensor)

        for key in ['voxels', 'coors', 'voxel_centers', 'num_points']:
            if key not in results:
                continue
            results[key] = DC(to_tensor(results[key]), stack=False)

        if 'img' in results and self.process_img:
            if isinstance(results['img'], list):
                # process multiple imgs in single frame
                imgs = [img.transpose(2, 0, 1) for img in results['img']]
                imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
                results['img'] = DC(to_tensor(imgs), stack=True)
            else:
                img = np.ascontiguousarray(results['img'].transpose(2, 0, 1))
                results['img'] = DC(to_tensor(img), stack=True)
        
        if 'semantic_mask' in results:
            results['semantic_mask'] = DC(to_tensor(results['semantic_mask']), stack=True)

        if 'vectors' in results:
            # vectors may have different sizes
            vectors = results['vectors']
            results['vectors'] = DC(vectors, stack=False, cpu_only=True)
        
        if 'polys' in results:
            results['polys'] = DC(results['polys'], stack=False, cpu_only=True)

        return results

    def __repr__(self):
        """str: Return a string that describes the module."""
        repr_str = self.__class__.__name__
        repr_str += f'(process_img={self.process_img}, '
        return repr_str