import timm
import torch
import math
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, PatchEmbed

def build_promptmodel(modelname='vit_base_patch16_224',  Prompt_Token_num=10, VPT_type="Deep", args=None):
    
    edge_size=224
    patch_size=16  
    num_classes=1000 if modelname == 'vit_base_patch16_224' else 21843
    basic_model = timm.create_model(modelname, pretrained=True)
    model = VPT_ViT(Prompt_Token_num=Prompt_Token_num,VPT_type=VPT_type, args=args)

    basicmodeldict=basic_model.state_dict()
    basicmodeldict.pop('head.weight')
    basicmodeldict.pop('head.bias')

    model.load_state_dict(basicmodeldict, False)
    
    model.head = torch.nn.Identity()
    
    model.Freeze()
    
    return model


class VPT_ViT(VisionTransformer):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1,
                 VPT_type="Shallow", basic_state_dict=None, args=None):

        super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
                         embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
                         qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
                         drop_path_rate=drop_path_rate, embed_layer=embed_layer,
                         norm_layer=norm_layer, act_layer=act_layer)
                         
        print('Using VPT model')
        if basic_state_dict is not None:
            self.load_state_dict(basic_state_dict, False)

        self.tfmout = []
        self.args = args
        self.VPT_type = VPT_type

        if self.args['intra_share'] == 1:
            print('use share to init prompt with all 0!')
            if VPT_type == "Deep":
                self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
            else:  # "Shallow"
                self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))

        else:
            print('use not share to init prompt with U dis!')
            val = math.sqrt(6. / float(768 * 2))
            if VPT_type == "Deep":
                self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
            else:  # "Shallow"
                self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))
            nn.init.uniform_(self.Prompt_Tokens, -val, val)

    def New_CLS_head(self, new_classes=15):
        self.head = nn.Linear(self.embed_dim, new_classes)

    def Freeze(self):
        for param in self.parameters():
            param.requires_grad = False

        self.Prompt_Tokens.requires_grad = True
        try:
            for param in self.head.parameters():
                param.requires_grad = True
        except:
            pass

    def UnFreeze(self):
        for param in self.parameters():
            param.requires_grad = True

    def obtain_prompt(self):
        prompt_state_dict = {'head': self.head.state_dict(),
                             'Prompt_Tokens': self.Prompt_Tokens}
        # print(prompt_state_dict)
        return prompt_state_dict

    def load_prompt(self, prompt_state_dict):
        try:
            self.head.load_state_dict(prompt_state_dict['head'], False)
        except:
            print('head not match, so skip head')
        else:
            print('prompt head match')

        if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape:

            # device check
            Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu())
            Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device))

            self.Prompt_Tokens = Prompt_Tokens

        else:
            print('\n !!! cannot load prompt')
            print('shape of model req prompt', self.Prompt_Tokens.shape)
            print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape)
            print('')

    def forward_features(self, x):
        self.tfmout = [] # renew
        x = self.patch_embed(x)
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)

        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)

        if self.VPT_type == "Deep":

            Prompt_Token_num = self.Prompt_Tokens.shape[1]

            for i in range(len(self.blocks)):
                Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
                x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
                num_tokens = x.shape[1]
                x = self.blocks[i](x)
                self.tfmout.append(x)
                x = x[:, :num_tokens - Prompt_Token_num]

        else:
            Prompt_Token_num = self.Prompt_Tokens.shape[1]

            Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
            x = torch.cat((x, Prompt_Tokens), dim=1)
            num_tokens = x.shape[1]
            for i in range(len(self.blocks)):
                x = self.blocks[i](x)
                self.tfmout.append(x)
            x = x[:, :num_tokens - Prompt_Token_num]

        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x=x[:, 0, :]
        return x

    def get_each_tfmout(self):
        assert len(self.tfmout) == 12, "NO"
        tfmout = self.tfmout
        self.tfmout = []
        return tfmout

