import os
import sys
import torch
from tqdm import tqdm
import numpy as np
import torch.nn as nn
# from models.network_lora_orthgonal import CONFIGS, LoraOrthVisionTransformer
from models.hide_features_ViT import CONFIGS, VisionTransformer
from RLRRDatasets.VTABDataLoader import get_data
from RLRRDatasets.VTABConfig import DATA_CONFIGS
from utils import (seed_torch, Logger, count_parameters)
from any_tools import get_hist_graph, get_bottle_k, get_top_k, calculate_posibility
import argparse
from Pruning_Config.loader_and_pruning import get_global_FFN_mask_by_distribution, get_global_FFN_mask_by_norm, get_local_FFN_mask_by_norm


def get_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--name", default="parameter-efficient fine-tuning")
    parser.add_argument("--dataset_name", default="eurosat")
    parser.add_argument("--model_type", default="ViT-B_16")
    parser.add_argument("--dataset_dir", default="/home/datasets/vtab-1k/")

    parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz")  # imagenet21k_
    parser.add_argument("--device", default='cuda:1', type=str)
    parser.add_argument("--num_classes", default=100, type=int)
    parser.add_argument("--img_size", default=224, type=int)
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--simple_aug", default=True, type=bool)

    parser.add_argument("--pruning_type", default="distribution", choices=["distribution", "norm", "importance"])
    parser.add_argument("--rate", default=0.3, type=float)
    parser.add_argument("--pruning_area", default="global", choices=["global", "local"])
    parser.add_argument("--stor_dir", default="Pruning_Config/")
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    return args


def setup(args):
    # Prepare model
    config = CONFIGS[args.model_type]
    model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=args.num_classes)
    model.load_from(np.load(args.pretrained_dir))

    return model


@torch.no_grad()
def train(model, train_loader, layer_index, before_act=False):
    model.eval()
    output_set = []

    i = 0
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.long().to(device)
        output = model(data, layer_index, before_act)
        output_set.append(output)
        # i=i+1
        # if i >= 31:
        #     break

    return output_set


def get_model_trainloader(args):
    config = DATA_CONFIGS[args.dataset_name]
    args.data_path = os.path.join(args.dataset_dir, args.dataset_name)
    args.batch_size = config['batch_size']
    args.simple_aug = config['simple_aug']
    args.num_classes = config['num_classes']
    train_loader, test_loader = get_data(data_path=args.data_path, batch_size=args.batch_size,
                                         simple_aug=args.simple_aug)

    model = setup(args)
    model.to(device)
    print(args, config)

    return model, train_loader

# 获取中间特征的置信度
def get_layer_probs(output):
    print("start calculating distribution and believe range")
    # 先处理feature特征
    feature = output.reshape(-1, output.shape[-1])
    print(feature.shape)
    # #  计算每一个维度的置信度
    prob_list = []
    for index in range(feature.shape[-1]):
        vector = feature[:, index]
        prob1 = calculate_posibility(vector, 0)
        # prob2 = calculate_posibility(vector, 0.2, -0.2)
        # prob = prob1 + prob2
        prob_list.append(prob1.item())
    # 所有维度的置信度形成一维tensor
    probs = torch.from_numpy(np.array(prob_list))

    return probs

# 获取中间特征的norm
def get_layer_norm(output):
    print("start calculating feature norm")
    feature = output.reshape(-1, output.shape[-1])
    print(feature.shape)

    norm = torch.norm(feature, p=2, dim=0)
    print("output norm size:", norm.size())
    return norm.cpu()


# 获取中间特征的重要度，利用feature_norm * weight_norm的方式得到
def get_importance(features, model):
    print("start calculating importance")
    layer_importance = []
    for index in range(12):
        feature_norm = features[index, :]
        weight_name = f"transformer.encoder.layer.{index}.ffn.fc2.weight"
        weight = model.state_dict()[weight_name].t().cpu()
        weight_norm = torch.mean(weight.abs(), dim=1)
        importance = feature_norm * weight_norm
        layer_importance.append(importance)

    importance = torch.stack(layer_importance)
    return importance

if __name__ == '__main__':
    args = get_args_parser()
    seed_torch(args.seed)
    device = torch.device(args.device)

    model, train_loader = get_model_trainloader(args)

    before_act = True
    if args.pruning_type == "distribution":
        before_act = True
    elif args.pruning_type == "norm" or args.pruning_type == "importance":
        before_act = False

    # 用于存储12层的中间特征信息
    layer_probs = []
    for index in range(12):
        output = train(model, train_loader, index, before_act)
        # 将list合并为一个tensor
        output = torch.cat(output, dim=0)

        # 根据分布需要在激活函数之前，因为之前符合正太分布
        if args.pruning_type == "distribution":
            probs = get_layer_probs(output)
        # 剩下的根据norm需要
        else:
            probs = get_layer_norm(output)

        # 将每层都拼接一下
        layer_probs.append(probs)

    # 得到 L x D 的矩阵
    layer_probs = torch.stack(layer_probs)
    print("layer probs size: ", layer_probs.size())

    if args.pruning_type == "importance":
        importance = get_importance(layer_probs, model)
        print(importance.size())

    else:
        importance = layer_probs

    if args.pruning_type == "distribution":
        mask = get_global_FFN_mask_by_distribution(importance, args.rate)
    else:
        if args.pruning_area == "global":
            mask = get_global_FFN_mask_by_norm(importance, args.rate)
        else:
            mask = get_local_FFN_mask_by_norm(importance, args.rate)


    path = os.path.join(args.stor_dir, args.dataset_name)
    if not os.path.exists(path):
        os.makedirs(path)
    save_path = path + f"{args.pruning_type}_{args.dataset_name}_rate{str(args.rate)}_{args.pruning_area}.pt"
    torch.save(mask, save_path)