from torch.optim import AdamW
import deepspeed
from deepspeed.runtime.activation_checkpointing import checkpointing
import re
from deepspeed.ops.adam import DeepSpeedCPUAdam




def get_lora_param_groups_cpuOpti(model):
	lora_pattern = r".*\.(lora_[A|B]|lora_embedding|adapter|q_proj|v_proj)\..*"
	lora_params = []
	non_lora_params = []
	for name, param in model.named_parameters():
		if re.fullmatch(lora_pattern, name):
			lora_params.append(param)
		else:
			non_lora_params.append(param)
	
	return DeepSpeedCPUAdam(
		model_params = [
		{"params": lora_params, "lr": 4e-5, "weight_decay": 0.0},    # LoRA parameter
		{"params": non_lora_params, "lr": 1e-7, "weight_decay": 0.01} # no LoRA parameter
		]
	)

def get_lora_param_groups(model):
	lora_pattern = r".*\.(lora_[A|B]|lora_embedding|adapter|q_proj|v_proj)\..*"
	lora_params = []
	non_lora_params = []

	for name, param in model.named_parameters():
		if re.fullmatch(lora_pattern, name):
			lora_params.append(param)
		else:
			non_lora_params.append(param)

	return AdamW(
			[
				{"params": lora_params, "lr": 4e-5, "weight_decay": 0.0},    # LoRA parameter
				{"params": non_lora_params, "lr": 1e-7, "weight_decay": 0.01} # no LoRA parameter
			]
		)

def wrap_distributed_model(args, train_dataset, collate_fn, model, ds_config):
	if args.use_optimizer:
		ddp_engine, ddp_optimizer, ddp_train_loader, ddp_scheduler = deepspeed.initialize(
			model=model,
			training_data = train_dataset,
			collate_fn = collate_fn,
			config=ds_config,
		)
	else:
		ddp_engine, ddp_optimizer, ddp_train_loader, ddp_scheduler = deepspeed.initialize(
			model=model,
			training_data = train_dataset,
			collate_fn = collate_fn,
			config=ds_config,
			optimizer = get_lora_param_groups_cpuOpti(model)
		)
	if args.activation_ckpt:
		checkpointing.configure(
			None,
			deepspeed_config=args.deepspeed_config
		)
	return ddp_train_loader, ddp_engine, ddp_optimizer, ddp_scheduler

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for param in model.parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    print(
        f"可训练参数数量: {trainable_params} || "
        f"总参数数量: {all_param} || "
        f"可训练参数占比: {100 * trainable_params / all_param:.2f}%"
    )