import pickle
import random
import re
import numpy as np
import torch
from transformers import AutoTokenizer

from dataclasses import dataclass, field
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoTokenizer,
    LlamaForCausalLM,
    set_seed, )
from datasets import load_dataset
import os
from typing import Optional
from collator import DataCollator
from torch.utils.data import DataLoader

CURRENT_DIR = os.path.dirname(__file__)

def _compute_importance_for_each_model(args, model_name_or_path: str):
    set_seed(args.seed)
    
    # load dataset
    raw_datasets = load_dataset(
        os.path.join(CURRENT_DIR, "dataset.py"),
        data_dir=args.data_dir
    )
    
    raw_datasets.cleanup_cache_files()
    print(raw_datasets)

    if args.sample_size > 0 and args.sample_size < len(raw_datasets["train"]):
        print(f"Sample size: {args.sample_size}")
        train_set = raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 100))
    else:
        print("Use all data")
        train_set = raw_datasets["train"]
    
    print("Examples: \n", train_set["Instance"][:5])
    
    config = AutoConfig.from_pretrained(
        model_name_or_path
    )
    config.bos_token_id = 1
    config.eos_token_id = 2
    config.pad_token_id = 1
    tokenizer = AutoTokenizer.from_pretrained(
        model_name_or_path
    )
    tokenizer.bos_token_id = 1
    tokenizer.eos_token_id = 2
    tokenizer.pad_token_id = 1
    model = LlamaForCausalLM.from_pretrained(
        model_name_or_path,
        from_tf=bool(".ckpt" in model_name_or_path),
        config=config,
        use_safetensors=True,
        device_map='auto'
    )

    model.resize_token_embeddings(len(tokenizer))
    
    # model, _, _, _ = deepspeed.initialize(
    #     model=model,
    #     config_params=ds_config
    # )
    
    data_collator = DataCollator(
        tokenizer,
        model=model,
        padding="longest",
        max_source_length=args.max_source_length,
        max_target_length=args.max_target_length,
        pad_to_multiple_of=8 if args.fp16 else None,
    )

    data_loader = DataLoader(
        train_set,
        batch_size=args.batch_size,
        collate_fn=data_collator
    )

    model.train()
    
    def compute_importance(model, data_loader):
        # 声明列表存储参数重要性
        mean_importances = {}
        len_all = len(data_loader)
        radio = 1 / len_all
        for batch in tqdm(data_loader):
            model.zero_grad()
            # 前向传播
            outputs = model(input_ids=batch["input_ids"].cuda(), labels=batch["labels"].cuda())

            # 计算损失
            loss = outputs.loss

            loss.backward()
            
            # 遍历所有参数
            for name, param in model.named_parameters():
                if param.requires_grad:                    
                    gradients = param.grad
                    
                    # 计算参数重要性
                    importance = torch.abs(gradients * param.data * radio)
                    if name not in mean_importances:
                        mean_importances[name] = importance
                    else:
                        mean_importances[name] += importance
        
        return mean_importances
    
    return compute_importance(model, data_loader)

def get_importance_for_each_model(args, model_name_or_path: str, save_path: str):
    importances = load_from_pickle(save_path)
    if importances is None:
        importances = _compute_importance_for_each_model(args, model_name_or_path)
        save_to_pickle(importances, save_path)
    
    # 统一转成cpu
    importances = {name: importance.cpu() for name, importance in importances.items()}
    return importances

def load_from_pickle(file_path):
    if not os.path.exists(file_path):
        return None
    print(f"Loading data from {file_path}")
    with open(file_path, 'rb') as f:
        return pickle.load(f)

def save_to_pickle(data, file_path):
    if not os.path.exists(os.path.dirname(file_path)):
        os.makedirs(os.path.dirname(file_path))
    print(f"Saving data to {file_path}")
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

def _calculate_weight_location(weight_shapes: dict, weight_nums: dict, flattened_index: int):
    """
    :param weight_shapes: 模型参数的shape
    :param weight_nums: 模型参数的数量
    :param flattened_index: 一维索引
    :return: name: 参数名, multi_index: 多维索引"""
    current_index = flattened_index
    for name in weight_shapes:
        if current_index < weight_nums[name]:
            # Index is in the current layer
            return name, _calculate_multi_dimensional_index(weight_shapes[name], current_index)
        else:
            # Move to the next layer
            current_index -= weight_nums[name]
    
    return None, None

def _calculate_multi_dimensional_index(shape, flattened_index):
    multi_index = []
    for dim_size in reversed(shape):
        flattened_index, index = divmod(flattened_index, dim_size)
        multi_index.append(index)
    return tuple(reversed(multi_index))

