import torch
import torch.nn as nn
from typing import Any, Dict, List, Tuple, Type
from .modeling.common import LayerNorm2d
import torch.nn.functional as F

from typing import Type


def build_sam_upscale_h(checkpoint=None):
    return _build_upscale(
        checkpoint=checkpoint,
    )

build_sam = build_sam_upscale_h


def build_sam_upscale_l(checkpoint=None):
    return _build_upscale(
        checkpoint=checkpoint,
    )


def build_sam_upscale_b(checkpoint=None):
    return _build_upscale(
        checkpoint=checkpoint,
    )


upscale_model_registry = {
    "default": build_sam_upscale_h,
    "vit_h": build_sam_upscale_h,
    "vit_l": build_sam_upscale_l,
    "vit_b": build_sam_upscale_b,
}
def _build_upscale(
    checkpoint=None,):
    prompt_embed_dim = 256
    gs_upscale = GSUpscaleModel(
        transformer_dim=prompt_embed_dim,
    )
    gs_upscale.eval()
    if checkpoint is not None:
        with open(checkpoint, "rb") as f:
            state_dict = torch.load(f)
            new_state_dict = {}
            for key, value in state_dict.items():
                new_key = key.replace('mask_decoder.output_upscaling', 'output_upscale')  # 替换前缀
                new_state_dict[new_key] = value
        gs_upscale.load_state_dict(new_state_dict)

    return gs_upscale


# 1. 创建自定义的 Upscale 模型
class GSUpscaleModel(nn.Module):
    def __init__(
        self,
        *,
        transformer_dim: int,
        activation: Type[nn.Module] = nn.GELU
    ):

        super().__init__()
        self.output_upscale = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
            activation(),
        )

    @property
    def device(self) -> Any:
        return self.pixel_mean.device

    def forward(self, attn):
        upscale_attn = self.output_upscale(attn)
        return upscale_attn