import torch
import torch.nn as nn
from models.clip.prompt_learner import cfgc, load_clip_to_cpu, TextEncoder, PromptLearner_v2, PromptLearner_v4, PromptLearner_v3, PromptLearner_flower,  PromptLearner_nwpu,  PromptLearner_dog, PromptLearner_ucf
from utils.class_names import cifar10_classnames, cifar100_classnames, stanfordcars_classnames,  dtd_classnames, SAT_classnames, Aircraft_classnames, flower_classnames, nwpu_classnames, pattern_classnames,  imagenet_classnames, dog_classnames, ucf_classnames
from models.lin import gcn_prompt_net
import copy
from models.vit import VisionTransformer, PatchEmbed, Block,resolve_pretrained_cfg, build_model_with_cfg, checkpoint_filter_fn

class ViT_Prompts(VisionTransformer):
    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
        self.layers = depth
        self.output_dim = embed_dim
        super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size,
            drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values,
            embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn)


    def forward(self, x, instance_tokens=None, gcn_tokens=None, **kwargs):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        instance_tokens_l0, gcn_tokens_l0 = None, None
        if instance_tokens is not None:
            map_instance_tokens = torch.zeros(instance_tokens.shape[0], x.shape[0], instance_tokens.shape[1],
                                              x.shape[-1], dtype=x.dtype, device=x.device)
            # print(map_instance_tokens[0])
        if gcn_tokens is not None:
            map_gcn_tokens = torch.zeros(gcn_tokens.shape[0], x.shape[0], gcn_tokens.shape[2], x.shape[-1],
                                           dtype=x.dtype, device=x.device)
            # print(map_class_tokens[0])

        for i in range(self.layers): #self.layers
            if instance_tokens is not None:
                # print('in', instance_tokens.size())
                map_instance_tokens[i] = instance_tokens[i].to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
                if i == 0:
                    instance_tokens_l0 = map_instance_tokens[i]
            if gcn_tokens is not None:
                # print('cl', class_tokens.size())
                map_gcn_tokens[i] = gcn_tokens[i].to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
                if i == 0:
                    gcn_tokens_l0 = map_gcn_tokens[i]

        x = x + self.pos_embed.to(x.dtype)

        if gcn_tokens is not None:
            x = torch.cat([x[:,:1,:], gcn_tokens_l0, x[:,1:,:]], dim=1)

        if instance_tokens is not None:
            x = torch.cat([x[:,:1,:], instance_tokens_l0, x[:,1:,:]], dim=1)

        x = self.pos_drop(x)
        # x = self.blocks(x)
        ##############################
        if instance_tokens is None and gcn_tokens is None:
            x = self.blocks(x)
        elif instance_tokens is not None and gcn_tokens is None:
            for i, block in enumerate(self.blocks):
                if i > 0: #####
                    # print(map_instance_tokens[i].shape)
                    # print(x.shape)
                    x = torch.cat([x[:,:1,:], map_instance_tokens[i], x[:, 1 + map_instance_tokens[i].shape[1]:, :]], dim=1)
                    x = block(x)
                else:
                    x = block(x)
        else:
            for i, block in enumerate(self.blocks):
                if i > 0: #####
                    y = 1 + map_instance_tokens[i].shape[1] + map_gcn_tokens[i].shape[1]
                    x = torch.cat([x[:, 0: 1 + map_instance_tokens[i].shape[1], :], map_gcn_tokens[i], x[:, y:, :]], dim=1)
                    x = torch.cat([x[:, :1, :], map_instance_tokens[i], x[:, 1 + map_instance_tokens[i].shape[1]:, :]], dim=1)
                    x = block(x)
                else:
                    x = block(x)
        ##############################
        x = self.norm(x)
        if self.global_pool:
            x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        return x

    def forward_attn(self, x, instance_tokens=None, gcn_tokens=None, **kwargs):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        instance_tokens_l0, gcn_tokens_l0 = None, None
        if instance_tokens is not None:
            map_instance_tokens = torch.zeros(instance_tokens.shape[0], x.shape[0], instance_tokens.shape[1],
                                              x.shape[-1], dtype=x.dtype, device=x.device)

        if gcn_tokens is not None:
            map_gcn_tokens = torch.zeros(gcn_tokens.shape[0], x.shape[0], gcn_tokens.shape[2], x.shape[-1],
                                           dtype=x.dtype, device=x.device)


        for i in range(self.layers): #self.layers
            if instance_tokens is not None:
                # print('in', instance_tokens.size())
                map_instance_tokens[i] = instance_tokens[i].to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
                if i == 0:
                    instance_tokens_l0 = map_instance_tokens[i]
            if gcn_tokens is not None:
                # print('cl', class_tokens.size())
                map_gcn_tokens[i] = gcn_tokens[i].to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
                if i == 0:
                    gcn_tokens_l0 = map_gcn_tokens[i]


        x = x + self.pos_embed.to(x.dtype)

        if gcn_tokens is not None:
            x = torch.cat([x[:,:1,:], gcn_tokens_l0, x[:,1:,:]], dim=1)

        if instance_tokens is not None:
            x = torch.cat([x[:,:1,:], instance_tokens_l0, x[:,1:,:]], dim=1)

        x = self.pos_drop(x)

        ##############################
        if instance_tokens is None and gcn_tokens is None:
            x = self.blocks(x)
        elif instance_tokens is not None and gcn_tokens is None:
            for i, block in enumerate(self.blocks):
                if i > 0: #####
                    x = torch.cat([x[:,:1,:], map_instance_tokens[i], x[:, 1 + map_instance_tokens[i].shape[1]:, :]], dim=1)
                    x = block(x)
                else:
                    x = block(x)
        else:
            for i, block in enumerate(self.blocks):
                if i == 11:
                    y = 1 + map_instance_tokens[i].shape[1] + map_gcn_tokens[i].shape[1]
                    x = torch.cat([x[:, 0: 1 + map_instance_tokens[i].shape[1], :], map_gcn_tokens[i], x[:, y:, :]],
                                  dim=1)
                    x = torch.cat([x[:, :1, :], map_instance_tokens[i], x[:, 1 + map_instance_tokens[i].shape[1]:, :]],
                                  dim=1)
                    attn = block.forward_att(x)
                    attn = attn[0,:,0,y:]
                    attn = torch.mean(attn,dim=0, keepdim=True)
                    return attn
                if i > 0: #####
                    y = 1 + map_instance_tokens[i].shape[1] + map_gcn_tokens[i].shape[1]
                    x = torch.cat([x[:, 0: 1 + map_instance_tokens[i].shape[1], :], map_gcn_tokens[i], x[:, y:, :]], dim=1)
                    x = torch.cat([x[:, :1, :], map_instance_tokens[i], x[:, 1 + map_instance_tokens[i].shape[1]:, :]], dim=1)
                    x = block(x)
                else:
                    x = block(x)
        return x


