from server.server_special import server_special
from server.server_MFCL import server_MFCL
from server.server_FDCL2 import server_FDCL2
from server.server_FedCIL import server_FedCIL
from server.server_pFedDIL import server_pFedDIL
from server.server_FLwF import server_FLwF
from server.server_SRFDIL import server_SRFDIL


def get_server(args, train_data, test_data, max_class, method_name, task_id, previous_test=None, theta_reg=None, **kwargs):
    match method_name:
        case 'SPECIAL':
            return server_special(args, train_data, test_data, max_class, method_name, task_id, previous_test, theta_reg)
        case 'SPECIAL-C':
            return server_FDCL2(args, train_data, test_data, max_class, method_name, task_id, previous_test,
                         theta_reg)
        case 'FedCIL':
            return server_FedCIL(args, train_data, test_data, max_class, method_name, task_id, previous_test,
                                 theta_reg)
        case 'pFedDIL':
            return server_pFedDIL(args, train_data, test_data, max_class, method_name, task_id, previous_test,
                                  theta_reg,
                                  **kwargs)

        case 'SRFDIL':
            return server_SRFDIL(args, train_data, test_data, max_class, method_name, task_id, previous_test, theta_reg,
                                 **kwargs)
        case 'FLwF':
            return server_FLwF(args, train_data, test_data, max_class, method_name, task_id, previous_test,
                               theta_reg)
        case 'MFCL':
            return server_MFCL(args, train_data, test_data, max_class, method_name, task_id, previous_test, theta_reg,
                               **kwargs)
