import torch
import copy 

def get_proxy_dict(fed_args, global_dict):
    opt_proxy_dict = None
    proxy_dict = None
    if fed_args.fed_alg in ['fedadagrad', 'fedyogi', 'fedadam']:
        proxy_dict, opt_proxy_dict = {}, {}
        for key in global_dict.keys():
            proxy_dict[key] = torch.zeros_like(global_dict[key])
            opt_proxy_dict[key] = torch.ones_like(global_dict[key]) * fed_args.fedopt_tau**2
    elif fed_args.fed_alg == 'fedavgm':
        proxy_dict = {}
        for key in global_dict.keys():
            proxy_dict[key] = torch.zeros_like(global_dict[key])
    return proxy_dict, opt_proxy_dict

def get_auxiliary_dict(fed_args, global_dict):

    if fed_args.fed_alg in ['scaffold']:
        global_auxiliary = {}               # c in SCAFFOLD
        for key in global_dict.keys():
            global_auxiliary[key] = torch.zeros_like(global_dict[key])

        # for gpu memory preserving, offloading to cpu
        auxiliary_model_list = []
        auxiliary_delta_dict = []

        for i in range(fed_args.num_clients):
            local_auxiliary_dict_cpu = {}
            local_delta_dict_cpu = {}

            for key in global_auxiliary.keys():
                # 为本地 auxiliary 变量创建零张量，并直接放在 CPU 上
                # 使用 global_auxiliary 的结构/形状信息
                local_auxiliary_dict_cpu[key] = torch.zeros_like(global_auxiliary[key], device='cpu')

                # 为本地 delta 变量创建零张量，并直接放在 CPU 上
                local_delta_dict_cpu[key] = torch.zeros_like(global_auxiliary[key], device='cpu')

            # 将为当前客户端创建的 CPU 字典添加到列表中
            auxiliary_model_list.append(local_auxiliary_dict_cpu)
            auxiliary_delta_dict.append(local_delta_dict_cpu)
            
        
        # auxiliary_model_list = [copy.deepcopy(global_auxiliary) for _ in range(fed_args.num_clients)]    # c_i in SCAFFOLD
        # auxiliary_delta_dict = [copy.deepcopy(global_auxiliary) for _ in range(fed_args.num_clients)]    # delta c_i in SCAFFOLD

    elif fed_args.fed_alg in ["fedavg_lora"]:
        global_auxiliary = {}
        for name, param in global_dict.items():
            if 'lora_A' in name:
                r, k = param.shape
                s = fed_args.omega_rsvd
                Omega = torch.randn(k, s)
                # Omega = torch.cat((torch.randn(k, s), param.T), dim=0)
                name_ori = name.replace('.lora_A.weight', '')
                global_auxiliary[name_ori] = Omega
        
        auxiliary_model_list = [None for _ in range(fed_args.num_clients)]    
        auxiliary_delta_dict = [copy.deepcopy(global_auxiliary) for _ in range(fed_args.num_clients)]    
    
    elif fed_args.fed_alg in ["fedela"]:
        global_auxiliary = {}
        for name, param in global_dict.items():
            if 'lora_B' in name:
                m, r = param.shape
                s = fed_args.omega_rsvd
                Omega = torch.randn(m, s)
                name_ori = name.replace('.lora_B.weight', '')
                global_auxiliary[name_ori] = Omega

        auxiliary_model_list = [None]*fed_args.num_clients
        auxiliary_delta_dict = [copy.deepcopy(global_auxiliary) for _ in range(fed_args.num_clients)]  
    
    else:
        global_auxiliary = None
        auxiliary_model_list = [None]*fed_args.num_clients
        auxiliary_delta_dict = [None]*fed_args.num_clients

    return global_auxiliary, auxiliary_model_list, auxiliary_delta_dict


def freeze_lora_A(model):
    for name, param in model.named_parameters():
        if "lora_A" in name:
            param.requires_grad = False
            # print(f"Froze parameter {name}")