from .clients import (
    SupervisedClient, 
    FixMatchClient, 
    OpenMatchClient, 
    OursClient,
    SCOMatchClient,
    ProSubClient,
    )
from .server import (
    SupervisedServer, 
    OpenMatchServer,
    OursServer,
    SCOMatchServer,
    ProSubServer,
    )
from .strategy import (
    FedAvgFinetuneStrategy, 
    FedOursFinetuneStrategy,
    FedProSubFinetuneStrategy,
    )
from .network import OpenNet, OursNet, ProSubNet
from src.core.base import BaseStrategy

name2strategy = {
    'FedAvg': BaseStrategy,
    'FedAvg_finetune': FedAvgFinetuneStrategy,
    'FedOurs_finetune': FedOursFinetuneStrategy,
    'FedProSub_finetune': FedProSubFinetuneStrategy,
    }
name2server_alg = {
    'supervised': SupervisedServer,
    'openmatch': OpenMatchServer,
    'ours': OursServer,
    'scomatch': SCOMatchServer,
    'prosub': ProSubServer,
    
    }
name2client_alg = {
    'supervised': SupervisedClient,
    'fixmatch': FixMatchClient,
    'openmatch': OpenMatchClient,
    'ours': OursClient,
    'scomatch': SCOMatchClient,
    'prosub': ProSubClient,
    }
def get_strategy(alg, **kwargs):
    try:
        strategy = name2strategy[alg](**kwargs)
    except KeyError as e:

        return None
    return strategy


def get_server_alg(alg, config, net_builder, 
                   train_loader, test_loader, logger, 
                   **kwargs):
    
    if alg not in name2server_alg:
        return None
        
    try:
        server_class = name2server_alg[alg]

        server_instance = server_class(config=config,
                                      net_builder=net_builder,
                                      train_loader=train_loader,
                                      test_loader=test_loader, 
                                      logger=logger, 
                                      **kwargs)

        return server_instance
    except Exception as e:

        import traceback
        traceback.print_exc()
        return None


def get_client_alg(alg, cid, config, net_builder, train_loader, **kwargs):
    try:
        alg = name2client_alg[alg](cid=cid,
                                   config=config,
                                   net_builder=net_builder,
                                   train_loader=train_loader, **kwargs)
    except Exception as e:
        print(f"[DEBUG] Client creation failed for {alg}: {type(e).__name__}: {str(e)}")
        return None
    return alg