import random
import torch
import copy
import pdb

# select client
# select 
def get_clients_this_round(fed_args, round):
    if (fed_args.fed_alg).startswith('local'):
        clients_this_round = [int((fed_args.fed_alg)[-1])]
    else:
        if fed_args.num_clients < fed_args.sample_clients:
            clients_this_round = list(range(fed_args.num_clients))
        else:
            random.seed(round)
            clients_this_round = sorted(random.sample(range(fed_args.num_clients), fed_args.sample_clients))
    return clients_this_round

# dp
def add_noise_to_model(local_dict, sigma):
    if not sigma: return local_dict
    noisy_dict = {}
    for key, value in local_dict.items():
        noise = torch.normal(0, sigma, size=value.size(), device=value.device) 
        print(noise, value)
        pdb.set_trace()
        noisy_dict[key] = value + noise
    return noisy_dict

''' test for specific approx
    '''
def get_delta_W(param_dict, sample_num=1, sample_this_round=1):
    res = {}
    for key in param_dict.keys():
        if ".lora_A.weight" in key:
            # res[key.replace(".lora_A.weight","")] = param_dict[key.replace(".lora_A.weight",".lora_B.weight")] @ param_dict[key] * sample_num / sample_this_round
            tmp = param_dict[key.replace(".lora_A.weight",".lora_B.weight")] @ param_dict[key] * sample_num / sample_this_round
    return res

def cosine_similarity_weights_simplified(model1, model2):
    model1_flat = torch.cat([param.flatten() for param in model1.values()])
    model2_flat = torch.cat([param.flatten() for param in model2.values()])
    dot_product = torch.dot(model1_flat, model2_flat)
    norm1 = torch.norm(model1_flat, p=2)
    norm2 = torch.norm(model2_flat, p=2)
    similarity = dot_product / (norm1 * norm2)
    return similarity.item()
''' test for specific approx
    '''

def global_aggregate(fed_args, global_dict, local_dict_list, sample_num_list, clients_this_round, round_idx, proxy_dict=None, opt_proxy_dict=None, auxiliary_info=None):
    sample_this_round = sum([sample_num_list[client] for client in clients_this_round])
    global_auxiliary = None

    ''' test for specific approx
    '''
    # global_gt_dict = {}
    # local_gt_list = [get_delta_W(local_dict_list[client], sample_num_list[client], sample_this_round) for client in clients_this_round]
    # for key in local_gt_list[0].keys():
    #     global_gt_dict[key] = sum([local_gt_list[client][key] for client in clients_this_round])
    ''' test for specific approx
    '''

    if fed_args.fed_alg == 'scaffold':
        for key in global_dict.keys():
            global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])
        global_auxiliary, auxiliary_delta_dict = auxiliary_info
        for key in global_auxiliary.keys():
            delta_auxiliary = sum([auxiliary_delta_dict[client][key] for client in clients_this_round]) 
            global_auxiliary[key] += delta_auxiliary / fed_args.num_clients
        # offloading for scaffold
        for client_idx in clients_this_round:
            local_model_dict = local_dict_list[client_idx]
            for key in local_model_dict.keys():
                if isinstance(local_model_dict[key], torch.Tensor) and local_model_dict[key].device.type != 'cpu':
                    local_model_dict[key] = local_model_dict[key].to('cpu')
            delta_aux_dict = auxiliary_delta_dict[client_idx]
            for key in delta_aux_dict.keys():
                if isinstance(delta_aux_dict[key], torch.Tensor) and delta_aux_dict[key].device.type != 'cpu':
                    delta_aux_dict[key] = delta_aux_dict[key].to('cpu')
        breakpoint()
    
    elif fed_args.fed_alg == 'fedavgm':
        # Momentum-based FedAvg
        for key in global_dict.keys():
            delta_w = sum([(local_dict_list[client][key] - global_dict[key]) * sample_num_list[client] / sample_this_round for client in clients_this_round])
            proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
            global_dict[key] = global_dict[key] + proxy_dict[key]

    elif fed_args.fed_alg == 'fedadagrad':
        for key, param in opt_proxy_dict.items():
            delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
            # In paper 'adaptive federated optimization', momentum is not used
            proxy_dict[key] = delta_w
            opt_proxy_dict[key] = param + torch.square(proxy_dict[key])
            global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

    elif fed_args.fed_alg == 'fedyogi':
        for key, param in opt_proxy_dict.items():
            delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
            proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
            delta_square = torch.square(proxy_dict[key])
            opt_proxy_dict[key] = param - (1-fed_args.fedopt_beta2)*delta_square*torch.sign(param - delta_square)
            global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

    elif fed_args.fed_alg == 'fedadam':
        for key, param in opt_proxy_dict.items():
            delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
            proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
            opt_proxy_dict[key] = fed_args.fedopt_beta2*param + (1-fed_args.fedopt_beta2)*torch.square(proxy_dict[key])
            global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

    elif fed_args.fed_alg == 'flora':    
        '''stack A and B respectively'''
        for key in global_dict.keys():
            if 'lora_A' in key:
                global_dict[key] = torch.cat([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round], dim=0)
            else:
                global_dict[key] = torch.cat([local_dict_list[client][key] for client in clients_this_round], dim=1)
        get_delta_W(global_dict)

    elif fed_args.fed_alg == 'fedsa_lora':
        # 只聚合lora_A矩阵的更新
        # breakpoint()
        for key in global_dict.keys():
            if 'lora_A' in key:  # 只处理lora_A矩阵参数
                global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])

    elif fed_args.fed_alg == 'fedavg_lora':
        global_auxiliary, auxiliary_local_dict, agg_stage, lora_rank  = auxiliary_info
        if agg_stage == "comm_1":
            '''random projection and get Q'''
            global_Q = {}
            for key in global_auxiliary.keys():
                Y = sum([auxiliary_local_dict[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])
                global_Q[key], _ = torch.linalg.qr(Y)
            global_auxiliary = copy.deepcopy(global_Q)
        
        elif agg_stage == "comm_2":
            '''revise projection and get final svd'''
            for key, Q in global_auxiliary.items():
                P = sum([auxiliary_local_dict[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])
                U, S, Vh = torch.linalg.svd(P.T, full_matrices=False)

                global_dict[f"{key}.lora_B.weight"] = Q @ (U[:, :lora_rank] * torch.sqrt(S[:lora_rank]))
                global_dict[f"{key}.lora_A.weight"] = torch.sqrt(S[:lora_rank]).unsqueeze(1) * Vh[:lora_rank, :]
            ''' test for specific approx
            '''
            # global_alg_dict = get_delta_W(global_dict)
            # print(f"the similarity of the global error is: {cosine_similarity_weights_simplified(global_alg_dict, global_gt_dict)}")
            ''' test for specific approx
            '''
    
    elif fed_args.fed_alg == 'fedela':
        global_B = {}
        for key in global_dict.keys():
            if "lora_B" in key:
                global_B[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])

        global_auxiliary = copy.deepcopy(global_B)

    else:   # Normal dataset-size-based aggregation 
        # dp
        # for client in clients_this_round:
        #     local_dict_list[client] = add_noise_to_model(local_dict_list[client], fed_args.dp_sigma)
    
        for key in global_dict.keys():
            global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])
        
        ''' test for specific approx
        '''
        # global_alg_dict = get_delta_W(global_dict)
        # res = cosine_similarity_weights_simplified(global_alg_dict, global_gt_dict)
        # print(f"the similarity of the global error is: {res}")
        # breakpoint()
        ''' test for specific approx
        '''

    return global_dict, global_auxiliary