import torch
import torch.nn as nn
from models.clip.prompt_learner import cfgc, load_clip_to_cpu, TextEncoder, PromptLearner_v2, PromptLearner_v4, PromptLearner_v3, PromptLearner_flower,  PromptLearner_nwpu,  PromptLearner_dog, PromptLearner_ucf
from utils.class_names import cifar10_classnames, cifar100_classnames, stanfordcars_classnames,  dtd_classnames, SAT_classnames, Aircraft_classnames, flower_classnames, nwpu_classnames, pattern_classnames,  imagenet_classnames, dog_classnames, ucf_classnames
from models.lin import gcn_prompt_net
import copy

class cNet_v22(nn.Module):

    def __init__(self, args):
        super(cNet_v22, self).__init__()
        self.cfg = cfgc()
        clip_model = load_clip_to_cpu(self.cfg)
        self.clip_model = clip_model
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.class_num = 1
        self.top_k = args["top_k"]
        self.args = args
        self.layer_num = self.image_encoder.layers

        if args['dataset'] == 'cifar':
            self.class_num = 100
            self.prompt_learner = PromptLearner_v2(self.cfg, cifar100_classnames, self.clip_model)
            # self.prompt_learner = PromptLearner_newCifar(self.cfg, cifar100_classnames, self.clip_model)
        if args['dataset'] == 'cifar10':
            self.class_num = 10
            self.prompt_learner = PromptLearner_v2(self.cfg, cifar10_classnames, self.clip_model)
        if args['dataset'] == 'cars':
            self.class_num = len(stanfordcars_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, stanfordcars_classnames, self.clip_model)
        if args['dataset'] == 'dtd':
            self.class_num = len(dtd_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, dtd_classnames, self.clip_model)
        if args['dataset'] == 'sat':
            self.class_num = len(SAT_classnames)
            self.prompt_learner = PromptLearner_v3(self.cfg, SAT_classnames, self.clip_model)
        if args['dataset'] == 'aircraft':
            self.class_num = len(Aircraft_classnames)
            self.prompt_learner = PromptLearner_v4(self.cfg, Aircraft_classnames, self.clip_model)
        if args['dataset'] == 'flower':
            self.class_num = len(flower_classnames)
            self.prompt_learner = PromptLearner_flower(self.cfg, flower_classnames, self.clip_model)
        if args['dataset'] == 'nwpu':
            self.class_num = len(nwpu_classnames)
            self.prompt_learner = PromptLearner_nwpu(self.cfg, nwpu_classnames, self.clip_model)
        if args['dataset'] == 'pattern':
            self.class_num = len(pattern_classnames)
            self.prompt_learner = PromptLearner_nwpu(self.cfg, pattern_classnames, self.clip_model)
        if args['dataset'] == 'Imagenet':
            self.class_num = len(imagenet_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, imagenet_classnames, self.clip_model)
        if args['dataset'] == 'dog':
            self.class_num = len(dog_classnames)
            self.prompt_learner = PromptLearner_dog(self.cfg, dog_classnames, self.clip_model)
        if args['dataset'] == 'ucf':
            self.class_num = len(ucf_classnames)
            self.prompt_learner = PromptLearner_ucf(self.cfg, ucf_classnames, self.clip_model)

        self.tokenized_prompts = self.prompt_learner.tokenized_prompts

        self.numtask = 0

        self.relu = nn.ReLU(inplace=True)
        self.bn1_ = nn.BatchNorm1d(self.text_encoder.text_projection.shape[1])
        self.bn2_ = nn.BatchNorm1d(self.text_encoder.text_projection.shape[1])
        self.linear_projection = copy.deepcopy(self.image_encoder.conv1)
        self.gcn_prompt = gcn_prompt_net(use_stochastic=False, gcn_len=self.args["prompt_length_c"])

    def encode_image_attention(self, image):

        prompts = self.prompt_learner().to(self.global_pool.device)
        tokenized_prompts = self.tokenized_prompts.to(self.global_pool.device)
        text_features = self.text_encoder(prompts, tokenized_prompts)  # class_num feature
        class_pool_key = text_features
        class_pool_key_norm = class_pool_key / class_pool_key.norm(dim=-1, keepdim=True)

        image_tokens = self.linear_projection(image.type(self.dtype))
        image_tokens = image_tokens.to(dtype=torch.float32)

        hidden_token, gcn_prompt, att_tokens = self.gcn_prompt(image_tokens,
                                                               class_pool_key_norm.to(dtype=torch.float32))

        gcn_prompt = gcn_prompt.reshape(-1, self.layer_num, self.args["prompt_length_c"], self.args["embd_dim"]).type(
            self.dtype)
        gcn_prompt = gcn_prompt.permute(1, 0, 2, 3)
        return self.image_encoder.forward_attention(image.type(self.dtype), self.global_pool, gcn_prompt,
                                            self.image_encoder.class_embedding)

    def encode_cross_attention(self, image, target=None, p_target=None):

        prompts = self.prompt_learner().to(self.global_pool.device)
        tokenized_prompts = self.tokenized_prompts.to(self.global_pool.device)
        text_features = self.text_encoder(prompts, tokenized_prompts)  # class_num feature
        class_pool_key = text_features
        class_pool_key_norm = class_pool_key / class_pool_key.norm(dim=-1, keepdim=True)

        image_tokens = self.linear_projection(image.type(self.dtype))
        image_tokens = image_tokens.to(dtype=torch.float32)

        att_maps = self.gcn_prompt.forward_att_map(image_tokens, class_pool_key_norm.to(dtype=torch.float32))

        real_maps = att_maps[0, target, :].squeeze()
        wrong_maps = att_maps[0, p_target, :].squeeze()
        return real_maps, wrong_maps

    def forward(self, image, target=None, p_target=None):
        #
        # This part of code will be released after accepted.
        #
        return {
            'logits': torch.cat(logits, dim=1),
            'features': image_features,
            'class_features': class_features_,
            'increase_sim': increase_sim,
            'reduce_sim': reduce_sim,
        }
