from utils import *
import argparse

# 创建参数解析器
parser = argparse.ArgumentParser()

# 添加命令行参数
parser.add_argument('--GA_model_path', type=str, help='GA模型路径')
parser.add_argument('--GD_model_path', type=str, help='GD模型路径')
parser.add_argument('--base_model_path', type=str, help='基础模型路径')
parser.add_argument('--result_model_dir', type=str, help='结果模型目录')
parser.add_argument('--cache_dir', type=str, help='缓存目录')
parser.add_argument('--granularity', type=str, help='粒度')
parser.add_argument('--TOP_RATE_GA', type=float, help='GA的TOP_RATE')
parser.add_argument('--TOP_RATE_GD', type=float, help='GD的TOP_RATE')
parser.add_argument('--DARE_DROP_RATE_GA', type=float, help='GA的DARE_DROP_RATE')
parser.add_argument('--DARE_DROP_RATE_GD', type=float, help='GD的DARE_DROP_RATE')
parser.add_argument('--SCALE_RATE_GA', type=float, help='GA的SCALE_RATE')
parser.add_argument('--SCALE_RATE_GD', type=float, help='GD的SCALE_RATE')
parser.add_argument('--fp16', type=bool, default=True)
parser.add_argument('--data_dir', type=str)
parser.add_argument('--max_source_length', type=int, default=128)
parser.add_argument('--max_target_length', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--sample_size', type=int, default=100)
parser.add_argument('--past_masks_paths_GA', type=str)
parser.add_argument('--past_masks_paths_GD', type=str)

# 解析命令行参数
args = parser.parse_args()

# Accessing the values
cache_dir = args.cache_dir + '/'
granularity = args.granularity
TOP_RATE_GA = args.TOP_RATE_GA
TOP_RATE_GD = args.TOP_RATE_GD
DARE_DROP_RATE_GA = args.DARE_DROP_RATE_GA
DARE_DROP_RATE_GD = args.DARE_DROP_RATE_GD
SCALE_RATE_GA = args.SCALE_RATE_GA
SCALE_RATE_GD = args.SCALE_RATE_GD

METHODS = {
    'neuron': get_mask_neuron_level,
    'head': get_mask_head_level
}

save_paths = {
    'GA_model': args.GA_model_path,
    'GD_model': args.GD_model_path,
    'base_model': args.base_model_path,
    'mask_GA': cache_dir + granularity + '/' + f"TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}" + '/' + 'mask_GA.pkl',
    'mask_GD': cache_dir + granularity + '/' + f"TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}" + '/' + 'mask_GD.pkl',
    'importances_GA': cache_dir + 'importances_GA.pkl',
    'importances_GD': cache_dir + 'importances_GD.pkl',
    'result_model': f"{args.result_model_dir}/SNIP_DARE/granularity-{granularity}-TOP_RATE_GA-{TOP_RATE_GA}-TOP_RATE_GD-{TOP_RATE_GD}-DARE_DROP_RATE_GA-{DARE_DROP_RATE_GA}-DARE_DROP_RATE_GD-{DARE_DROP_RATE_GD}-SCALE_RATE_GA-{SCALE_RATE_GA}-SCALE_RATE_GD-{SCALE_RATE_GD}",
    'past_masks_paths_GA': args.past_masks_paths_GA.split(',,'),
    'past_masks_paths_GD': args.past_masks_paths_GD.split(',,')
}

for key in save_paths:
    print(f"{key}: {save_paths[key]}")

get_mask = METHODS[granularity]
# torch.set_num_threads(100) 

def main():  
    # load base model
    base_model = LlamaForCausalLM.from_pretrained(save_paths['base_model'])
    n_param_all = sum(p.numel() for p in base_model.parameters())
    n_param_top_GA = int(n_param_all * TOP_RATE_GA)
    n_param_top_GD = int(n_param_all * TOP_RATE_GD)
    print(f"n_param_all: {n_param_all}, n_param_top_GA: {n_param_top_GA}, n_param_top_GD: {n_param_top_GD}")
    
    n_params = {'GA': n_param_top_GA, 'GD': n_param_top_GD, 'all': n_param_all}
    mask_GA, mask_GD = get_mask(args, save_paths, base_model, n_params)
    
    # 计算保留的比例
    safe_remain_rate = sum(p.flatten().sum() for p in list(mask_GA.values())) / n_param_all
    unsafe_remain_rate = sum(p.flatten().sum() for p in list(mask_GD.values())) / n_param_all

    print(f"safe_remain_rate: {safe_remain_rate}, unsafe_remain_rate: {unsafe_remain_rate}")
    
    DARE_REMAIN_RATE_GA_ELSE = (1 - DARE_DROP_RATE_GA) - safe_remain_rate
    DARE_REMAIN_RATE_GD_ELSE = (1 - DARE_DROP_RATE_GD) - unsafe_remain_rate
    
    print(f"DARE_REMAIN_RATE_GA_ELSE: {DARE_REMAIN_RATE_GA_ELSE}, DARE_REMAIN_RATE_GD_ELSE: {DARE_REMAIN_RATE_GD_ELSE}")
    
    GA_model = LlamaForCausalLM.from_pretrained(save_paths['GA_model'])
    GD_model = LlamaForCausalLM.from_pretrained(save_paths['GD_model'])
    base_tokenizer = AutoTokenizer.from_pretrained(save_paths['base_model'])
    
    # merge  
    print("begin to merge")
    for name in tqdm(base_model.state_dict()):
        # TODO: 是否需要单独控制TOP部分的放缩率？       
        if name not in mask_GA:
            mask_GA[name] = torch.zeros_like(base_model.state_dict()[name])
        random_tensor = torch.rand(base_model.state_dict()[name].shape)
        mask_GA[name][random_tensor <= DARE_REMAIN_RATE_GA_ELSE] = 1
        safe_delta_weight = (GA_model.state_dict()[name] - base_model.state_dict()[name]) * mask_GA[name] / (1 - DARE_DROP_RATE_GA)
            
        if name not in mask_GD:
            mask_GD[name] = torch.zeros_like(base_model.state_dict()[name])
        random_tensor = torch.rand(base_model.state_dict()[name].shape)
        mask_GD[name][random_tensor <= DARE_REMAIN_RATE_GD_ELSE] = 1
        unsafe_delta_weight = (GD_model.state_dict()[name] - base_model.state_dict()[name]) * mask_GD[name] / (1 - DARE_DROP_RATE_GD)
        
        # # DARE 稀疏 unsafe_dalta_weight
        # if DARE_DROP_RATE_GD > 0:
            # random_tensor = torch.rand(unsafe_delta_weight.shape)
            # unsafe_delta_weight[random_tensor <= DARE_DROP_RATE_GD] = 0
            # unsafe_delta_weight /= (1 - DARE_DROP_RATE_GD)
        # # DARE 稀疏 safe_delta_weight
        # if DARE_DROP_RATE_GA > 0:
        #     random_tensor = torch.rand(safe_delta_weight.shape)
        #     safe_delta_weight[random_tensor <= DARE_DROP_RATE_GA] = 0
        #     safe_delta_weight /= (1 - DARE_DROP_RATE_GA)
        
        base_model.state_dict()[name].copy_(base_model.state_dict()[name] + SCALE_RATE_GA*safe_delta_weight + SCALE_RATE_GD*unsafe_delta_weight) 

    # 保存结果
    base_model = base_model.half()
    base_model.save_pretrained(save_paths["result_model"])
    base_tokenizer.save_pretrained(save_paths["result_model"])
    
if __name__ == '__main__':
    main()