import torch
import torch.nn as nn
from mmseg.models.backbones.vit import VisionTransformer
from mmseg.registry import MODELS


@MODELS.register_module()
class VisionTransformer_Prompt(VisionTransformer):
    def __init__(self, prompt_length=5, prompt_dropout=0.1, **kwargs):
        embed_dims = kwargs.get('embed_dims', 768)

        self.prompt_length = prompt_length
        self.prompt_dropout_ratio = prompt_dropout
        self._prompt_embed_dims = embed_dims
        super().__init__(**kwargs)
        self.prompt_embeddings = nn.Parameter(
            torch.randn(1, prompt_length, embed_dims) * 0.02
        )
        self.prompt_dropout = nn.Dropout(p=self.prompt_dropout_ratio)
        old_pos_embed = self.pos_embed  # [1, 1+N, C]
        num_old_tokens = old_pos_embed.shape[1]
        new_pos_embed = nn.Parameter(torch.zeros(1, num_old_tokens + prompt_length, embed_dims))

        new_pos_embed.data[:, :1] = old_pos_embed.data[:, :1]

        nn.init.trunc_normal_(new_pos_embed.data[:, 1:1 + prompt_length], std=0.02)
        new_pos_embed.data[:, 1 + prompt_length:] = old_pos_embed.data[:, 1:]

        self.pos_embed = new_pos_embed
        
    def _pos_embeding(self, patched_img, hw_shape, pos_embed):

        assert patched_img.ndim == 3 and pos_embed.ndim == 3, \
            'the shapes of patched_img and pos_embed must be [B, L, C]'

        x_len, pos_len = patched_img.shape[1], pos_embed.shape[1]

        prompt_len = self.prompt_length
        cls_token_weight = pos_embed[:, 0]
        prompt_weight = pos_embed[:, 1:1 + prompt_len]
        patch_weight = pos_embed[:, 1 + prompt_len:]

        patch_token_len = x_len - 1 - prompt_len
        hw = hw_shape
        pos_h = int((patch_weight.shape[1]) ** 0.5)
        pos_w = pos_h
        patch_weight = patch_weight.reshape(1, pos_h, pos_w, -1).permute(0, 3, 1, 2)
        patch_weight = torch.nn.functional.interpolate(patch_weight, size=hw, mode=self.interpolate_mode, align_corners=False)
        patch_weight = patch_weight.flatten(2).transpose(1, 2)

        new_pos_embed = torch.cat([
            cls_token_weight.unsqueeze(1),
            prompt_weight,
            patch_weight
        ], dim=1)

        return self.drop_after_pos(patched_img + new_pos_embed)
    

    def forward(self, x):
        B = x.shape[0]
        x, hw_shape = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        prompt_tokens = self.prompt_embeddings.expand(B, -1, -1)
        prompt_tokens = self.prompt_dropout(prompt_tokens)

        x = torch.cat([cls_tokens, prompt_tokens, x], dim=1)
        x = self._pos_embeding(x, hw_shape, self.pos_embed)

        if not self.with_cls_token:
            x = x[:, 1 + self.prompt_length:]

        if self.pre_norm:
            x = self.pre_ln(x)

        outs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i == len(self.layers) - 1 and self.final_norm:
                x = self.norm1(x)
            if i in self.out_indices:
                out = x[:, 1 + self.prompt_length:]
                B, _, C = out.shape
                out = out.reshape(B, hw_shape[0], hw_shape[1], C).permute(0, 3, 1, 2).contiguous()
                outs.append(out)

        return tuple(outs)
