import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

def mix_parameters(clients, W):
    # Pack parameters from all clients
    param_vectors = [torch.nn.utils.parameters_to_vector(client.model.parameters()) 
                     for client in clients]
    stacked_params = torch.stack(param_vectors)
    # Mix parameters using the mixing matrix W
    mixed_params = torch.matmul(W, stacked_params)
    # Unpack mixed parameters back to clients
    for i, client in enumerate(clients):
        torch.nn.utils.vector_to_parameters(mixed_params[i], client.model.parameters())


def resolve_submodule(model, name):
    """Resolve a dotted path like 'encoders.0' or 'projector.layer1'."""
    attrs = name.split(".")
    submodule = model
    for attr in attrs:
        if attr.isdigit():
            if int(attr) > 0:
                submodule = submodule[-1]
        else:
            submodule = getattr(submodule, attr)
    return submodule

def mix_partial_parameters(clients, W, part="i_encoder"):
    if part == 'all':
        mix_parameters(clients, W)
        return

    # Resolve the target submodules for each client
    param_targets = [resolve_submodule(client.model, part) for client in clients]

    # Extract and stack parameter vectors
    param_vectors = [
        torch.nn.utils.parameters_to_vector(module.parameters()).detach().clone()
        for module in param_targets
    ]
    stacked_params = torch.stack(param_vectors)

    # Mix parameters using the mixing matrix W
    mixed_params = torch.matmul(W, stacked_params)

    # Apply mixed parameters back to each client's module
    for i, module in enumerate(param_targets):
        torch.nn.utils.vector_to_parameters(mixed_params[i], module.parameters())