import json
import os
import time
from abc import ABC
from typing import Optional

import imageio
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import Tensor
from segment_anything import (SamAutomaticMaskGenerator, SamPredictor,
                              sam_model_registry)
from tqdm import tqdm

from . import utils
# from .scene_property import INPUT_BOX, INPUT_POINT
from .self_prompting import mask_to_prompt
from .prepare_prompts import get_prompt_points


class Sam3D(ABC):
    '''TODO, add discription'''
    def __init__(self, args, cfg, device=torch.device('cuda')):
        self.cfg = cfg
        self.args = args
        sam_checkpoint = "./data/sam_ckpt/sam_vit_h_4b8939.pth"
        model_type = "vit_h"
        self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device)
        self.predictor = SamPredictor(self.sam)
        print("SAM initializd.")
        self.step_size = cfg.fine_model_and_render.stepsize
        self.device = device
        self.seg_everything = self.args.segment_everything
        self.segment = True if args.segment or self.seg_everything else False
        self.e_flag = '_everything' if self.seg_everything else ''
        self.e_flag = args.sp_name if args.sp_name is not None else self.e_flag
        self.base_save_dir = os.path.join(cfg.basedir, cfg.expname)
        # for interactive backend
        self.context = {'num_clicks': 0, 'click': []}


    def forward(self, xyz_min, xyz_max, cfg_model, cfg_train, data_dict, stage='coarse', coarse_ckpt_path=None):
        '''TODO, add discription'''
        if abs(cfg_model.world_bound_scale - 1) > 1e-9:
            xyz_shift = (xyz_max - xyz_min) * (cfg_model.world_bound_scale - 1) / 2
            xyz_min -= xyz_shift
            xyz_max += xyz_shift
        
        # e_flag = '_everything' if self.args.segment_everything else ''

        # find whether there is existing checkpoint path
        last_ckpt_path = os.path.join(self.base_save_dir, f'fine_last.tar')
        if self.args.no_reload:
            reload_ckpt_path = None
        elif self.args.ft_path:
            reload_ckpt_path = self.args.ft_path
        elif coarse_ckpt_path is not None and os.path.isfile(last_ckpt_path):
            reload_ckpt_path = coarse_ckpt_path
        elif os.path.isfile(last_ckpt_path):
            reload_ckpt_path = last_ckpt_path
        else:
            reload_ckpt_path = None

        # init model and optimizer
        assert reload_ckpt_path is not None and 'segmentation must based on a pretrained NeRF'
        print(f'scene_rep_reconstruction ({stage}): reload from {reload_ckpt_path}')
        model, optimizer, start = utils.load_existed_model(self.args, self.cfg, 
            cfg_train, reload_ckpt_path, self.device)

        if self.args.freeze_density:
            for param in model.named_parameters():
                if 'density' in param[0]:
                    param[1].requires_grad = False

        if self.args.freeze_rgb:
            for param in model.named_parameters():
                if 'rgbnet' in param[0] or ('f_k0' not in param[0] and 'k0' in param[0]):
                    param[1].requires_grad = False

        if self.args.freeze_feature:
            for param in model.named_parameters():
                if 'f_k0' in param[0]:
                    param[1].requires_grad = False
        
        if stage == 'fine':
            model.change_to_fine_mode()
            print("Segmentation model: FINE MODE.")
        else:
            print("Segmentation model: COARSE MODE.")

        # in case OOM            
        torch.cuda.empty_cache()

        render_viewpoints_kwargs = {
                'model': model,
                'ndc': self.cfg.data.ndc,
                'render_kwargs': {
                    'near': data_dict['near'],
                    'far': data_dict['far'],
                    'bg': 1 if self.cfg.data.white_bkgd else 0,
                    'stepsize': self.step_size,
                    'inverse_y': self.cfg.data.inverse_y,
                    'flip_x': self.cfg.data.flip_x,
                    'flip_y': self.cfg.data.flip_y,
                    'render_depth': True,
                },
            }
        optimizer = utils.create_segmentation_optimizer(model, cfg_train)

        # render_poses = self.generate_rendering_poses(data_dict['poses'][data_dict['i_train']]) # TODO, find the best pose for prompt
        # render_poses = torch.flip(data_dict['render_poses'], dims=[0])
        render_poses = data_dict['render_poses']
        rgbs, depths, bgmaps, segs, dual_segs = self.train_step(
            render_poses=data_dict['poses'][data_dict['i_train']],
            HW=data_dict['HW'][data_dict['i_train']],
            Ks=data_dict['Ks'][data_dict['i_train']],
            # render_poses=render_poses,
            # HW=data_dict['HW'][data_dict['i_test']][[0]].repeat(len(render_poses), 0),
            # Ks=data_dict['Ks'][data_dict['i_test']][[0]].repeat(len(render_poses), 0),
            optimizer = optimizer,
            render_factor=self.args.render_video_factor,
            stage = stage,
            **render_viewpoints_kwargs
        )
        # save video
        try:
            testsavedir = os.path.join(self.base_save_dir, f'render_video_{stage}_segmentation')
            os.makedirs(testsavedir, exist_ok=True)
            imageio.mimwrite(os.path.join(testsavedir, 'video.rgb'+self.e_flag+'.mp4'), utils.to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(os.path.join(testsavedir, 'video.seg'+self.e_flag+'.mp4'), utils.to8b(segs), fps=30, quality=8)
            imageio.mimwrite(os.path.join(testsavedir, 'video.sam_seg'+self.e_flag+'.mp4'), utils.to8b(self.sam_segs), fps=30, quality=8)
            if stage == 'fine':
                imageio.mimwrite(os.path.join(testsavedir, 'video.dual_seg'+self.e_flag+'.mp4'), utils.to8b(dual_segs), fps=30, quality=8)
                imageio.mimwrite(os.path.join(testsavedir, 'video.dual_sam_seg'+self.e_flag+'.mp4'), utils.to8b(self.dual_sam_segs), fps=30, quality=8)
            if self.e_flag != '':
                seg_on_rgb = [0.3*rgb + 0.7*seg for rgb,seg in zip(rgbs, segs)]
                imageio.mimwrite(os.path.join(testsavedir, 'video.training_seg_on_rgb'+self.e_flag+'.mp4'), utils.to8b(seg_on_rgb), fps=30, quality=8)
                
        except:
            print("There are Some error with generating videos, please check here.")

        torch.save({
            'model_kwargs': model.get_kwargs(),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, os.path.join(self.base_save_dir, f'{stage}_segmentation'+self.e_flag+'.tar'))
        print(f'scene_rep_reconstruction ({stage}): saved checkpoints at', os.path.join(self.base_save_dir, f'{stage}_segmentation'+self.e_flag+'.tar'))


    def train_step(self, model, render_poses, HW, Ks, ndc, optimizer, render_kwargs, 
                    render_factor=0, render_fct=0.0, stage = 'coarse'):
        '''TODO, add comments'''
        assert len(render_poses) == len(HW) and len(HW) == len(Ks)
        if render_factor!=0:
            HW = np.copy(HW)
            Ks = np.copy(Ks)
            HW = (HW/render_factor).astype(int)
            Ks[:, :2, :3] /= render_factor

        # get obj num for seg_everything
        num_obj = model.seg_mask_grid.grid.shape[1]
        rand_colors = None if stage == 'coarse' else utils.gen_rand_colors(num_obj)

        if stage == 'coarse' and self.seg_everything:
            model.seg_mask_grid.grid.requires_grad = False

        if self.seg_everything: # NOTE: WIP
            # mask_generator = SamAutomaticMaskGenerator(sam)
            mask_generator = SamAutomaticMaskGenerator(
                    model=self.sam,
                    points_per_side=5,
                    pred_iou_thresh=0.90,
                    stability_score_thresh=0.90,
                    crop_n_layers=0,
                    crop_n_points_downscale_factor=1,
                    min_mask_region_area=0,  # Requires open-cv to run post-processing
                )

        if stage == 'fine':
            num_epochs = 1 
        else:
            num_epochs = 1 # DEBUG, test longer fine epoch

        # main training loop
        rgbs, segs, depths, bgmaps, dual_segs = [], [], [], [], []
        self.sam_segs, self.dual_sam_segs = [], []
        for eph in range(num_epochs):
            for i, c2w in enumerate(tqdm(render_poses)):
                optimizer.zero_grad(set_to_none=True)
                # get data
                H, W = HW[i]; K = Ks[i]
                rays_o, rays_d, viewdirs = utils.get_rays_of_a_view(
                        H, W, K, c2w, ndc, inverse_y=render_kwargs['inverse_y'],
                        flip_x=self.cfg.data.flip_x, flip_y=self.cfg.data.flip_y)
                
                # keys = ['rgb_marched', 'depth', 'alphainv_last']
                # if self.segment: keys.append('seg_mask_marched')
                keys = ['rgb_marched', 'depth', 'alphainv_last', 'seg_mask_marched']
                if stage == 'fine': keys.append('dual_seg_mask_marched')
                rays_o, rays_d, viewdirs = [arr.flatten(0, -2) for arr in [rays_o, rays_d, viewdirs]]
                render_result_chunks = [
                    {k: v for k, v in model(ro, rd, vd, distill_active=False, render_fct=render_fct, **render_kwargs).items() if k in keys}
                    for ro, rd, vd in zip(rays_o.split(8192, 0), rays_d.split(8192, 0), viewdirs.split(8192, 0))
                ]
                render_result = {
                    k: torch.cat([ret[k] for ret in render_result_chunks]).reshape(H,W,-1)
                    for k in render_result_chunks[0].keys()
                }
                rgb = render_result['rgb_marched'].cpu().numpy()
                seg_m = render_result['seg_mask_marched'] if self.segment else None
                dual_seg_m = render_result['dual_seg_mask_marched'] if stage == 'fine' else None

                depth = render_result['depth'].cpu().numpy()
                bgmap = render_result['alphainv_last'].cpu().numpy()
                rgbs.append(rgb)
                depths.append(depth)
                bgmaps.append(bgmap)

                if self.segment:
                    if not self.seg_everything:
                        segs.append(seg_m.detach().cpu().numpy() / (seg_m.max().item()+1e-7))
                        if stage == 'fine':
                            dual_segs.append(dual_seg_m.detach().cpu().numpy() / (dual_seg_m.max().item()+1e-7))
                elif rand_colors is not None:
                    obj_mean = torch.mean(seg_m.view(-1,num_obj), dim = 0, keepdim=True)
                    obj_std = torch.std(seg_m.view(-1,num_obj), dim = 0, keepdim=True)
                    # generate unknown mask
                    unknown_mask = seg_m < obj_mean + obj_std
                    tmp_seg_m = seg_m.detach().clone()
                    tmp_seg_m[unknown_mask == 1] = -100
                    unknown_mask = unknown_mask.sum(-1)
                    unknown_mask = unknown_mask.detach().cpu().numpy()
                    seg_labels = np.argmax(tmp_seg_m.detach().cpu().numpy(), axis = -1)
                    seg_labels[unknown_mask == num_obj] = num_obj
                    segs.append(rand_colors[seg_labels])

                self.init_image = utils.to8b(rgb)
                self.predictor.set_image(self.init_image)
                index_matrix = _generate_index_matrix(H, W, render_result['depth'].detach().clone())

                if stage == 'coarse': # coarse stage, get sam seg
                    if i == 0 and eph == 0:
                        print("The first view, in which an initialized prompt is set by user")
                        plt.figure(figsize=(10,10))
                        plt.imshow(utils.to8b(rgb))
                        plt.axis('on')
                        plt.savefig('tmp.jpg')
                        if self.seg_everything:
                            '''NOTE: this mode is WIP, seg everything mode, try to seg all objects'''
                            with torch.no_grad():
                                input_boxes = INPUT_BOX[f'{self.args.scene}'].to(self.device) # 
                                transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, self.init_image.shape[:2])
                                masks, scores, logits = self.predictor.predict_torch(
                                    point_coords=None,
                                    point_labels=None,
                                    boxes=transformed_boxes,
                                    multimask_output=False,
                                )
                                target_masks = masks.squeeze(1).permute([1,2,0]).float()
                                for mask_i in range(len(masks)):
                                    plt.figure(figsize=(10,10))
                                    plt.imshow(target_masks[:,:,mask_i].detach().cpu().numpy())
                                    plt.axis('on')
                                    plt.savefig('tmp_mask_batched_'+str(mask_i)+'.jpg')
                                num_obj = len(masks)
                                # introducing unknown
                                rand_colors = np.random.rand(num_obj + 1, 3)
                                rand_colors[-1,:] = 0

                            model.change_num_objects(num_obj)
                            model.seg_mask_grid.grid.requires_grad = True
                            model.dual_seg_mask_grid.grid.requires_grad = True

                            print("Reset optimizer for new mask_grid")
                            optimizer = utils.create_optimizer_or_freeze_model(model, cfg_train = self.cfg.coarse_train, global_step=0)
                            optimizer.zero_grad(set_to_none=True)

                            # the initialized num_objects is 1, reset it to num_obj and re-run the mask rendering
                            re_render_result_chunks = [
                                {k: v for k, v in model(ro, rd, vd, distill_active=False, render_fct=render_fct, **render_kwargs).items() if k in keys}
                                for ro, rd, vd in zip(rays_o.split(8192, 0), rays_d.split(8192, 0), viewdirs.split(8192, 0))
                            ]
                            seg_m = torch.cat([ret['seg_mask_marched'] for ret in re_render_result_chunks]).reshape(H, W, -1)

                            segs = []
                            obj_mean = torch.mean(seg_m.view(-1,num_obj), dim = 0, keepdim=True)
                            obj_std = torch.std(seg_m.view(-1,num_obj), dim = 0, keepdim=True)
                            unknown_mask = seg_m < obj_mean + obj_std
                            unknown_mask = unknown_mask.sum(-1)
                            unknown_mask = unknown_mask.detach().cpu().numpy()
                            seg_labels = np.argmax(seg_m.detach().cpu().numpy(), axis = -1)
                            seg_labels[unknown_mask == num_obj] = num_obj
                            segs.append(rand_colors[seg_labels])
                            # optimise
                            loss = seg_loss(target_masks, None, seg_m, 100)
                            optim(optimizer, loss)
                        else:
                            '''user define mode'''
                            masks, scores, logits, selected_mask = self.seg_init_frame_coarse()
                            loss = seg_loss(masks, selected_mask, seg_m)
                    # further views
                    else:
                        loss = self.prompting_coarse(H, W, seg_m, index_matrix, num_obj)
                    optim(optimizer, loss)
                elif stage == 'fine':
                    # fine stage
                    if i == 0 and eph == 0:
                        assert model.num_objects == 1, "only support single object segmentation now"
                        assert stage == 'fine', "stage should be in 'coarse' or 'fine'"
                        target_mask, dual_target = self.seg_init_frame_fine(seg_m, model, dual_seg_m)
                        # standard segmentation loss
                        loss = seg_loss(target_mask, None, seg_m)
                        # dual segmentation loss
                        loss += seg_loss(dual_target, None, dual_seg_m)
                    # further views
                    else:
                        loss = self.prompting_fine(H, W, seg_m, dual_seg_m, index_matrix, num_obj)
                    optim(optimizer, loss)  
                else:
                    raise NotImplementedError

        return rgbs, depths, bgmaps, segs, dual_segs


    def generate_rendering_poses(self, poses_from_data):
        '''TODO, find the best pose for rendering'''
        return poses_from_data


    def seg_init_frame_coarse(self):
        '''for coarse stage init, we need to set a prompt for the user to select a mask'''
        with torch.no_grad():
            prompts = get_prompt_points(self.args, sam=self.predictor, 
                    ctx=self.context, init_rgb=self.init_image)
            input_point = prompts['prompt_points']
            input_label = np.ones(len(input_point))

            masks, scores, logits = self.predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )

        if prompts['mask_id'] is None:
            for j, mask in enumerate(masks): 
                ### for selection
                plt.figure(figsize=(10,10))
                plt.imshow(mask)
                plt.axis('on')
                plt.savefig('tmp_mask_'+str(j)+'.jpg')

            selected_mask = int(input("Please select a mask:"))
        else:
            selected_mask = prompts['mask_id']

        # record the selected prompt and mask
        with open(os.path.join(self.base_save_dir, "user-specific-prompt.json"), 'w') as f:
            prompt_dict = {
                "mask_id": selected_mask,
                "prompt_points": input_point.tolist()
            }
            json.dump(prompt_dict, f)
        print(f"Prompt saved in {os.path.join(self.base_save_dir, 'user-specific-prompt.json')}")
        
        sam_seg_show = masks[selected_mask].astype(np.float32)
        sam_seg_show = np.stack([sam_seg_show,sam_seg_show,sam_seg_show], axis = -1)
        for ip, point in enumerate(input_point):
            sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
            if ip < 3:
                sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, ip] = 1
            else:
                sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, 2] = 1
        self.sam_segs.append(sam_seg_show)

        return masks, scores, logits, selected_mask


    def seg_init_frame_fine(self, seg_m, model, dual_seg_m):
        '''for fine stage, we load the user-specific prompt and mask'''
        # get the recorded user-specific prompt and the corresponding mask
        with open(os.path.join(self.base_save_dir, "user-specific-prompt.json"), 'r') as f:
            prompt_dict = json.load(f)
            mask_id = prompt_dict['mask_id']
            input_point = np.array(prompt_dict['prompt_points'])

        with torch.no_grad():
            # input_point = get_prompt_points(self.args)['prompt_points']
            input_label = np.ones(len(input_point))

            masks, scores, logits = self.predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                multimask_output=True,
            )

        print("user-specific-prompt loaded, the specified prompt mask id is:", mask_id)
        target_mask = torch.as_tensor(masks[mask_id]).float().to(seg_m.device)
        
        # the rendered segmentation result
        tmp_rendered_mask = seg_m[:,:,0].detach().clone()
        tmp_rendered_mask[torch.logical_or(tmp_rendered_mask <= tmp_rendered_mask.mean(), tmp_rendered_mask <= 0)] = 0
        tmp_rendered_mask[tmp_rendered_mask != 0] = 1

        # get the dual segmentation target
        dual_target = torch.zeros_like(tmp_rendered_mask)
        dual_target[(tmp_rendered_mask - target_mask) == 1] = 1
        
        IoU = utils.cal_IoU(tmp_rendered_mask, target_mask)
        print("Current IoU is", IoU)
        if IoU > 0.9:
            print("IoU is larger than 0.9, no refinement is required. Use Ctrl+C to cancel the fine stage training.")
            time.sleep(5)
            print("Begin refinement.")
            
        
        model.seg_mask_grid.grid.data = torch.zeros_like(model.seg_mask_grid.grid)
        model.dual_seg_mask_grid.grid.data = torch.zeros_like(model.seg_mask_grid.grid)

        segs, dual_segs = [], []
        segs.append(seg_m.detach().cpu().numpy() / (seg_m.max().item()+1e-7))
        dual_segs.append(dual_seg_m.detach().cpu().numpy() / (dual_seg_m.max().item()+1e-7))
        
        sam_seg_show = masks[mask_id].astype(np.float32)
        sam_seg_show = np.stack([sam_seg_show,sam_seg_show,sam_seg_show], axis = -1)
        self.sam_segs.append(sam_seg_show)
        dual_sam_seg_show = dual_target.detach().cpu().numpy().astype(np.float32)
        dual_sam_seg_show = np.stack([dual_sam_seg_show,dual_sam_seg_show,dual_sam_seg_show], axis = -1)
        self.dual_sam_segs.append(dual_sam_seg_show)

        return target_mask, dual_target


    def prompting_coarse(self, H, W, seg_m, index_matrix, num_obj):
        '''TODO, for coarse stage, we use the self-prompting method to generate the prompt and mask'''
        seg_m_clone = seg_m.detach().clone()
        seg_m_for_prompt = seg_m_clone
        # kernel_size = 3
        # padding = kernel_size // 2
        # seg_m_for_prompt = torch.nn.functional.avg_pool2d(seg_m_clone.permute([2,0,1]).unsqueeze(0), kernel_size, stride = 1, padding = padding)
        # seg_m_for_prompt = seg_m_for_prompt.squeeze(0).permute([1,2,0])

        loss = 0

        for num in range(num_obj):
            with torch.no_grad():
                # self-prompting
                prompt_points, input_label = mask_to_prompt(predictor = self.predictor, rendered_mask_score = seg_m_for_prompt[:,:,num][:,:,None], 
                                                            index_matrix = index_matrix, num_prompts = self.args.num_prompts)

                masks, selected = None, -1
                if len(prompt_points) != 0:
                    masks, scores, logits = self.predictor.predict(
                        point_coords=prompt_points,
                        point_labels=input_label,
                        multimask_output=False,
                    )
                    selected = np.argmax(scores)

            if num == 0:
                # used for single object only
                sam_seg_show = masks[selected].astype(np.float32) if masks is not None else np.zeros((H,W))
                sam_seg_show = np.stack([sam_seg_show,sam_seg_show,sam_seg_show], axis = -1)
                for ip, point in enumerate(prompt_points):
                    sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                    if ip < 3:
                        sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, ip] = 1
                    else:
                        sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, -1] = 1
                    
                self.sam_segs.append(sam_seg_show)

            if masks is not None:
                tmp_seg_m = seg_m[:,:,num]
                tmp_rendered_mask = tmp_seg_m.detach().clone()
                tmp_rendered_mask[torch.logical_or(tmp_rendered_mask <= tmp_rendered_mask.mean(), tmp_rendered_mask <= 0)] = 0
                tmp_rendered_mask[tmp_rendered_mask != 0] = 1
                tmp_IoU = utils.cal_IoU(torch.as_tensor(masks[selected]).float(), tmp_rendered_mask)
                print(f"current IoU is: {tmp_IoU}")
                if tmp_IoU < 0.5:
                    print("SKIP, unacceptable sam prediction, IoU is", tmp_IoU)
                    continue

                loss += seg_loss(masks, selected, tmp_seg_m, 0.15)
                for neg_i in range(seg_m.shape[-1]):
                    if neg_i == num:
                        continue
                    loss += (torch.tensor(masks[selected]).to(seg_m.device) * seg_m[:,:,neg_i]).sum()
        return loss


    def prompting_fine(self, H, W, seg_m, dual_seg_m, index_matrix, num_obj):
        '''TODO, for fine stage, we use the self-prompting method to generate the prompt and mask'''
        loss = 0
        # get the prompt of interest
        seg_m_clone = seg_m.detach().clone()
        seg_m_for_prompt = torch.nn.functional.avg_pool2d(seg_m_clone.permute([2,0,1]).unsqueeze(0), 25, stride = 1, padding = 12)
        seg_m_for_prompt = seg_m_for_prompt.squeeze(0).permute([1,2,0])
        # get the dual prompt of interest
        dual_seg_m_clone = dual_seg_m.detach().clone()
        dual_seg_m_for_prompt = torch.nn.functional.avg_pool2d(dual_seg_m_clone.permute([2,0,1]).unsqueeze(0), 25, stride = 1, padding = 12)
        dual_seg_m_for_prompt = dual_seg_m_for_prompt.squeeze(0).permute([1,2,0])
        
        for num in range(num_obj):
            tmp_seg_m = seg_m[:,:,num]
            dual_tmp_seg_m = dual_seg_m[:,:,num]
            
            with torch.no_grad():
                # rendered segmentation mask
                tmp_rendered_mask = tmp_seg_m.detach().clone()
                tmp_rendered_mask[torch.logical_or(tmp_rendered_mask <= tmp_rendered_mask.mean(), tmp_rendered_mask <= 0)] = 0
                tmp_rendered_mask[tmp_rendered_mask != 0] = 1

                # rendered dual segmentation mask
                tmp_rendered_dual_mask = dual_tmp_seg_m.detach().clone()
                tmp_rendered_dual_mask[torch.logical_or(tmp_rendered_dual_mask <= tmp_rendered_dual_mask.mean(), tmp_rendered_dual_mask <= 0)] = 0
                tmp_rendered_dual_mask[tmp_rendered_dual_mask != 0] = 1

            
                # self-prompting
                ori_prompt_points, ori_input_label = mask_to_prompt(predictor = self.predictor, \
                    rendered_mask_score = seg_m_for_prompt[:,:,num].unsqueeze(-1), index_matrix = index_matrix, num_prompts = self.args.num_prompts)
                num_self_prompts = len(ori_prompt_points)

                # dual self-prompting
                dual_prompt_points, dual_input_label = mask_to_prompt(predictor = self.predictor, \
                    rendered_mask_score = dual_seg_m_for_prompt[:,:,num].unsqueeze(-1), index_matrix = index_matrix, num_prompts = self.args.num_prompts)                
                num_dual_self_prompts = len(dual_prompt_points)

                masks, dual_masks = None, None
                # self-prompting
                if num_self_prompts != 0:
                    prompt_points = np.concatenate([ori_prompt_points, dual_prompt_points], axis = 0)
                    input_label = np.concatenate([ori_input_label, 1-dual_input_label], axis = 0)
                    # generate mask
                    masks, scores, logits = self.predictor.predict(
                        point_coords=prompt_points,
                        point_labels=input_label,
                        multimask_output=False,
                    )
                    
                # dual self-prompting
                if num_dual_self_prompts != 0:
                    prompt_points = np.concatenate([ori_prompt_points, dual_prompt_points], axis = 0)
                    input_label = np.concatenate([1-ori_input_label, dual_input_label], axis = 0)
                    # generate dual mask
                    dual_masks, dual_scores, dual_logits = self.predictor.predict(
                        point_coords=prompt_points,
                        point_labels=input_label,
                        multimask_output=False,
                    )


            if num == 0:
                # used for single object only
                sam_seg_show = masks[0].astype(np.float32) if masks is not None else np.zeros((H,W))
                sam_seg_show = np.stack([sam_seg_show,sam_seg_show,sam_seg_show], axis = -1)
                for point in ori_prompt_points:
                    sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                    sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, 0] = 1
                for point in dual_prompt_points:
                    sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                    sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, 2] = 1
                self.sam_segs.append(sam_seg_show)
                
                dual_sam_seg_show = dual_masks[0].astype(np.float32)  if dual_masks is not None else np.zeros((H,W))
                dual_sam_seg_show = np.stack([dual_sam_seg_show,dual_sam_seg_show,dual_sam_seg_show], axis = -1)
                for point in dual_prompt_points:
                    dual_sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                    dual_sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, 0] = 1
                for point in ori_prompt_points:
                    dual_sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                    dual_sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, 2] = 1
                self.dual_sam_segs.append(dual_sam_seg_show)

            if masks is not None:
                tmp_IoU = utils.cal_IoU(torch.as_tensor(masks[0]).float(), tmp_rendered_mask)
                print("tmp_IoU:", tmp_IoU)
                if tmp_IoU < 0.5:
                    print("SKIP, unacceptable sam prediction for original seg, IoU is", tmp_IoU)
                else:
                    loss += seg_loss(masks[0], None, tmp_seg_m, 0.15)
                    # loss += -(torch.tensor(masks[0]).to(seg_m.device) * tmp_seg_m).sum() + 0.15 * (torch.tensor(1-masks[0]).to(seg_m.device) * tmp_seg_m).sum()
                    for neg_i in range(seg_m.shape[-1]):
                        if neg_i == num: 
                            continue
                        loss -= seg_loss(masks[0], None, seg_m[:,:,neg_i], 0)
                        # loss += (torch.tensor(masks[0]).to(seg_m.device) * seg_m[:,:,neg_i]).sum()

                if dual_masks is not None:
                    tmp_IoU = utils.cal_IoU(torch.as_tensor(dual_masks[0]).float(), tmp_rendered_dual_mask)
                    print("tmp_dual_IoU:", tmp_IoU)
                    if tmp_IoU < 0.5:
                        print("SKIP, unacceptable sam prediction for dual seg, IoU is", tmp_IoU)
                    else:
                        loss += seg_loss(dual_masks[0], None, dual_tmp_seg_m, 0.15)
                        # loss += -(torch.tensor(dual_masks[0]).to(seg_m.device) * dual_tmp_seg_m).sum() + 0.15 * (torch.tensor(1-dual_masks[0]).to(dual_seg_m.device) * dual_tmp_seg_m).sum()
                        for neg_i in range(dual_seg_m.shape[-1]):
                            if neg_i == num: 
                                continue
                            loss -= seg_loss(dual_masks[0], None, dual_seg_m[:,:,neg_i], 0)
                            # loss += (torch.tensor(dual_masks[0]).to(seg_m.device) * dual_seg_m[:,:,neg_i]).sum()
        
        return loss


