import torch
import torch.nn as nn
import copy

from models.vit import VisionTransformer, PatchEmbed, Block,resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn

class ViT_Prompts(VisionTransformer):
    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):

        super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size,
            drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values,
            embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn)


    def forward(self, x, instance_tokens=None, **kwargs):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        if instance_tokens is not None:
            instance_tokens = instance_tokens.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)

        x = x + self.pos_embed.to(x.dtype)
        if instance_tokens is not None:
            x = torch.cat([x[:,:1,:], instance_tokens, x[:,1:,:]], dim=1)

        x = self.pos_drop(x)
        
        x = self.blocks(x)
        
        x = self.norm(x)
        if self.global_pool:
            x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        return x



def _create_vision_transformer(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    # NOTE this extra code to support handling of repr size for in21k pretrained models
    #pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
    pretrained_cfg = resolve_pretrained_cfg(variant)
    default_num_classes = pretrained_cfg['num_classes']
    num_classes = kwargs.get('num_classes', default_num_classes)
    repr_size = kwargs.pop('representation_size', None)
    if repr_size is not None and num_classes != default_num_classes:
        repr_size = None

    model = build_model_with_cfg(
        ViT_Prompts, variant, pretrained,
        pretrained_cfg=pretrained_cfg,
        representation_size=repr_size,
        pretrained_filter_fn=checkpoint_filter_fn,
        pretrained_custom_load='npz' in pretrained_cfg['url'],
        **kwargs)
    return model



class SiNet(nn.Module):

    def __init__(self, args):
        super(SiNet, self).__init__()

        model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
        #图像编码器，生成向量大小为768
        """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
    """
        self.image_encoder =_create_vision_transformer('vit_base_patch16_224', pretrained=True, **model_kwargs)

        self.class_num = 1#类数量
        if args["dataset"] == "cddb":
            self.class_num = 2
            #nn.Linear设置网络中的全连接层的,in_features由输入张量的形状决定，out_features则决定了输出张量的形状
            #一个会话相当于一个数据集任务（比如s1、s2），这里只total_sessions=训练总域数量，保存的模型也是每个域一个模型
            #测试时相当于先找对应的域，再和该域下绑定的模型进行预测
            #ModuleList以列表的形式来保存多个子模块，而且其中的模块会被正确地登记注册，可以计算梯度
            self.classifier_pool = nn.ModuleList([
                nn.Linear(args["embd_dim"], self.class_num, bias=True)
                for i in range(args["total_sessions"])
            ])
        elif args["dataset"] == "domainnet":
            self.class_num = 345
            self.classifier_pool = nn.ModuleList([
                nn.Linear(args["embd_dim"], self.class_num, bias=True)
                for i in range(args["total_sessions"])
            ])
        elif args["dataset"] == "core50":
            self.class_num = 50
            self.classifier_pool = nn.ModuleList([
                nn.Linear(args["embd_dim"], self.class_num, bias=True)#有bias的梯度
                for i in range(args["total_sessions"])
            ])

        else:
            raise ValueError('Unknown datasets: {}.'.format(args["dataset"]))
        #单个图像提示Li的长度为10，一个语言提示Ll的长度为16 
        #只用weight
        self.prompt_pool = nn.ModuleList([
            nn.Linear(args["embd_dim"], args["prompt_length"], bias=False)
            for i in range(args["total_sessions"])
        ])

        self.numtask = 0# 当前任务数量

    @property
    def feature_dim(self):
        return self.image_encoder.out_dim

    def extract_vector(self, image):
        image_features = self.image_encoder(image)
        #归一化
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        return image_features

    def forward(self, image):
        logits = []
        #结合图片和当前任务的prompt 作为输入，得到输出后再进入当前任务的classifier 10,768
        image_features = self.image_encoder(image, self.prompt_pool[self.numtask-1].weight)
        for prompts in [self.classifier_pool[self.numtask-1]]:
            logits.append(prompts(image_features))

        return {
            'logits': torch.cat(logits, dim=1),
            'features': image_features
        }

    def interface(self, image, selection):
        #根据选择的任务索引找到对应的prompt，再结合Vit得到输出图片向量
        instance_batch = torch.stack([i.weight for i in self.prompt_pool], 0)[selection, :, :]
        image_features = self.image_encoder(image, instance_batch)
        #把向量到分类池里都算一遍，再根据selection的索引找到对应任务的分类器算的预测结果
        logits = []
        for prompt in self.classifier_pool:
            logits.append(prompt(image_features))
        logits = torch.cat(logits,1)
        selectedlogit = []
        for idx, ii in enumerate(selection):
            selectedlogit.append(logits[idx][self.class_num*ii:self.class_num*ii+self.class_num])
        selectedlogit = torch.stack(selectedlogit)
        return selectedlogit

    def update_fc(self, nb_classes):
        self.numtask +=1

    def copy(self):
        return copy.deepcopy(self)

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False#参数不再更新
        self.eval()
        return self
