import torch


def average_lora_fedplora(args, global_model, loc_updates):
    """
    Compute the FedPLoRA style averaging of LoRA and classifier parameters.

    This function collects all LoRA related parameters and classifier parameters
    from local client updates, averages them, and then applies the averaged
    update on top of the current global model state.

    Arguments
    ---------
    args:
        Argument object that must contain the attribute agg_mode.
        Currently expects agg_mode to be "ParallelFedAgg".
    global_model:
        A state dict of the current global model.
    loc_updates:
        A list of state dicts, each representing one client model update.

    Returns
    -------
    global_model:
        Updated global state dict after averaging and applying selected keys.
    global_updates:
        A dict that stores the averaged update tensor for each updated key.
    """
    model_update_avg_dict = {}

    # Collect all local updates for LoRA related and classifier parameters
    for k in global_model.keys():
        # LoRA parameters
        if "lora" in k:
            for loc_update in loc_updates:
                if args.agg_mode == "ParallelFedAgg":
                    sub_keys = [k]
                else:
                    raise Exception("Not support for agg_mode: {}".format(args.agg_mode))

                for sub_k in sub_keys:
                    if sub_k in loc_update:
                        if k not in model_update_avg_dict:
                            model_update_avg_dict[k] = []
                        model_update_avg_dict[k].append(loc_update[sub_k])

        # Classifier parameters (example: linear head for classification)
        elif "classifier" in k:
            for loc_update in loc_updates:
                if k in loc_update:
                    if k not in model_update_avg_dict:
                        model_update_avg_dict[k] = []
                    model_update_avg_dict[k].append(loc_update[k])

    global_updates = {}

    # Average collected updates and apply them to the global model
    for k in global_model.keys():
        if k in model_update_avg_dict:
            # Filter out all zero tensors (no update) if any
            non_zero_updates = [
                u for u in model_update_avg_dict[k] if torch.sum(u) != 0
            ]

            if len(non_zero_updates) > 0:
                # Average over all non zero local updates for this key
                stacked = torch.stack(non_zero_updates)
                global_updates[k] = stacked.mean(dim=0)

                # Apply the averaged update on top of the current global value
                global_model[k] = global_model[k].detach().cpu() + global_updates[k]

    return global_model, global_updates