import os
import sys
import torch
from tqdm import tqdm
import numpy as np
# from models.Original_LoRA import CONFIGS, LoRAVisionTransformer
from models.network_lora import CONFIGS, LoraVisionTransformer
from models.lora_with_config import CONFIGS, LoraPruningVisionTransformer
from models.hide_features_ViT import CONFIGS, VisionTransformer
from RLRRDatasets.VTABDataLoader import get_data
from RLRRDatasets.VTABConfig import DATA_CONFIGS
from utils import (seed_torch, Logger, count_parameters)
import torch.nn as nn
from any_tools import get_hist_graph, get_bottle_k, get_top_k, calculate_posibility
from utils import (seed_torch, accuracy, AverageMeter, Logger, count_parameters)
import argparse


def get_args_parser():
    parser = argparse.ArgumentParser()

    parser.add_argument("--name", default="parameter-efficient fine-tuning")
    parser.add_argument("--dataset_name", default="dtd")
    parser.add_argument("--model_type", default="ViT-B_16")

    parser.add_argument("--type", default="vit", choices=["vit", "lora", "mask_lora"])
    parser.add_argument("--loramodel_dir", default="/home/lzj/orthPEFT/output/dtd/dtd 72.71_original.pth")

    parser.add_argument("--rate", default=0.2, type=float)
    parser.add_argument("--pruning_type", default="distribution", choices=["distribution", "norm", "importance"])
    parser.add_argument("--pruning_area", default="global", choices=["global", "local"])
    parser.add_argument("--model_dir", default="/home/lzj/orthPEFT/output/cifar/cifar 65.01111.pth")


    parser.add_argument("--dataset_dir", default="/home/datasets/vtab-1k/vtab-1k/")
    parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz")  # imagenet21k_
    parser.add_argument("--output_dir", default="output", type=str)  # -aug all-no-res
    parser.add_argument("--device", default='cuda:0', type=str)

    parser.add_argument("--num_workers", default=6, type=int)
    parser.add_argument("--img_size", default=224, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--num_classes", default=100, type=int)
    parser.add_argument("--batch_size", default=256, type=int)
    parser.add_argument("--learning_rate", default=3e-3, type=float)
    parser.add_argument("--weight_decay", default=5e-5, type=float)
    parser.add_argument("--simple_aug", default=True, type=bool)

    # fellow SSF
    parser.add_argument("--warmup_epochs", default=10, type=int)
    parser.add_argument("--sched", choices=["cosine", "linear"], default="cosine")
    parser.add_argument("--lr_cycle_decay", default=0.5, type=float)
    parser.add_argument("--cooldown_epochs", default=10, type=int)

    parser.add_argument("--local-rank", type=int, default=-1)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=2)
    parser.add_argument('--loss_scale', type=float, default=0)
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    args = parser.parse_args()

    return args

@torch.no_grad()
def valid(model, test_loader, device):
    model.eval()
    top1 = AverageMeter('Acc@1', ':6.2f')
    losses = AverageMeter('Loss', ':.4e')
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (x, label) in enumerate(tqdm(test_loader)):
        x, label = x.to(device), label.to(device)
        # with autocast():
        output = model(x)

        loss = criterion(output, label)
        acc1 = accuracy(output, label, topk=(1,))
        top1.update(acc1[0].item(), x.size(0))
        losses.update(loss.item(), x.size(0))
    print('Test :', losses, top1)
    return top1.avg, losses.avg


def load_mask_config(args, activate=False):
    # pth = "Pruning_Config/" + type + dataset + "_rate" + str(rate) + "_per_layer_config.pt"
    save_path = f"Pruning_Config/{args.pruning_type}_{args.dataset_name}_rate{str(args.rate)}_{args.pruning_area}.pt"
    mask = torch.load(save_path).cpu()
    return mask


