import json
from utils import *
import argparse

# 创建参数解析器
parser = argparse.ArgumentParser()

# 添加命令行参数
parser.add_argument('--GA_model_path', type=str, help='GA模型路径')
parser.add_argument('--GD_model_path', type=str, help='GD模型路径')
parser.add_argument('--base_model_path', type=str, help='基础模型路径')
parser.add_argument('--result_model_dir', type=str, help='结果模型目录')
parser.add_argument('--cache_dir', type=str, help='缓存目录')
parser.add_argument('--granularity', type=str, help='粒度')
parser.add_argument('--TOP_RATE_GA', type=float, help='GA的TOP_RATE')
parser.add_argument('--TOP_RATE_GD', type=float, help='GD的TOP_RATE')
parser.add_argument('--DARE_DROP_RATE_GA', type=float, help='GA的DARE_DROP_RATE')
parser.add_argument('--DARE_DROP_RATE_GD', type=float, help='GD的DARE_DROP_RATE')
parser.add_argument('--SCALE_RATE_GA', type=float, help='GA的SCALE_RATE')
parser.add_argument('--SCALE_RATE_GD', type=float, help='GD的SCALE_RATE')
parser.add_argument('--fp16', type=bool, default=True)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--max_source_length', type=int, default=128)
parser.add_argument('--max_target_length', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--sample_size', type=int, default=100)

# 解析命令行参数
args = parser.parse_args()

# Accessing the values
cache_dir = args.cache_dir + '/'
granularity = args.granularity
TOP_RATE_GA = args.TOP_RATE_GA
TOP_RATE_GD = args.TOP_RATE_GD
DARE_DROP_RATE_GA = args.DARE_DROP_RATE_GA
DARE_DROP_RATE_GD = args.DARE_DROP_RATE_GD
SCALE_RATE_GA = args.SCALE_RATE_GA
SCALE_RATE_GD = args.SCALE_RATE_GD

METHODS = {
    'neuron': get_mask_neuron_level,
    'head': get_mask_head_level
}

save_paths = {
    'GA_model': args.GA_model_path,
    'GD_model': args.GD_model_path,
    'base_model': args.base_model_path,
    'mask_GA': cache_dir + granularity + '/' + f"TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}" + '/' + 'mask_GA.pkl',
    'mask_GD': cache_dir + granularity + '/' + f"TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}" + '/' + 'mask_GD.pkl',
    'mask_intersection': cache_dir + granularity + '/' + f"TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}" + '/' + 'mask_intersection.pkl',
    'importances_GA': cache_dir + 'importances_GA.pkl',
    'importances_GD': cache_dir + 'importances_GD.pkl',
    'result_model': f"{args.result_model_dir}/SNIP_DARE/granularity-{granularity}-TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}-DARE_DROP_RATE_GA-{DARE_DROP_RATE_GA}-DARE_DROP_RATE_GD-{DARE_DROP_RATE_GD}-SCALE_RATE_GA-{SCALE_RATE_GA}-SCALE_RATE_GD-{SCALE_RATE_GD}"
}

for key in save_paths:
    print(f"{key}: {save_paths[key]}")

get_mask = METHODS[granularity]
# torch.set_num_threads(100) 

def main():  
    # load base model
    base_model = LlamaForCausalLM.from_pretrained(save_paths['base_model'])
    n_param_all = sum(p.numel() for p in base_model.parameters())
    n_param_top_GA = int(n_param_all * TOP_RATE_GA)
    n_param_top_GD = int(n_param_all * TOP_RATE_GD)
    print(f"n_param_all: {n_param_all}, n_param_top_GA: {n_param_top_GA}, n_param_top_GD: {n_param_top_GD}")
    
    n_param_tops = {'GA': n_param_top_GA, 'GD': n_param_top_GD, 'all': n_param_all}
    mask_GA, mask_GD = get_mask(args, save_paths, base_model, n_param_tops)
    
    # 计算保留的比例
    remain_weight_GA = {name.replace('model.layers.', ''): mask_GA[name].flatten().sum().item() for name in mask_GA}
    remain_weight_GD = {name.replace('model.layers.', ''): mask_GD[name].flatten().sum().item() for name in mask_GD}
    
    # # 写入文件
    # with open('importance_safe.json', 'w') as f:
    #     json.dump(remain_weight_GA, f)
    #
    # with open('importance_unsafe.json', 'w') as f:
    #     json.dump(remain_weight_GD, f)
    safe_remain_rate = sum(list(remain_weight_GA.values())) / n_param_all
    unsafe_remain_rate = sum(list(remain_weight_GA.values())) / n_param_all

    print(f"safe_remain_rate: {safe_remain_rate}, unsafe_remain_rate: {unsafe_remain_rate}")
    
    print(f"remain_weight_GA: {remain_weight_GA}, remain_weight_GD: {remain_weight_GD}")
    
    DARE_REMAIN_RATE_GA_ELSE = (1 - DARE_DROP_RATE_GA) - safe_remain_rate
    DARE_REMAIN_RATE_GD_ELSE = (1 - DARE_DROP_RATE_GD) - unsafe_remain_rate
    
    print(f"DARE_REMAIN_RATE_GA_ELSE: {DARE_REMAIN_RATE_GA_ELSE}, DARE_REMAIN_RATE_GD_ELSE: {DARE_REMAIN_RATE_GD_ELSE}")
    
    GA_model = LlamaForCausalLM.from_pretrained(save_paths['GA_model'])
    GD_model = LlamaForCausalLM.from_pretrained(save_paths['GD_model'])
    base_tokenizer = AutoTokenizer.from_pretrained(save_paths['base_model'])
    
    # merge
    print("begin to merge")
    for name in tqdm(base_model.state_dict()):
        
        # 带intersecion版本
        # if name not in mask_GA:
        #     mask_GA[name] = torch.zeros_like(base_model.state_dict()[name])
        #     mask_intersection[name] = torch.zeros_like(base_model.state_dict()[name])
        # random_tensor = torch.rand(base_model.state_dict()[name].shape)
        # mask_GA[name][random_tensor <= DARE_REMAIN_RATE_GA_ELSE] = 1
        # safe_delta_weight = (GA_model.state_dict()[name] - base_model.state_dict()[name]) * (mask_GA[name] + mask_intersection[name]) / (1 - DARE_DROP_RATE_GA)
            
        # if name not in mask_GD:
        #     mask_GD[name] = torch.zeros_like(base_model.state_dict()[name])
        #     mask_intersection[name] = torch.zeros_like(base_model.state_dict()[name])
        # random_tensor = torch.rand(base_model.state_dict()[name].shape)
        # mask_GD[name][random_tensor <= DARE_REMAIN_RATE_GD_ELSE] = 1
        # unsafe_delta_weight = (GD_model.state_dict()[name] - base_model.state_dict()[name]) * (mask_GD[name] + mask_intersection[name]) / (1 - DARE_DROP_RATE_GD)
        
        # 正常版本
        
        if name not in mask_GA:
            mask_GA[name] = torch.zeros_like(base_model.state_dict()[name])
        random_tensor = torch.rand(base_model.state_dict()[name].shape)
        # mask_GA[name][random_tensor <= DARE_REMAIN_RATE_GA_ELSE] = 1
        safe_delta_weight = (GA_model.state_dict()[name] - base_model.state_dict()[name]) * mask_GA[name]
            
        if name not in mask_GD:
            mask_GD[name] = torch.zeros_like(base_model.state_dict()[name])
        random_tensor = torch.rand(base_model.state_dict()[name].shape)
        # mask_GD[name][random_tensor <= DARE_REMAIN_RATE_GD_ELSE] = 1
        unsafe_delta_weight = (GD_model.state_dict()[name] - base_model.state_dict()[name]) * mask_GD[name]
        
        base_model.state_dict()[name].copy_(base_model.state_dict()[name] + SCALE_RATE_GA*safe_delta_weight + SCALE_RATE_GD*unsafe_delta_weight)

    # 保存结果
    base_model = base_model.half()
    base_model.save_pretrained(save_paths["result_model"])
    base_tokenizer.save_pretrained(save_paths["result_model"])
    
if __name__ == '__main__':
    main()