from libs import *
from dataset import *
from models import *
from train import *
from logger import logger

def train_Median(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = median(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_Krum(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = krum(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_GM(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = raga(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_MCA(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = mca(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_CClip(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = cp.concatenate(client_x, axis=0)
    # cent_y = cp.concatenate(client_y, axis=0)

    # train_perf_sel = cp.zeros((com_amount, 2))
    test_perf_sel = cp.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        if i == 0:
            avg_model, avg_grad = cclip(
                avg_model,
                client_gradients,
                learning_rate / (3 * i / 500 + 1),
                None,
            )
        else:
            avg_model, avg_grad = cclip(
                avg_model,
                client_gradients,
                learning_rate / (3 * i / 500 + 1),
                avg_grad,
            )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = cp.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = cp.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_public(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h_public(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_Zeno(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = zeno(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel

def train_FLTurst(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = trust(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_median(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h_median(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_Krum(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h_krum(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_GM(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Our-attack":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h_gm(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_MCA(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = np.concatenate(client_x, axis=0)
    # cent_y = np.concatenate(client_y, axis=0)

    # train_perf_sel = np.zeros((com_amount, 2))
    test_perf_sel = np.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -0.1 * torch.mean(torch.stack(grads), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Our-attack":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        avg_model = h_mca(
            avg_model,
            client_gradients,
            learning_rate / (3 * i / 500 + 1),
            len(attackers),
        )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = np.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = np.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel


def train_H_CClip(
    data_obj,
    learning_rate,
    batch_size,
    epoch,
    com_amount,
    test_per,
    init_model,
    model_func,
    attackers,
    attack_name,
):

    n_client = data_obj.n_client

    client_x = data_obj.client_x
    client_y = data_obj.client_y
    # cent_x = cp.concatenate(client_x, axis=0)
    # cent_y = cp.concatenate(client_y, axis=0)

    # train_perf_sel = cp.zeros((com_amount, 2))
    test_perf_sel = cp.zeros((com_amount, 2))

    avg_model = model_func().to(device)
    avg_model.load_state_dict(init_model.state_dict())

    client_gradients = []
    byzantine = []

    for i in trange(com_amount, desc="Total Time", ascii=True):

        logger().info("Communication Round {:d}".format(i + 1))

        del client_gradients
        del byzantine

        client_gradients = []
        byzantine = []

        for client in range(n_client):
            if client not in attackers:
                train_x = client_x[client]
                train_y = client_y[client]

                gradients = train_grad(
                    avg_model,
                    train_x,
                    train_y,
                    learning_rate / (3 * i / 500 + 1),
                    batch_size,
                    epoch,
                    data_obj.dataset,
                )

                client_gradients.append(gradients)

        for client in range(n_client):
            if client in attackers:
                if attack_name == "Gaussian":
                    gradients = {
                        name: torch.normal(0, 90, size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Sign-flip":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = -3 * torch.sum(torch.stack(grads), dim=0)

                if attack_name == "LIE":
                    gradients = {}
                    for name, param in avg_model.named_parameters():
                        if param.requires_grad:
                            grads = [
                                client_grads[name]
                                for client_grads in client_gradients
                                if name in client_grads
                            ]
                            if grads:
                                gradients[name] = torch.normal(
                                    0, 9, size=param.shape, device=device
                                ) + 0.7 * torch.mean(torch.stack(grads, dim=0), dim=0)

                if attack_name == "FoE":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -0.1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                if attack_name == "Same-value":
                    gradients = {
                        name: torch.ones(size=param.shape, device=device)
                        for name, param in avg_model.named_parameters()
                        if param.requires_grad
                    }

                if attack_name == "Our-attack":
                    gradients = {}
                    for key in client_gradients[0].keys():
                        grads = [grads[key] for grads in client_gradients]
                        gradients[key] = (
                            -1
                            * torch.sum(torch.stack(grads), dim=0)
                            / (n_client - len(attackers))
                        )

                byzantine.append(gradients)

        client_gradients.extend(byzantine)
        if i == 0:
            avg_model, avg_grad = h_cclip(
                avg_model,
                client_gradients,
                learning_rate / (3 * i / 500 + 1),
                len(attackers),
                None,
            )
        else:
            avg_model, avg_grad = h_cclip(
                avg_model,
                client_gradients,
                learning_rate / (3 * i / 500 + 1),
                len(attackers),
                avg_grad,
            )

        if (i + 1) % test_per == 0:
            loss_test, acc_test = evaluate_global_model(
                data_obj.test_x, data_obj.test_y, avg_model, data_obj.dataset
            )
            test_perf_sel[i] = cp.array([loss_test, acc_test])

            logger().info(
                "**** Communication %3d, Test Accuracy: %.4f, Loss: %.4f"
                % (i + 1, acc_test, loss_test)
            )

            # loss_test, acc_test = evaluate_global_model(cent_x, cent_y,
            #                                  avg_model, data_obj.dataset)
            # train_perf_sel[i] = cp.array([loss_test, acc_test])

            # logger().info("**** Communication %3d, Train Accuracy: %.4f, Loss: %.4f"
            #       %(i+1, acc_test, loss_test))

        # # Freeze model
        # for params in avg_model.parameters():
        #     params.requires_grad = False

    # return test_perf_sel, train_perf_sel
    return test_perf_sel