def calculate_head_level_importance(safe_importances, unsafe_importances, n_heads, hidden_size, n_param_tops: dict) -> tuple:
    """
    :param safe_importances: 安全模型的importances
    :param unsafe_importances: 不安全模型的importances
    :param n_heads: 头数
    :param hidden_size: 隐藏层大小
    :param n_param_tops: 保留的参数量 dict
    :return: safe_importances, unsafe_importances, intersection_importances
    """
    result_safe_importances = {}
    result_unsafe_importances = {}
       
    safe_importances = {name: weight.view(hidden_size, n_heads, -1) for name, weight in safe_importances.items() if "q_proj" in name or "k_proj" in name or "v_proj" in name}
    unsafe_importances = {name: weight.view(hidden_size, n_heads, -1) for name, weight in unsafe_importances.items() if "q_proj" in name or "k_proj" in name or "v_proj" in name}
    
    # 对每个head的所有importances进行求和
    for name in tqdm(safe_importances):
        param_num = safe_importances[name].numel()
        a = torch.sum(safe_importances[name], dim=(0,2))
        b = torch.sum(unsafe_importances[name], dim=(0,2))
        result_safe_importances.update({f"{name}:{i}": a[i] for i in range(n_heads)})
        result_unsafe_importances.update({f"{name}:{i}": b[i] for i in range(n_heads)})
    
    TOP_NUM_GA = n_param_tops['GA'] // (param_num // n_heads)
    TOP_NUM_GD = n_param_tops['GD'] // (param_num // n_heads)
    print(f"NUM OF TOP ATTN_HEAD GA: {TOP_NUM_GA}")
    print(f"NUM OF TOP ATTN_HEAD GD: {TOP_NUM_GD}")
    
    # 对所有head的importances进行排序，得到排序后的字典
    print("sorting...")
    result_safe_importances = dict(sorted(result_safe_importances.items(), key=lambda item: item[1], reverse=True))
    print(len(result_safe_importances))
    
    result_unsafe_importances = dict(sorted(result_unsafe_importances.items(), key=lambda item: item[1], reverse=True))
    print("sorted successfully")
    
    # 取交集
    print("calculating locations...")
    safe_importances = set(list(result_safe_importances.keys())[:TOP_NUM_GA])
    unsafe_importances = set(list(result_unsafe_importances.keys())[:TOP_NUM_GD])
    intersection_importances = safe_importances.intersection(unsafe_importances)
    
    # 得到各自差集
    result_safe_locations = list(safe_importances - intersection_importances)
    result_unsafe_locations = list(unsafe_importances - intersection_importances)  
    print("calculated successfully")
    
    return result_safe_locations, result_unsafe_locations

def calculate_mask_attn_head_level(save_path, remain_locations: list, viewed_attn_shape: tuple = (4096, 32, 128)) -> dict:
    """
    :param save_path: 保存mask_weight的路径
    :param remain_locations: 保留的head位置
    :param viewed_attn_shape: mask_weight的shape
    :return: mask_weight
    """
    mask_weight = load_from_pickle(save_path)
    if mask_weight is None:
        mask_weight = {}
        viewed_template = torch.zeros(viewed_attn_shape)
        for location in tqdm(remain_locations):
            name, head = location.split(":")
            head = int(head)
            if name not in mask_weight:
                mask = viewed_template.clone()
                mask[:, head, :] = 1
                mask_weight.update({name: mask})
            else:
                mask_weight[name][:, head, :] = 1
        
        # 还原权重矩阵shape，viewed_attn_shape后两维合成一维
        for name in tqdm(mask_weight):
            mask_weight[name] = mask_weight[name].view(viewed_attn_shape[0], -1)
    
    return mask_weight 

def convert_importances_to_indices(importances: dict, n_param_top: int, n_param_all: int, *past_masks_paths) -> list:
    """
    :param importances: 重要性字典
    :param n_param_top: 保留的参数量
    :param past_importances_dirs: 需要剔除的重要性(CL)
    :return: 保留的参数索引
    """
    importances = torch.cat([importance.flatten() for name, importance in importances.items() if "mlp" in name or "attn" in name]).numpy()
    
    if past_masks_paths:
        past_masks_paths = past_masks_paths[0]
        def convert_path(original_path):
            new_path = ""
            if "mask_GA" in original_path:
                new_path = original_path[:original_path.find("neuron/")] + "importances_GA.pkl"
            else:
                new_path = original_path[:original_path.find("neuron/")] + "importances_GD.pkl"
            return new_path
        print(past_masks_paths)
        for past_masks_path in past_masks_paths:
            # 正则匹配检索出TOP_RATE_GA-{}-TOP_RATE_GD-{}
            pattern = r"TOP_RATE_GA-(.*?)-TOP_RATE_GD-(.*?)/"

            # 使用re.findall方法进行匹配和检索
            matches = re.findall(pattern, past_masks_path)[0]
            if 'mask_GA' in past_masks_path:
                top_rate = float(matches[0])
            else:
                top_rate = float(matches[1])
            
            past_n_param_top = int(n_param_all * top_rate)
            past_importances = load_from_pickle(convert_path(past_masks_path))
            past_importances = torch.cat([past_importance.flatten().to('cpu') for name, past_importance in past_importances.items() if "mlp" in name or "attn" in name])

            print("past_n_param_top:", past_n_param_top)
            _, past_indices = torch.topk(past_importances, k=past_n_param_top, largest=True)
            importances[past_indices] = 0
            
    _, indices = torch.topk(torch.tensor(importances), k=n_param_top, largest=True)
    # save_to_pickle(indices, save_path)
    print("convert_importances_to_indices done")

    return indices
    
