import torch
import torch.nn as nn
import copy

from models.clip.prompt_learner_shared import cfgc, load_clip_to_cpu, TextEncoder, PromptLearnerShared
from utils.class_names import core50_classnames, domainnet_classnames, cddb_classnames
from models.clip import clip
class Encode():

    def __init__(self, args):
        #super(Encode, self).__init__()

        self.cfg = cfgc()
        clip_model = load_clip_to_cpu(self.cfg)
        self.clip_model = clip_model

        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.class_num = 1#类数量
        if args["dataset"] == "cddb":
            self.class_num = 2
            #nn.Linear设置网络中的全连接层的,in_features由输入张量的形状决定，out_features则决定了输出张量的形状
            #一个会话相当于一个数据集任务（比如s1、s2），这里只total_sessions=训练总域数量，保存的模型也是每个域一个模型
            #测试时相当于先找对应的域，再和该域下绑定的模型进行预测
            #ModuleList以列表的形式来保存多个子模块，而且其中的模块会被正确地登记注册，可以计算梯度
            # self.classifier_pool = nn.ModuleList([
            #     nn.Linear(args["embd_dim"], self.class_num, bias=True)
            #     for i in range(args["total_sessions"])
            # ])
        elif args["dataset"] == "domainnet":
            self.class_num = 345
            # self.classifier_pool = nn.ModuleList([
            #     nn.Linear(args["embd_dim"], self.class_num, bias=True)
            #     for i in range(args["total_sessions"])
            # ])
        elif args["dataset"] == "core50":
            self.class_num = 50
            # self.classifier_pool = nn.ModuleList([
            #     nn.Linear(args["embd_dim"], self.class_num, bias=True)
            #     for i in range(args["total_sessions"])
            # ])

        else:
            raise ValueError('Unknown datasets: {}.'.format(args["dataset"]))
        #单个图像提示Li的长度为10，一个语言提示Ll的长度为16
        # self.prompt_pool = nn.ModuleList([
        #     nn.Linear(args["embd_dim"], args["prompt_length"], bias=False)
        #     for i in range(args["total_sessions"])
        # ])

    # self.numtask = 0# 当前任务数量

    
    def feature_dim(self):
        return self.image_encoder.out_dim

    def extract_vector(self, image):
        image_features = self.image_encoder(image.type(self.dtype))
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

        return image_features

    # def forward(self, image):
    #     pass
        # logits = []
        # #结合图片和当前任务的prompt 作为输入，得到输出后再进入当前任务的classifier
        # image_features = self.image_encoder(image, self.prompt_pool[self.numtask-1].weight)
        # for prompts in [self.classifier_pool[self.numtask-1]]:
        #     logits.append(prompts(image_features))

        # return {
        #     'logits': torch.cat(logits, dim=1),
        #     'features': image_features
        # }

    # def interface(self, image, selection):
    #     pass
        #根据选择的任务索引找到对应的prompt，再结合Vit得到输出图片向量
        # instance_batch = torch.stack([i.weight for i in self.prompt_pool], 0)[selection, :, :]
        # image_features = self.image_encoder(image, instance_batch)
        # #把向量到分类池里都算一遍，再根据selection的索引找到对应任务的分类器算的预测结果
        # logits = []
        # for prompt in self.classifier_pool:
        #     logits.append(prompt(image_features))
        # logits = torch.cat(logits,1)
        # selectedlogit = []
        # for idx, ii in enumerate(selection):
        #     selectedlogit.append(logits[idx][self.class_num*ii:self.class_num*ii+self.class_num])
        # selectedlogit = torch.stack(selectedlogit)
        # return selectedlogit

    # def update_fc(self, nb_classes):
    #     self.numtask +=1

    # def copy(self):
    #     return copy.deepcopy(self)

    # def freeze(self):
    #     for param in self.parameters():
    #         param.requires_grad = False#参数不再更新
    #     self.eval()

    #     return self
