from Topology import RandomGossip, Ring, Torus, FedLay
import importlib

def create_topologies(topology_type, sizes, neighbors, sample_type=None):
    """Helper to create topologies for each group size"""
    topo_class = {
        "random": lambda s: RandomGossip(s, neighbors),
        "ring": lambda s: Ring(s),
        "torus": lambda s: Torus(s),
        "fedlay": lambda s: FedLay(s, 3, sample_type),
    }[topology_type]
    return [topo_class(size) for size in sizes]

def get_groups(configs, client_dataloaders, encoder_fns, classifier_dim, output_dim):


    client_lib = importlib.import_module('Clients.baseclients')

    client_class_map = {
        "modality": "ModalityClient",
        "task": "TaskClient",
        "parse": "ParseClient",
        "hybrid": "HybridClient"
    }
    
    class_name = client_class_map.get(configs.mode)
    if class_name is None:
        raise ValueError(f"Unknown mode: {configs.mode}")

    TargetClient = getattr(client_lib, class_name)
    
    
    topology_type = configs.topology
    neighbors = configs.n_neighbours
    sample_type = configs.sample_type
    modality_num = 3 if configs.dataset == "cap" else 2
    group_size = configs.group_size
    n_clients = configs.n_clients

    sizes = [group_size] * modality_num + [n_clients - modality_num * group_size]
    print(sizes)
    modality_groups = [[] for _ in range(modality_num)]
    client_groups = [[] for _ in range(modality_num + 1)]
    whole_group = []

    client_id = 0

    # Create unimodal clients
    for m in range(modality_num):
        for _ in range(group_size):
            modality_mask = [0] * modality_num
            modality_mask[m] = 1
            client = TargetClient(
                dataloader=client_dataloaders[client_id],
                modality_type=modality_mask,
                client_id=client_id,
                encoder_fns=encoder_fns,
                classifier_dim=classifier_dim,
                output_dim=output_dim,
                configs=configs,
            )
            modality_groups[m].append(client)
            client_groups[m].append(client)
            whole_group.append(client)
            client_id += 1

    # Create multimodal clients (with all modalities)
    for _ in range(sizes[-1]):
        client = TargetClient(
            dataloader=client_dataloaders[client_id],
            modality_type=[1] * modality_num,
            client_id=client_id,
            encoder_fns=encoder_fns,
            classifier_dim=classifier_dim,
            output_dim=output_dim,
            configs=configs,
        )
        for m in range(modality_num):
            modality_groups[m].append(client)
        client_groups[-1].append(client)
        whole_group.append(client)
        client_id += 1

    modality_topologies = create_topologies(topology_type, [len(g) for g in modality_groups], neighbors, sample_type)
    task_topologies = create_topologies(topology_type, [len(g) for g in client_groups], neighbors, sample_type)
    global_topologies = create_topologies(topology_type, [len(whole_group)], neighbors, sample_type)

    return modality_groups, client_groups, modality_topologies, task_topologies, whole_group, global_topologies