from segment_anything import sam_model_registry
import torch
import torch.nn as nn


class SAM_APG(nn.Module):

    def __init__(self, cfg):
        super(SAM_APG, self).__init__()
        self.cfg = cfg

        print("load vit_h")
        self.sam, img_embedding_size = sam_model_registry["vit_h"](image_size=352,
                                                                   num_classes=1, 
                                                                   checkpoint="sam_vit_h_4b8939.pth",
                                                                   pixel_mean=[0, 0, 0],
                                                                   pixel_std=[1, 1, 1])
        select = 0

        for n, p in self.sam.named_parameters():
            if "image_encoder" in n:
                p.requires_grad = False
            if "prompt_generator" in n:  
                p.requires_grad = True

            if "mask_downsample" in n:
                p.requires_grad = False

            if "prompt_encoder" in n:
                p.requires_grad = False

            if "mask_decoder" in n:
                p.requires_grad = False

            if p.requires_grad == True:
                select += len(p.reshape(-1))
   
        print("select:", select / 1e6)


        if self.cfg is not None and self.cfg.snapshot:
                print('load checkpoint')
                self.load_state_dict(torch.load(self.cfg.snapshot))



    def forward(self, x, multimask_output=True, image_size=None):

        coarse_map, Background_outputs, sod_outputs,cod_outputs= self.sam(batched_input=x, multimask_output=multimask_output, image_size=image_size)

        return coarse_map, Background_outputs, sod_outputs,cod_outputs


