import math
import torch
import torch.nn as nn
import numpy as np
from torch.nn.modules.utils import _pair
from scipy import ndimage
import torch.nn.functional as F
from models import configs

ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'testing': configs.get_testing(),
}


def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


class Lora_Linear(nn.Module):
    def __init__(self, size_in, size_out, bias=True, enable_lora=False, FFN=False, mask=None):
        super(Lora_Linear, self).__init__()
        self.enable_lora = enable_lora
        self.FFN = FFN
        self.mask = mask
        self.size_in = size_in
        self.size_out = size_out
        if self.FFN:
            rank = (mask == 1).sum()
            if size_in < size_out:
                self.mlp = nn.Linear(size_in, rank, bias=bias)
            else:
                self.mlp = nn.Linear(rank, size_out, bias=bias)
        else:
            self.mlp = nn.Linear(size_in, size_out, bias=bias)
        self.rank = 1

        self.num = min(size_in, size_out)
        if self.enable_lora:
            self.Lora_A = nn.Parameter(torch.empty(self.rank, size_in), requires_grad=True)
            self.Lora_B = nn.Parameter(torch.empty(size_out, self.rank), requires_grad=True)
            self.scaling = 1 / self.rank
            nn.init.kaiming_uniform_(self.Lora_A)
            nn.init.zeros_(self.Lora_B)

        self._frozen_param()

    def _frozen_param(self):
        for param in self.mlp.parameters():
            param.requires_grad = False

    def expand_with_mask(self, features, mask):
        """
        Args:
            tensor: 二维张量 [num_rows, num_selected_cols], 仅包含 mask=1 对应的数据
            mask: 一维张量 [num_cols], 值为 0 或 1, mask.sum() = num_selected_cols
        Returns:
            output: 二维张量 [num_rows, num_cols], mask=1 的位置为 tensor 值, mask=0 的位置为 0
        """
        # 检查输入合法性
        assert features.dim() == 3, "tensor 必须是二维张量"
        assert mask.dim() == 1, "mask 必须是一维张量"
        assert mask.sum() == features.size(-1), "mask 中 1 的数量必须等于 tensor 的列数"

        shape, num_rows, num_selected_cols = features.shape
        num_cols = mask.size(0)

        # 初始化全零张量（支持梯度）
        output = torch.zeros(shape,num_rows, num_cols,dtype=features.dtype, device=features.device)

        # 获取 mask=1 的列索引（布尔掩码）
        mask_bool = mask.bool()  # [num_cols]

        # 将 tensor 的值填充到 output 的 mask=1 位置
        output[:, :, mask_bool] = features

        return output

    def forward(self, x):
        if self.enable_lora:
            if self.FFN:
                if self.size_in > self.size_out:
                    mask_bool = self.mask.bool()
                    result = self.mlp(x[:, :, mask_bool])
                else:
                    result = self.mlp(x)
                    result = self.expand_with_mask(result, self.mask)
            else:
                result = self.mlp(x)
            result += (x @ self.Lora_A.transpose(0, 1) @ self.Lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return self.mlp(x)


class LoraAttention(nn.Module):
    def __init__(self, config, enable_lora=False):
        super(LoraAttention, self).__init__()
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Lora_Linear(config.hidden_size, self.all_head_size, bias=True, enable_lora=enable_lora)
        self.key = Lora_Linear(config.hidden_size, self.all_head_size, bias=True, enable_lora=enable_lora)
        self.value = Lora_Linear(config.hidden_size, self.all_head_size, bias=True, enable_lora=enable_lora)
        self.out = Lora_Linear(config.hidden_size, config.hidden_size, bias=True, enable_lora=enable_lora)
        self.attn_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = nn.Dropout(config.transformer["attention_dropout_rate"])
        self.softmax = nn.Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        # weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output  # , weights


class LoraMLP(nn.Module):
    def __init__(self, config, mask, enable_lora=False):
        super(LoraMLP, self).__init__()
        self.mask = mask
        self.fc1 = Lora_Linear(config.hidden_size, config.transformer["mlp_dim"], bias=True, enable_lora=True, FFN=True, mask=self.mask)
        self.fc2 = Lora_Linear(config.transformer["mlp_dim"], config.hidden_size, bias=True, enable_lora=True, FFN=True, mask=self.mask)

        self.act = nn.GELU()
        self.dropout = nn.Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        x = self.act(self.fc1(x))
        if self.training:
            x = self.dropout(x)
        x = self.fc2(x)
        if self.training:
            x = self.dropout(x)
        return x


class LoraBlock(nn.Module):
    def __init__(self, config, mask, drop_path=0.0, enable_lora=False):
        super(LoraBlock, self).__init__()
        self.ffn_mask = mask
        self.hidden_size = config.hidden_size
        self.attention_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = LoraMLP(config, self.ffn_mask, enable_lora=enable_lora)
        self.attn = LoraAttention(config, enable_lora=enable_lora)
        self.enable_lora = enable_lora
        self.drop_path1 = nn.Dropout(p=drop_path)
        self.drop_path2 = nn.Dropout(p=drop_path)

    def forward(self, x):
        x = x + self.drop_path1(self.attn(self.attention_norm(x)))
        x = x + self.drop_path2(self.ffn(self.ffn_norm(x)))
        return x

    def _fc_load_weight(self, ROOT, Key, Weights, unit):
        mat_weights = Weights[ROOT + '/' + Key + '/' + "kernel"]
        mat_bias = Weights[ROOT + '/' + Key + '/' + "bias"]
        if Key == ATTENTION_OUT:
            mat_weights = mat_weights.reshape(-1, mat_weights.shape[-1])
        else:
            mat_weights = mat_weights.reshape(mat_weights.shape[0], -1)

        if Key == FC_0:
            # 合并剩余权重
            # intergate_indicates = (self.ffn_mask == 0).nonzero().squeeze()
            # intergate_weight = torch.from_numpy(mat_weights[:, intergate_indicates])
            # lora_a_weight = torch.mean(intergate_weight, dim=1)
            # unit.Lora_A.copy_(lora_a_weight.t())

            # 权重进行裁剪
            mask_indicates = torch.nonzero(self.ffn_mask).flatten()
            mat_weights = mat_weights[:, mask_indicates]
            mat_bias = mat_bias[mask_indicates]

        elif Key == FC_1:
            # 合并剩余权重
            # intergate_indicates = (self.ffn_mask == 0).nonzero().squeeze()
            # intergate_weight = torch.from_numpy(mat_weights[intergate_indicates, :])
            # lora_b_weight = torch.mean(intergate_weight, dim=0)
            # nn.init.zeros_(unit.Lora_A)
            # unit.Lora_B.copy_(lora_b_weight.reshape(768, 1) )

            # 对权重进行裁剪
            mask_indicates = torch.nonzero(self.ffn_mask).flatten()
            mat_weights = mat_weights[mask_indicates,:]

        unit.mlp.weight.copy_(np2th(mat_weights).t())
        unit.mlp.bias.copy_(np2th(mat_bias).view(-1))

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            self._fc_load_weight(ROOT, ATTENTION_Q, weights, self.attn.query)
            self._fc_load_weight(ROOT, ATTENTION_K, weights, self.attn.key)
            self._fc_load_weight(ROOT, ATTENTION_V, weights, self.attn.value)
            self._fc_load_weight(ROOT, ATTENTION_OUT, weights, self.attn.out)
            self._fc_load_weight(ROOT, FC_0, weights, self.ffn.fc1)
            self._fc_load_weight(ROOT, FC_1, weights, self.ffn.fc2)

            self.attention_norm.weight.copy_(np2th(weights[ROOT + '/' + ATTENTION_NORM + '/' + "scale"]))
            self.attention_norm.bias.copy_(np2th(weights[ROOT + '/' + ATTENTION_NORM + '/' + "bias"]))
            self.ffn_norm.weight.copy_(np2th(weights[ROOT + '/' + MLP_NORM + '/' + "scale"]))
            self.ffn_norm.bias.copy_(np2th(weights[ROOT + '/' + MLP_NORM + '/' + "bias"]))


class LoraEncoder(nn.Module):
    def __init__(self, config, pruning_config, drop_path=0.0, enable_lora=True):
        super(LoraEncoder, self).__init__()
        self.enable_lora = enable_lora
        self.layer = nn.ModuleList()
        # self.lora_layer = nn.ModuleList()
        for n in range(config.transformer["num_layers"]):
            setattr(self, f"lora_layer_{n}", nn.ModuleList())
        self.encoder_norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
        self.num_blocks = config.transformer["num_layers"]
        self.depth = config.transformer["num_layers"]
        dpr = [x.item() for x in torch.linspace(0, drop_path, self.num_blocks)]
        # fellow SSF
        for i in range(config.transformer["num_layers"]):
            self.layer.append(LoraBlock(config, pruning_config[i], drop_path=dpr[i], enable_lora=enable_lora))

    def forward(self, hidden_states):

        for layer_block in self.layer:
            hidden_states = layer_block(hidden_states)

        encoded = self.encoder_norm(hidden_states)
        return encoded


class LoraEmbeddings(nn.Module):
    def __init__(self, config, img_size, in_channels=3):
        super(LoraEmbeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        patch_size = _pair(config.patches["size"])
        n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
        self.hybrid = False

        self.patch_embeddings = nn.Conv2d(in_channels=in_channels,
                                          out_channels=config.hidden_size,
                                          kernel_size=patch_size,
                                          stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = nn.Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        B = x.size(0)
        cls_tokens = self.cls_token.expand(B, -1, -1)

        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)
        embeddings = x + self.position_embeddings

        return embeddings


class LoraPruningTransformer(nn.Module):
    def __init__(self, config, pruning_config, img_size, drop_path=0.0, enable_lora=True):
        super(LoraPruningTransformer, self).__init__()
        self.embeddings = LoraEmbeddings(config, img_size=img_size)
        self.encoder = LoraEncoder(config, pruning_config, drop_path=drop_path, enable_lora=enable_lora)
        self._frozen_param()

    def _frozen_param(self):
        for param in self.embeddings.parameters():
            param.requires_grad = False

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded = self.encoder(embedding_output)
        return encoded


class ImprovedLoraPruningVisionTransformer(nn.Module):
    def __init__(self, config, pruning_config, img_size=224, num_classes=21843, zero_head=False, enable_lora=True,
                 drop_path=0.0):
        super(ImprovedLoraPruningVisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = LoraPruningTransformer(config, pruning_config, img_size, drop_path=drop_path, enable_lora=enable_lora)
        self.head = nn.Linear(config.hidden_size, num_classes)
        self.loss_fct = nn.CrossEntropyLoss()

    def get_parameters(self, lr, weight_decay):
        wd_params = []
        no_wd_params = []
        for name, param in self.named_parameters():
            if 'bias' in name or 'norm' in name:
                no_wd_params.append(param)
            else:
                wd_params.append(param)

        params = [
            {"params": wd_params, "lr": lr, "weight_decay": weight_decay},
            {"params": no_wd_params, "lr": lr, "weight_decay": 0.}
        ]

        return params

    def forward(self, x, labels=None):
        x = self.transformer(x)

        logits = self.head(x[:, 0])
        if labels is not None:
            loss = self.loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits

    def load_from(self, weights):
        with torch.no_grad():
            if self.zero_head:
                nn.init.zeros_(self.head.weight)
                nn.init.zeros_(self.head.bias)
            else:
                self.head.weight.copy_(np2th(weights["head/kernel"]).t())
                self.head.bias.copy_(np2th(weights["head/bias"]).t())

            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))
            self.transformer.embeddings.cls_token.copy_(np2th(weights["cls"]))
            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])
            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                print("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)

                if self.classifier == "token":
                    posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
                    ntok_new -= 1
                else:
                    posemb_tok, posemb_grid = posemb[:, :0], posemb[0]

                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)

                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = np.concatenate([posemb_tok, posemb_grid], axis=1)
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