def seg_loss(mask: Tensor, selected_mask: Optional[Tensor], seg_m: Tensor, lamda: float = 5.0) -> Tensor:
    """
    Compute segmentation loss using binary mask and predicted mask.

    Args:
        mask: Binary ground truth segmentation mask tensor.
        selected_mask: Tensor indicating which indices in `mask` to select. Can be `None`.
        seg_m: Predicted segmentation mask tensor.
        lamda: Weighting factor for outside mask loss. Default is 5.0.

    Returns:
        Computed segmentation loss.

    Raises:
        AssertionError: If `seg_m` is `None`.
    """
    assert seg_m is not None, "Segmentation mask is None."
    device = seg_m.device
    if selected_mask is not None:
        mask_loss = -(utils.to_tensor(mask[selected_mask], device) * seg_m.squeeze(-1)).sum()
        out_mask_loss = lamda * (utils.to_tensor(1 - mask[selected_mask], device) * seg_m.squeeze(-1)).sum()
    else:
        mask_loss = -(utils.to_tensor(mask, device) * seg_m.squeeze(-1)).sum()
        out_mask_loss = lamda * ((1 - utils.to_tensor(mask, device)) * seg_m.squeeze(-1)).sum()
    return mask_loss + out_mask_loss


def prompting(args, num_obj, seg_m, predictor, index_matrix, sam_segs, stage):
    '''for the rest of the turns, we use self-prompts to update the segmentation volume'''
    loss = 0
    seg_m_clone = seg_m.detach().clone()
    seg_m_for_prompt = torch.nn.functional.avg_pool2d(seg_m_clone.permute([2, 0, 1]).unsqueeze(0), 
                                                    25, stride = 1, padding = 12)
    seg_m_for_prompt = seg_m_for_prompt.squeeze(0).permute([1, 2, 0])
    
    for num in range(num_obj):
        with torch.no_grad():
            prompt_points, input_label = mask_to_prompt(predictor, 
                                    seg_m_for_prompt[...,num:num+1], index_matrix, args.num_prompts)
            masks, scores, logits = predictor.predict(
                point_coords=prompt_points,
                point_labels=input_label,
                multimask_output=False,
            )
            selected = np.argmax(scores)
            
        tmp_seg_m = seg_m[:,:,num]

        if num == 0:
            # used for single object only
            sam_seg_show = masks[selected].astype(np.float32)
            sam_seg_show = np.stack([sam_seg_show,sam_seg_show,sam_seg_show], axis = -1)
            for ip, point in enumerate(prompt_points):
                sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, :] = 0
                sam_seg_show[point[1]-3 : point[1]+3, point[0] - 3 : point[0]+3, ip] = 1
            sam_segs.append(sam_seg_show)
            
        tmp_rendered_mask = tmp_seg_m.detach().clone()
        tmp_rendered_mask[torch.logical_or(tmp_rendered_mask <= tmp_rendered_mask.mean(), tmp_rendered_mask <= 0)] = 0
        tmp_rendered_mask[tmp_rendered_mask != 0] = 1
        tmp_IoU = utils.cal_IoU(torch.as_tensor(masks[selected]).float(), tmp_rendered_mask)

        print(f"current IoU is: {tmp_IoU}")
        if tmp_IoU < 0.5:
            print("SKIP, unacceptable sam prediction, IoU is", tmp_IoU.item())
            continue
                
        if stage == 'coarse':
            # area weighted loss
            loss += seg_loss(masks, selected, tmp_seg_m, 0.15)
            # (-(torch.tensor(masks[selected]).to(seg_m.device) * tmp_seg_m).sum() + \
            #             0.15 * (torch.tensor(1-masks[selected]).to(seg_m.device) * tmp_seg_m).sum())
            
            for neg_i in range(seg_m.shape[-1]):
                if neg_i == num:
                    continue
                loss += (torch.tensor(masks[selected]).to(seg_m.device) * seg_m[:,:,neg_i]).sum()
        else:
            loss += torch.tensor(masks[selected]).sum() / torch.tensor(masks[selected]).numel() * \
                seg_loss(masks, selected, tmp_seg_m, 0.15)

    return sam_segs, loss


def optim(optimizer, loss, clip=None, model=None):
    """Perform a single optimization step using the given optimizer and loss.

    Args:
        optimizer: PyTorch optimizer to use for the optimization step.
        loss: The loss tensor to optimize.
        clip: Optional gradient clipping value.
        model: Optional PyTorch model whose parameters to clip.

    Raises:
        TypeError: If the loss is not a tensor.
    """
    if isinstance(loss, torch.Tensor):
        optimizer.zero_grad()
        loss.backward()
        if clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
    else:
        pass


def _generate_index_matrix(H, W, depth_map):
    '''generate the index matrix, which contains the coordinate of each pixel and cooresponding depth'''
    xs = torch.arange(1, H+1) / H # NOTE, range (1, H) = arange(1, H+1)
    ys = torch.arange(1, W+1) / W
    grid_x, grid_y = torch.meshgrid(xs, ys)
    index_matrix = torch.stack([grid_x, grid_y], dim = -1) # [H, W, 2]
    depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) # [H, W, 1]
    index_matrix = torch.cat([index_matrix, depth_map], dim = -1)
    return index_matrix

