import os
from typing import Union

import torch

from opencood.utils.box_utils import project_points_by_matrix
from opencood.visualization import simple_vis
from opencood.tools.inference_utils import get_center_box, boxes_to_corners_3d


class VisUtil:
    # global attr
    is_vis: bool = False
    root_path: str = './work_dirs/vis'
    lidar_range: list = None
    left_hand: bool = False
    save_vis_n: int = -1
    training: bool = False
    fusion_debug: bool = False
    
    # iter attr
    idx: int = 0
    frame_id: str = '000000'
    scene_points: Union[list, torch.Tensor] = None
    cluster_points: Union[list, torch.Tensor] = None
    pred_bboxes: torch.Tensor = None
    gt_bboxes: torch.Tensor = None
    
    @staticmethod
    def is_vis_now():
        return VisUtil.is_vis and VisUtil.idx % VisUtil.save_vis_n == 0
    
    @staticmethod
    def get_vis_path(vis_type, name):
        cur_save_path = os.path.join(VisUtil.root_path, f'{vis_type}_{name}')
        if not os.path.exists(cur_save_path):
            os.makedirs(cur_save_path)
        return cur_save_path
    
    @staticmethod
    def get_image_path(path, idx, frame_id, stage='one', prefix=''):
        name = f'{idx}_{frame_id}_{stage}.png' if prefix == '' else \
               f'{idx}_{frame_id}_{stage}_{prefix}.png'

        return os.path.join(path, f'{idx}_{frame_id}_{stage}.png')
    
    @staticmethod
    def get_vis_points(points, valid_mask=None, vote_targets=None, vote_mask=None, transform=None):
        vote_points = []    # voted point's center
        if valid_mask is None:
            pass
        if vote_targets is not None and vote_mask is not None:
            offsets = vote_targets
            # TODO split transform points and offset is uncorrect, why?
            offset_points = points[vote_mask, :3] + offsets[vote_mask]
            if transform is not None:
                offset_points = project_points_by_matrix(offset_points, transform.float())
            vote_points.append(offset_points)       # vote target center, blue
            
        if transform is not None:
            points = project_points_by_matrix(points[:, :3], transform.float())
        vis_points = [
            points[valid_mask],        # foreground points
            points[~valid_mask],       # background points
        ]
        
        return vis_points + vote_points

    @staticmethod
    def get_scene_points(points, vote_target, pts_coors=None, valid_mask=None, 
                         img_metas=None, proj_first=True, ignore_proj=False):
        """Get scene points, this function should be used when scene level batch is 1

        Args:
            points (Torch): shape(Nx3) 
            vote_target (Torch): 
            pts_coors (Torch, optional): car level batch idx of points. Defaults to None.
            valid_mask (Torch, optional): Foreground mask. Defaults to None.
            img_metas (Torch, optional):  Defaults to None.
            proj_first (bool, optional): whether project first. Defaults to True.
        """
        if proj_first:
            proj_first = img_metas[0]['proj_first']
        if pts_coors is not None:
            if len(pts_coors.shape) == 1:
                pts_coors = pts_coors.unsqueeze(1)
    
        if proj_first:
            scene_points = VisUtil.get_vis_points(points, valid_mask, vote_target, valid_mask, None)
            VisUtil.scene_points = scene_points
        else:
            scene_points = []
            for i in range(pts_coors[:, 0].max() + 1):
                bs_mask = (pts_coors[:, 0] == i)
                _vote_target = vote_target[bs_mask] if vote_target is not None else None
                _vote_mask = valid_mask[bs_mask] if vote_target is not None else None
                agent_vis_points = VisUtil.get_vis_points(points[bs_mask], valid_mask[bs_mask], _vote_target, _vote_mask, 
                                                        img_metas[i]['proj2ego_matrix'])
                if ignore_proj:
                    agent_vis_points = VisUtil.get_vis_points(points[bs_mask], valid_mask[bs_mask], _vote_target, _vote_mask, 
                                                        None)
                    agent_vis_points[2] = torch.zeros(1, 3)
                scene_points += agent_vis_points
            VisUtil.scene_points = scene_points
    
    @staticmethod
    def get_cluster_points(cluster_xyz, cluster_inds, agent_num=0):
        cluster_points = []
        if agent_num == 0:
            agent_num = cluster_inds.max() + 1
        for i in range(agent_num):
            mask = cluster_inds[:, 1] == i
            cluster_xyz_agent = cluster_xyz[mask]
            cluster_points.append(cluster_xyz_agent)
        VisUtil.cluster_points = cluster_points
    
    @staticmethod
    def set_args(**argv):
        for key, value in argv.items():
            setattr(VisUtil, key, value)
    
    @staticmethod
    def get_stage_arg(key, stage='one'):
        if key not in ['pred_bboxes', 'gt_bboxes']:
            return getattr(VisUtil, key)
        else:
            value = getattr(VisUtil, key)
            if isinstance(value, dict):
                return value[f"{stage}_stage"]
            else:
                return value

    @staticmethod
    def vis(method='bev', vis_pred_bbox=False, vis_gt_bbox=False, 
            vis_cluster=False, stage='one', prefix=None):
        """Vis 

        Args:
            method (str, optional): _description_. Defaults to 'bev'.
            vis_pred_bbox (bool, optional): _description_. Defaults to False.
            vis_gt_bbox (bool, optional): _description_. Defaults to False.
            vis_cluster (bool, optional): _description_. Defaults to False.
            stage (str, optional): _description_. Defaults to 'one'.
        """
        # set image path
        name = f'{stage}'
        name += '_pred' if vis_pred_bbox else ''
        name += '_gt' if vis_gt_bbox else ''
        name += f'_{prefix}' if prefix is not None else ''
        path = VisUtil.get_vis_path(method, name)
        img_path = VisUtil.get_image_path(path, VisUtil.idx, VisUtil.frame_id, stage)
        simple_vis.visualize(
            VisUtil.get_stage_arg('pred_bboxes', stage) if vis_pred_bbox else None,
            VisUtil.get_stage_arg('gt_bboxes', stage)   if vis_gt_bbox else None,
            VisUtil.scene_points,
            VisUtil.lidar_range,
            img_path,
            method=method,
            left_hand=VisUtil.left_hand,
            vis_gt_box=vis_gt_bbox,
            vis_pred_box=vis_pred_bbox,
            cluster_pcd=VisUtil.cluster_points if vis_cluster else None
        )
    
    @staticmethod
    def vis_with_stages(method='bev', vis_pred_bbox=False, vis_gt_bbox=False, 
            vis_cluster=False, is_two_stage=False):
        if is_two_stage:
            VisUtil.vis(method, vis_pred_bbox, vis_gt_bbox, vis_cluster, 'one')
            VisUtil.vis(method, vis_pred_bbox, vis_gt_bbox, vis_cluster, 'two')
        else:
            VisUtil.vis(method, vis_pred_bbox, vis_gt_bbox, vis_cluster, 'one')
    
    @staticmethod
    def get_preds_and_gts_with_format(pred_dicts):
        if isinstance(pred_dicts['pred_box'], dict):
            pred_bboxes = dict()
            gt_bboxes = dict()
            for key in pred_dicts['pred_box'].keys():
                pred_bboxes[key] = pred_dicts['pred_box'][key]
                gt_bboxes[key] = pred_dicts['gt_box']
        else:
            pred_bboxes = pred_dicts['pred_box']
            gt_bboxes = pred_dicts['gt_box']
        return pred_bboxes, gt_bboxes