def setup(args, m_type):
    # Prepare model
    config = CONFIGS[args.model_type]
    if m_type == "vit":
        model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=args.num_classes,
                                      drop_path=args.drop_path)
        model.load_from(np.load(args.pretrained_dir))
    elif m_type == "lora":
        model = LoraVisionTransformer(config, args.img_size, zero_head=True, num_classes=args.num_classes,
                                      drop_path=args.drop_path)
        model.load_from(np.load(args.pretrained_dir))
        lora_dict = torch.load(args.loramodel_dir)
        model.load_state_dict(lora_dict, strict=False)
    else:
        mask_config = load_mask_config(args)
        zeros_per_row = (mask_config == 0).sum(dim=1)
        print(zeros_per_row)
        print(zeros_per_row.sum())
        model = LoraPruningVisionTransformer(config, mask_config, args.img_size, zero_head=True, num_classes=args.num_classes,
                                      drop_path=args.drop_path)
        model.load_from(np.load(args.pretrained_dir))
        lora_dict = torch.load(args.model_dir)
        model.load_state_dict(lora_dict, strict=False)


    return model


@torch.no_grad()
def train(model, train_loader, layer_index, before_act=False):
    model.eval()

    output_set = []

    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.long().to(device)
        output = model(data, index=layer_index, before_act=before_act)
        output_set.append(output)

    return output_set



def main(args, index, before_act, m_type):
    config = DATA_CONFIGS[args.dataset_name]
    args.data_path = os.path.join(args.dataset_dir, args.dataset_name)
    args.num_classes = config['num_classes']
    args.learning_rate = config['lr']
    args.min_lr = config['min_lr']
    args.drop_path = config['drop_path']
    args.warmup_lr = config['warmup_lr']
    args.weight_decay = config['weight_decay']
    args.batch_size = config['batch_size']
    args.simple_aug = config['simple_aug']
    if not os.path.exists(os.path.join(args.output_dir, args.dataset_name)):
        os.makedirs(os.path.join(args.output_dir, args.dataset_name))

    sys.stdout = Logger(sys.stdout,
                        os.path.join(args.output_dir, args.dataset_name, '{}111.txt').format(args.dataset_name))

    train_loader, test_loader = get_data(data_path=args.data_path, batch_size=args.batch_size,
                                         simple_aug=args.simple_aug)

    model = setup(args, m_type)
    model.to(device)

    # valid(model, test_loader, device)
    print(args, config)

    output = train(model, train_loader, index, before_act)
    # 将list合并为一个tensor
    features = torch.cat(output, dim=0)
    return features

def draw_distribution(features):
    # func = nn.GELU()
    # flattened = func(features)
    mean_result = features.mean(dim=0)

    get_hist_graph(mean_result.cpu())

def analyze(output, rate):
    # 先处理teature特征
    feature = output.reshape(-1, output.shape[-1])
    print(feature.shape)
    # draw_distribution(feature)
    mean_result = feature[:,2000]

    get_hist_graph(mean_result.cpu())

    # prob_list = []
    # for index in range(feature.shape[-1]):
    #     vector = feature[:, index]
    #     prob = calculate_posibility(vector, -2.0)
    #     prob_list.append(prob.item())
    #
    # probs = torch.from_numpy(np.array(prob_list))
    # lentgh = len(probs)
    # top_k = int(lentgh * rate)
    # values, pruning_top_index = get_top_k(probs, top_k)
    # print(values)

    # print(output_set.size())
    # flattened = output_set.reshape(-1, output_set.shape[-1])
    # print(flattened.size())
    # mean_result = flattened.mean(dim=0)
    #
    # bottle_k = int(mean_result.shape[-1] * rate)
    # index = get_top_k(mean_result, bottle_k)
    # # func = nn.GELU()
    # # flattened = func(flattened)
    # processed_features = flattened[:, index]
    # print(processed_features.size())
    # final_result = processed_features.mean(dim=0)
    # get_hist_graph(final_result.cpu())
    #
    # func = nn.GELU()
    # processed_features = func(processed_features)
    # final_result = processed_features.mean(dim=0)
    # get_hist_graph(final_result.cpu())

if __name__ == '__main__':
    args = get_args_parser()
    seed_torch(args.seed)
    device = torch.device(args.device)


    index = 0
    before_act = True
    output = main(args, index, before_act, args.type)
    print(output.size())
    analyze(output, args.rate)