import torch
import os
import json

@torch.no_grad()
def merge_lora_modules(model, lora_client_map, device="cuda", log_file=None):
    """
    将同组的多个 LoRA 模块（A_i, B_i）融合成一个平均模块，减少显存占用。
    Args:
        model: 已加载的 LoRA 模型（带多个 lora_A*, lora_B*）
        lora_client_map: dict, group_id -> [client_ids]
        device: CUDA 设备
        log_file: 日志文件路径
    """
    log_msgs = []
    log_msgs.append("=== Merging LoRA modules based on lora_client_map ===")
    merged_state = model.state_dict()
    updated_state = {}

    for group_idx, client_ids in lora_client_map.items():
        group_idx = int(group_idx)
        log_msgs.append(f"Processing group {group_idx} -> clients {client_ids}")

        # 收集所有同组 LoRA A/B 参数
        lora_A_keys = [k for k in merged_state.keys() if f"lora_A" in k]
        lora_B_keys = [k for k in merged_state.keys() if f"lora_B" in k]

        for proj_type in ["query", "value"]:
            # 找出该类型的所有 A/B 模块
            group_A_tensors, group_B_tensors = [], []

            for cid in client_ids:
                a_key = f"{proj_type}.lora_A{cid}.weight"
                b_key = f"{proj_type}.lora_B{cid}.weight"
                if a_key in merged_state and b_key in merged_state:
                    group_A_tensors.append(merged_state[a_key].to(device))
                    group_B_tensors.append(merged_state[b_key].to(device))

            if len(group_A_tensors) == 0:
                continue  # 该组没找到对应的模块

            # 均值聚合
            avg_A = torch.mean(torch.stack(group_A_tensors), dim=0)
            avg_B = torch.mean(torch.stack(group_B_tensors), dim=0)

            new_A_key = f"{proj_type}.lora_A{group_idx}.weight"
            new_B_key = f"{proj_type}.lora_B{group_idx}.weight"

            updated_state[new_A_key] = avg_A
            updated_state[new_B_key] = avg_B

            log_msgs.append(f"  → Merged {proj_type} group {group_idx}: {len(group_A_tensors)} modules fused")

    # 加载新的权重
    merged_state.update(updated_state)
    model.load_state_dict(merged_state, strict=False)

    if log_file:
        with open(log_file, "a") as f:
            for line in log_msgs:
                f.write(line + "\n")
    else:
        for line in log_msgs:
            print(line)

    return model
