import torch
import torch.nn as nn
import copy
import numpy as np
import torch.nn.functional as F

from models.vit_lora import VisionTransformer, PatchEmbed,  resolve_pretrained_cfg, build_model_with_cfg, \
    checkpoint_filter_fn, Block_GR_Lora
    

class ViT_gr_lora(VisionTransformer):
    def __init__(
            self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
            embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
            embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block_GR_Lora, n_tasks=10, rank=64):

        super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
            embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, representation_size=representation_size,
            drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate, weight_init=weight_init, init_values=init_values,
            embed_layer=embed_layer, norm_layer=norm_layer, act_layer=act_layer, block_fn=block_fn, n_tasks=n_tasks, rank=rank)


    def forward_shared(self, x, task_id, register_blk=-1, get_feat=False, get_cur_feat=False, get_cur_x=False, get_old_x=False):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = x + self.pos_embed[:,:x.size(1),:]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            x = blk(x, task_id, register_blk==i, get_feat=get_feat, get_cur_feat=get_cur_feat, get_cur_x=get_cur_x, get_old_x=get_old_x)

        x = self.norm(x)[:,0,:]
        
        return x

    def forward(self, x, task_id, register_blk=-1, get_feat=False, get_cur_x=False, aux_id=None):
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = x + self.pos_embed[:,:x.size(1),:]
        x = self.pos_drop(x)

        for i, blk in enumerate(self.blocks):
            if aux_id is None:
                x = blk(x, task_id, register_blk==i, get_feat=get_feat, get_cur_x= get_cur_x, aux_id=task_id)
            else:
                x = blk(x, task_id, register_blk==i, get_feat=get_feat, get_cur_x= get_cur_x, aux_id=aux_id)

        x = self.norm(x)[:,0,:]
        
        return x


def _create_vision_transformer_gr_lora(variant, pretrained=False, **kwargs):
    if kwargs.get('features_only', None):
        raise RuntimeError('features_only not implemented for Vision Transformer models.')

    pretrained_cfg = resolve_pretrained_cfg(variant)
    default_num_classes = pretrained_cfg['num_classes']
    num_classes = kwargs.get('num_classes', default_num_classes)
    repr_size = kwargs.pop('representation_size', None)
    if repr_size is not None and num_classes != default_num_classes:
        repr_size = None

    model = build_model_with_cfg(
        ViT_gr_lora, variant, pretrained,
        pretrained_cfg=pretrained_cfg,
        representation_size=repr_size,
        pretrained_filter_fn=checkpoint_filter_fn,
        pretrained_custom_load='npz' in pretrained_cfg['url'],
        **kwargs)
    return model

    
class SiNet_GR_LoRA(nn.Module):

    def __init__(self, args):
        super(SiNet_GR_LoRA, self).__init__()

        model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, n_tasks=args["total_sessions"], rank=args["rank"])
        self.image_encoder =_create_vision_transformer_gr_lora('vit_base_patch16_224_in21k', pretrained=True, **model_kwargs)
        
        self.class_num = 1
        self.class_num = args["init_cls"]
        self.cur_task = -1
        self.device = args['device'][0]

        self.classifier_pool = nn.ModuleList([
            nn.Linear(args["embd_dim"], self.class_num, bias=True)
            for i in range(args["total_sessions"])
        ])

        self.numtask = 0
        self.total_sessions = args["total_sessions"]

    def update_fc(self, nb_classes):
        self.numtask +=1
        self.cur_task += 1

    def extract_vector(self, image, task=None):
        if task == None:
            image_features = self.image_encoder(image, self.numtask-1)
        else:
            image_features = self.image_encoder(image, task)
       
        return image_features
    
    def extract_vector_by_auxid(self, image, task_id, aux_id):

        image_features = self.image_encoder(image, task_id=task_id, aux_id=aux_id)
       
        return image_features
    
    def extract_vector_shared(self, image, task_id=None):
        if task_id == None:
            image_features = self.image_encoder.forward_shared(image, self.numtask-1)
        else:
            image_features = self.image_encoder.forward_shared(image, task_id)
       
        return image_features

    def forward_shared(self, image, get_feat=False, get_cur_x=False, fc_only=False):

        image_features = self.image_encoder(image, task_id=self.numtask-1, get_feat=get_feat, get_cur_x=get_cur_x)
        logits = []
        for prompts in self.classifier_pool[0:self.numtask]:
            logits.append(prompts(image_features))
        logits = torch.cat(logits, dim=1)

        return {
            'logits': logits,
            'features': image_features,
        }
    
    def forward_only_fc(self, features):
        logits = []
        for prompts in self.classifier_pool[0:self.numtask]:
            logits.append(prompts(features))
        logits = torch.cat(logits, dim=1)

        return logits

    def forward(self, image, get_feat=False, get_cur_x=False, fc_only=False):
        

        image_features = self.image_encoder(image, task_id=self.numtask-1, get_feat=get_feat, get_cur_x=get_cur_x)

        logits = []
        for prompts in self.classifier_pool[0:self.numtask]:
            logits.append(prompts(image_features))
        logits = torch.cat(logits, dim=1)

        return {
            "logits": logits,
            "features": image_features,
        }

    def interface(self, image, increment):

        if self.numtask-1 == 0:
            image_features = self.image_encoder(image, task_id=self.numtask-1)

            logits_org = []
            for prompts in self.classifier_pool[0:self.numtask]:
                logits_org.append(prompts(image_features))
            logits_org = torch.cat(logits_org, dim=1)
            logits = logits_org

        else:

            all_predicts = []
            all_entropies = []
            all_logits = []
            for i in range(self.numtask):
                with torch.no_grad():
                    image_features = self.image_encoder(image, task_id=self.numtask-1, aux_id=i)
                    
                    logits = []
                    for prompts in self.classifier_pool[0:self.numtask]:
                        logits.append(prompts(image_features))
                    logits = torch.cat(logits, dim=1)
                    
                probs = nn.functional.softmax(logits, dim=1)


                entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)  # bs
                predicts = torch.topk(logits, k=1, dim=1, largest=True, sorted=True)[1]

                pred_labels = predicts.squeeze(1)
                out_of_task_mask = (pred_labels < i * increment) | (pred_labels >= (i + 1) * increment)
                entropy[out_of_task_mask] = 1e9

                all_predicts.append(predicts.cpu().numpy())
                all_entropies.append(entropy.cpu().numpy())
                all_logits.append(logits.cpu().numpy())
            all_predicts = np.array(all_predicts)
            all_entropies = torch.tensor(all_entropies)
            all_logits = torch.tensor(all_logits)

            min_entropy_indices = torch.argmin(all_entropies, axis=0)  # bs
            min_entropy_logits = all_logits[min_entropy_indices, torch.arange(len(min_entropy_indices))].to(self.device)
            logits = min_entropy_logits

        return logits

    def classifier_recall(self):
        self.classifier_pool.load_state_dict(self.old_state_dict)

    def copy(self):
        return copy.deepcopy(self)

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False
        self.eval()

        return self