def _create_vision_transformer(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    # NOTE this extra code to support handling of repr size for in21k pretrained models
    pretrained_cfg = resolve_pretrained_cfg(variant)
    # pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
    default_num_classes = pretrained_cfg['num_classes']
    num_classes = kwargs.get('num_classes', default_num_classes)
    repr_size = kwargs.pop('representation_size', None)
    if repr_size is not None and num_classes != default_num_classes:
        repr_size = None

    model = build_model_with_cfg(
        ViT_Prompts, variant, pretrained,
        pretrained_cfg=pretrained_cfg,
        representation_size=repr_size,
        pretrained_filter_fn=checkpoint_filter_fn,
        pretrained_custom_load='npz' in pretrained_cfg['url'],
        **kwargs)
    return model


class cNet_v22_vit(nn.Module):

    def __init__(self, args):
        super(cNet_v22_vit, self).__init__()
        self.cfg = cfgc()
        clip_model = load_clip_to_cpu(self.cfg)
        model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
        self.image_encoder = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=True, **model_kwargs)
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.class_num = 1
        self.top_k = args["top_k"]
        self.args = args
        self.layer_num = self.image_encoder.layers

        if args['dataset'] == 'cifar':
            self.class_num = 100
            self.prompt_learner = PromptLearner_v2(self.cfg, cifar100_classnames, clip_model)
        if args['dataset'] == 'cifar10':
            self.class_num = 10
            self.prompt_learner = PromptLearner_v2(self.cfg, cifar10_classnames, clip_model)
        if args['dataset'] == 'cars':
            self.class_num = len(stanfordcars_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, stanfordcars_classnames, clip_model)
        if args['dataset'] == 'dtd':
            self.class_num = len(dtd_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, dtd_classnames, clip_model)
        if args['dataset'] == 'sat':
            self.class_num = len(SAT_classnames)
            self.prompt_learner = PromptLearner_v3(self.cfg, SAT_classnames, clip_model)
        if args['dataset'] == 'aircraft':
            self.class_num = len(Aircraft_classnames)
            self.prompt_learner = PromptLearner_v4(self.cfg, Aircraft_classnames, clip_model)
        if args['dataset'] == 'flower':
            self.class_num = len(flower_classnames)
            self.prompt_learner = PromptLearner_flower(self.cfg, flower_classnames, clip_model)
        if args['dataset'] == 'nwpu':
            self.class_num = len(nwpu_classnames)
            self.prompt_learner = PromptLearner_nwpu(self.cfg, nwpu_classnames, clip_model)
        if args['dataset'] == 'pattern':
            self.class_num = len(pattern_classnames)
            self.prompt_learner = PromptLearner_nwpu(self.cfg, pattern_classnames, clip_model)
        if args['dataset'] == 'Imagenet':
            self.class_num = len(imagenet_classnames)
            self.prompt_learner = PromptLearner_v2(self.cfg, imagenet_classnames, clip_model)
        if args['dataset'] == 'dog':
            self.class_num = len(dog_classnames)
            self.prompt_learner = PromptLearner_dog(self.cfg, dog_classnames, clip_model)
        if args['dataset'] == 'ucf':
            self.class_num = len(ucf_classnames)
            self.prompt_learner = PromptLearner_ucf(self.cfg, ucf_classnames, clip_model)

        self.tokenized_prompts = self.prompt_learner.tokenized_prompts

        self.linear_projection = copy.deepcopy(self.image_encoder.patch_embed)
        self.gcn_prompt = gcn_prompt_net(use_stochastic=False, gcn_len=self.args["prompt_length_c"])

    def encode_image_attention(self, image):

        prompts = self.prompt_learner().to(self.global_pool.device)
        tokenized_prompts = self.tokenized_prompts.to(self.global_pool.device)
        text_features = self.text_encoder(prompts, tokenized_prompts)  # class_num feature
        class_pool_key = text_features
        class_pool_key_norm = class_pool_key / class_pool_key.norm(dim=-1, keepdim=True)

        image_tokens = self.linear_projection(image)
        image_tokens = image_tokens.permute(0, 2, 1)
        image_tokens = image_tokens.reshape(-1, self.image_encoder.output_dim, 14, 14)
        hidden_token, gcn_prompt, att_tokens = self.gcn_prompt(image_tokens,
                                                               class_pool_key_norm.to(dtype=torch.float32))

        gcn_prompt = gcn_prompt.reshape(-1, self.layer_num, self.args["prompt_length_c"], self.args["embd_dim"])
        gcn_prompt = gcn_prompt.permute(1, 0, 2, 3)

        return self.image_encoder.forward_attn(image, self.global_pool, gcn_prompt)

    def encode_cross_attention(self, image, target=None, p_target=None):

        prompts = self.prompt_learner().to(self.global_pool.device)
        tokenized_prompts = self.tokenized_prompts.to(self.global_pool.device)
        text_features = self.text_encoder(prompts, tokenized_prompts)  # class_num feature
        class_pool_key = text_features
        class_pool_key_norm = class_pool_key / class_pool_key.norm(dim=-1, keepdim=True)

        image_tokens = self.linear_projection(image)

        image_tokens = image_tokens.permute(0, 2, 1)
        image_tokens = image_tokens.reshape(-1, self.image_encoder.output_dim, 14, 14)
        att_maps = self.gcn_prompt.forward_att_map_vit(image_tokens, class_pool_key_norm.to(dtype=torch.float32))

        real_maps = att_maps[0, target, :].squeeze()
        wrong_maps = att_maps[0, p_target, :].squeeze()
        return real_maps, wrong_maps
    def forward(self, image, target=None, p_target=None):
        #
        # This part of code will be released after accepted.
        #
        return {
            'logits': torch.cat(logits, dim=1),
            'features': image_features,
            'increase_sim': increase_sim,
            'reduce_sim': reduce_sim,
        }