def calculate_mask_neuron_level(save_path, base_model, indices) -> dict:
    """
    :param save_path: 保存mask_weight的路径
    :param base_model: 原模型
    :param indices: 保留的参数索引(一维)
    :return: mask_weight
    """
    mask_weight = load_from_pickle(save_path)
    if mask_weight is None:
        if not os.path.exists(os.path.dirname(save_path)):
            os.makedirs(os.path.dirname(save_path))
        
        mask_weight = {name: torch.zeros_like(weight) for name, weight in base_model.state_dict().items() if "mlp" in name or "attn" in name}
        weight_shapes = {name: weight.shape for name, weight in mask_weight.items()}
        weight_nums = {name: torch.tensor(shape).prod().item() for name, shape in weight_shapes.items()}
        
        # 将交集索引映射到原模型的参数位置，并直接生成mask_weight
        for i in tqdm(indices):
            name, multi_index = _calculate_weight_location(weight_shapes, weight_nums, i)
            mask_weight[name][multi_index] = 1
        
        # 保存mask_weight
        save_to_pickle(mask_weight, save_path)
    
    return mask_weight

def get_mask_neuron_level(args, save_paths: dict, base_model, n_params: dict):
    """
    :param save_paths: 保存路径 dict
    :param base_model: 原模型
    :param n_param_top: 保留的参数量
    :return: mask_GA, mask_GD
    """
    mask_GA = load_from_pickle(save_paths["mask_GA"])
    mask_GD = load_from_pickle(save_paths["mask_GD"])
    # mask_intersection = load_from_pickle(save_paths["mask_intersection"])
    
    if mask_GA is not None and mask_GD is not None:
        return mask_GA, mask_GD

    if "past_masks_paths_GA" in save_paths and "past_masks_paths_GD" in save_paths:
        indices_GA = convert_importances_to_indices(get_importance_for_each_model(args, save_paths["GA_model"], save_paths["importances_GA"]), n_params["GA"], n_params['all'], save_paths["past_masks_paths_GA"])
        indices_GD = convert_importances_to_indices(get_importance_for_each_model(args, save_paths["GD_model"], save_paths["importances_GD"]), n_params["GD"], n_params['all'], save_paths["past_masks_paths_GD"])
    else:
        indices_GA = convert_importances_to_indices(get_importance_for_each_model(args, save_paths["GA_model"], save_paths["importances_GA"]), n_params["GA"], n_params['all'])
        indices_GD = convert_importances_to_indices(get_importance_for_each_model(args, save_paths["GD_model"], save_paths["importances_GD"]), n_params["GD"], n_params['all'])
    
    indices_GA_unique = np.setdiff1d(np.array(indices_GA), np.array(indices_GD), assume_unique=True).tolist()
    indices_GD_unique = np.setdiff1d(np.array(indices_GD), np.array(indices_GA), assume_unique=True).tolist()
    intersection = np.intersect1d(np.array(indices_GA), np.array(indices_GD), assume_unique=True).tolist()
    
    mask_GA = calculate_mask_neuron_level(save_paths["mask_GA"], base_model, indices_GA_unique)
    mask_GD = calculate_mask_neuron_level(save_paths["mask_GD"], base_model, indices_GD_unique)
    # mask_intersection = calculate_mask_neuron_level(save_paths["mask_intersection"], base_model, intersection)
    
    return mask_GA, mask_GD

def get_mask_head_level(save_paths: dict, base_model, n_param_tops: dict):
    """
    :param save_paths: 保存路径 dict
    :param base_model: 原模型
    :param n_param_top: 保留的参数量
    :return: mask_GA, mask_GD
    """
    
    n_heads, hidden_size = base_model.config.num_attention_heads, base_model.config.hidden_size
    head_dim = hidden_size // n_heads
    safe_locations, unsafe_locations = calculate_head_level_importance(get_importance_for_each_model(save_paths["GD_model"], save_paths["importances_GD"]), get_importance_for_each_model(save_paths["GA_model"], save_paths["importances_GA"]), n_heads, hidden_size, n_param_tops)
    mask_GA = calculate_mask_attn_head_level(save_paths["mask_GA"], safe_locations, (hidden_size, n_heads, head_dim))
    mask_GD = calculate_mask_attn_head_level(save_paths["mask_GD"], unsafe_locations, (hidden_size, n_heads, head_dim))
    
    return mask_GA, mask_GD