from typing import List, Any

import torch
import numpy as np
from scipy.optimize import minimize
from copy import deepcopy
import os

from codes.components.worker_toy import ByzantineWorker
from codes.components.utils import save_txt

torch.set_printoptions(precision=20)


def extract_initial_parameters(model, device):
    standin_model = type(model)().to(device)
    # Get the state dict of model1
    state_dict_model = model.state_dict()
    # Load the state dict of model1 into model2
    standin_model.load_state_dict(state_dict_model)
    # Get gradients from model1
    grads1 = {name: param.grad for name, param in model.named_parameters()}

    # Manually copy the gradients to model2
    for name, param in standin_model.named_parameters():
        if grads1[name] is not None:
            param.grad = grads1[name].clone()

    return standin_model


def extract_mini_batches(train_loader, max_batches_per_epoch):
    mini_batches = []
    pass_flag = 0
    for batch_idx, y_batch in enumerate(train_loader):
        mini_batches.append(y_batch)
        if len(mini_batches) >= max_batches_per_epoch:
            pass_flag = 1
            break
    assert pass_flag, [len(mini_batches), max_batches_per_epoch]
    return mini_batches


def extract_subdataset(train_loader, max_batches_per_epoch):
    mini_batches = []
    for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
        mini_batches.append((X_batch, y_batch))
    return mini_batches


class FullNLPAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, worker_rank, save_dir, momentum, n, f, T,
                 max_batches_per_epoch, search_size,
                 agg, args1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.momentum = momentum
        self.n = n
        self.f = f
        self.T = T
        self.max_batches_per_epoch = max_batches_per_epoch
        if not search_size:
            self.search_size = self.T
        else:
            self.search_size = search_size
        assert self.T % self.search_size == 0
        self.agg = agg
        self.args1 = args1
        self.opt_path_ful = []
        self.opt_path = []
        self.datasets = []

        self.rank = worker_rank - (self.n - self.f)
        self.save_dir = save_dir
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir[:-4] + str(self.rank) + '.txt')
        return self._gradient

    def omniscient_callback(self):
        if self.opt_path_ful:
            pass
        else:
            # apply optimal path to all Byzantine workers
            nlpattackers = []
            for w in self.simulator.workers:
                if isinstance(w, FullNLPAttack):
                    nlpattackers.append(w)

            opt_path = []
            for w in nlpattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            if opt_path:
                self.opt_path_ful = opt_path
                self.opt_path = deepcopy(opt_path)
                pass
            else:
                datasets = []
                if self.datasets:
                    datasets = self.datasets
                else:
                    honest_data_loaders = []
                    for w in self.simulator.workers:
                        if not isinstance(w, ByzantineWorker):
                            w.data_loader.sampler.set_epoch(0)
                            honest_data_loaders.append(w.data_loader)
                    for epoch in range(1, self.T + 1):
                        datas = []
                        for data_loader in honest_data_loaders:
                            mini_batches = extract_mini_batches(data_loader, self.max_batches_per_epoch)
                            datas.append(mini_batches)
                        datasets.append(datas)
                        for data_loader in honest_data_loaders:
                            data_loader.sampler.set_epoch(epoch)

                    for w in self.simulator.workers:
                        w.data_loader.sampler.set_epoch(0)

                    reshaped_datasets = []
                    num_honest_workers = len(honest_data_loaders)

                    for _ in range(num_honest_workers):
                        reshaped_datasets.append([])

                    for epoch_data in datasets:
                        for j in range(num_honest_workers):
                            reshaped_datasets[j].extend(epoch_data[j])

                    for worker_data in reshaped_datasets:
                        if worker_data:
                            first_element = worker_data[0]
                            worker_data.append(first_element)

                    datasets = reshaped_datasets
                    self.datasets = datasets

                # do the NLP solver work
                standin_model = extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()

                labor_model = extract_initial_parameters(standin_model, self.device)
                optimizer2_state_dict = standin_optm.state_dict()
                lr = standin_optm.param_groups[0]['lr']
                labor_optm = type(self.optimizer)(labor_model.parameters(), lr=lr)
                labor_optm.load_state_dict(deepcopy(optimizer2_state_dict))


                def fun(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * self.f * (t % self.search_size) + 2 * k],
                                                                        g[2 * self.f * (t % self.search_size) + 2 * k + 1]])
                                                              ).float().to(self.device))
                            # print("grads:426", g.dtype, grads[0].dtype, grads[-1].dtype)
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(coe * (loss / (self.n - self.f)))
                                coe *= self.args1.nlpobj

                        reward = np.mean(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward, reward.dtype)
                        # print(t, start, -reward)
                        return -reward

                    return v

                def fun_last(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * self.f * (t % self.search_size) + 2 * k],
                                                                        g[2 * self.f * (
                                                                                    t % self.search_size) + 2 * k + 1]])
                                                              ).float().to(self.device))
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                        reward = loss / (self.n - self.f)
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                def fun_lower(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * self.f * (t % self.search_size) + 2 * k],
                                                                        g[2 * self.f * (
                                                                                    t % self.search_size) + 2 * k + 1]])
                                                              ).float().to(self.device))
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(loss / (self.n - self.f))

                        reward = np.min(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                bounds = None

                cost = fun_last if self.args1.nlpobj == 0 else fun_lower if self.args1.nlpobj < 0 else fun

                weight = 1
                lambdas = []
                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)
                velocity_lst = [torch.zeros(sum([param.data.numel() for param in standin_model.parameters()])).to(
                    self.device)] * (self.n - self.f)

                for p in range(int(self.T / self.search_size)):
                    start = p * self.search_size
                    end = start + self.search_size

                    theta1 = deepcopy(theta)

                    # g0 = np.ones(self.f * 2 * self.search_size) * -1
                    g0 = np.concatenate([np.ones(int(0.2 * self.f * self.search_size)) * -0.4,
                                         np.ones(int(1.8 * self.f * self.search_size))])
                    funs = []
                    xs = []
                    solver = []
                    # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                    for method in ['Powell']:
                        res = minimize(cost(weight, theta, state, velocity_lst, self.agg, start, end),
                                       g0, method=method, bounds=bounds)
                        if res.success:
                            funs.append(res.fun)
                            xs.append(res.x)
                            solver.append(method)
                    if not funs:
                        raise RuntimeError('NLP Attack Failed!')
                    ind = funs.index(min(funs))
                    best_x = xs[ind]
                    best_solver = solver[ind]
                    print(best_solver)
                    print(funs[ind])
                    print(best_x)
                    grads = []
                    for t in range(self.search_size):
                        for k in range(self.f):
                            grads.append(torch.from_numpy(np.array([best_x[2 * self.f * t + 2 * k],
                                                                    best_x[2 * self.f * t + 2 * k + 1]])
                                                          ).float().to(self.device))
                    byz_grads: list[Any] = deepcopy(grads)
                    # print("byz_grads:620", byz_grads[-1].dtype)

                    self.opt_path_ful.extend(grads)
                    self.opt_path.extend(grads)

                    theta2 = deepcopy(theta)
                    assert compare_state_dicts(theta1, theta2), [theta1, theta2]

                    for t in range(start, end):
                        grads = []
                        for k in range(self.n - self.f):
                            mini_batches = datasets[k]
                            y_batch = mini_batches[t].to(self.device)
                            # print('y_batch', k, y_batch)
                            standin_optm.zero_grad()
                            output = standin_model(y_batch)
                            loss = self.loss_func(output, y_batch)
                            loss.backward()
                            gradient = torch.cat(
                                [param.grad.clone().detach().flatten() for param in standin_model.parameters() if
                                 param.grad is not None])

                            velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                            grads.append(velocity_lst[k].clone().detach())

                        for k in range(self.f):
                            grads.append(byz_grads[self.f * (t % self.search_size) + k])

                        aggregated = self.agg(grads)

                        beg = 0
                        for group in standin_optm.param_groups:
                            for p in group["params"]:
                                if p.grad is None:
                                    continue
                                ter = beg + len(p.grad.view(-1))
                                x = aggregated[beg:ter].reshape_as(p.grad.data)
                                p.grad.data = x.clone().detach()
                                beg = ter
                        standin_optm.step()
                        weight *= self.args1.nlpobj

                    theta = standin_model.state_dict()
                    state = standin_optm.state_dict()

                print('Last theta:', theta, len(self.opt_path))
                for w in nlpattackers:
                    w.opt_path_ful = deepcopy(self.opt_path_ful)
                    w.opt_path = deepcopy(self.opt_path)

        while len(self.opt_path) % self.f != self.rank:
            self.opt_path.pop(0)
        mal_grad = self.opt_path.pop(0)
        # print("nonlinear:670:mal_grad", mal_grad.dtype)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


class NLPAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, save_dir, momentum, n, f, T,
                 max_batches_per_epoch, search_size,
                 agg, args1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.momentum = momentum
        self.n = n
        self.f = f
        self.T = T
        self.max_batches_per_epoch = max_batches_per_epoch
        if not search_size:
            self.search_size = self.T
        else:
            self.search_size = search_size
        assert self.T % self.search_size == 0
        self.agg = agg
        self.args1 = args1
        self.opt_path_ful = []
        self.opt_path = []
        self.datasets = []

        self.save_dir = save_dir
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        if self.opt_path_ful:
            pass
        else:
            # apply optimal path to all Byzantine workers
            nlpattackers = []
            for w in self.simulator.workers:
                if isinstance(w, NLPAttack):
                    nlpattackers.append(w)

            opt_path = []
            for w in nlpattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            if opt_path:
                self.opt_path_ful = opt_path
                self.opt_path = deepcopy(opt_path)
                pass
            else:
                datasets = []
                if self.datasets:
                    datasets = self.datasets
                else:
                    honest_data_loaders = []
                    for w in self.simulator.workers:
                        if not isinstance(w, ByzantineWorker):
                            w.data_loader.sampler.set_epoch(0)
                            honest_data_loaders.append(w.data_loader)
                    for epoch in range(1, self.T + 1):
                        datas = []
                        for data_loader in honest_data_loaders:
                            mini_batches = extract_mini_batches(data_loader, self.max_batches_per_epoch)
                            datas.append(mini_batches)
                        datasets.append(datas)
                        for data_loader in honest_data_loaders:
                            data_loader.sampler.set_epoch(epoch)

                    for w in self.simulator.workers:
                        w.data_loader.sampler.set_epoch(0)

                    reshaped_datasets = []
                    num_honest_workers = len(honest_data_loaders)

                    for _ in range(num_honest_workers):
                        reshaped_datasets.append([])

                    for epoch_data in datasets:
                        for j in range(num_honest_workers):
                            reshaped_datasets[j].extend(epoch_data[j])

                    for worker_data in reshaped_datasets:
                        if worker_data:
                            first_element = worker_data[0]
                            worker_data.append(first_element)

                    datasets = reshaped_datasets
                    self.datasets = datasets

                # do the NLP solver work
                standin_model = extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()

                labor_model = extract_initial_parameters(standin_model, self.device)
                optimizer2_state_dict = standin_optm.state_dict()
                lr = standin_optm.param_groups[0]['lr']
                labor_optm = type(self.optimizer)(labor_model.parameters(), lr=lr)
                labor_optm.load_state_dict(deepcopy(optimizer2_state_dict))


                def fun(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * (t % self.search_size)], g[2 * (t % self.search_size) + 1]])).float().to(self.device))
                            # print("grads:426", g.dtype, grads[0].dtype, grads[-1].dtype)
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(coe * (loss / (self.n - self.f)))
                                coe *= self.args1.nlpobj

                        reward = np.mean(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward, reward.dtype)
                        # print(t, start, -reward)
                        return -reward

                    return v

                def fun_last(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * (t % self.search_size)], g[2 * (t % self.search_size) + 1]])).float().to(self.device))
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                        reward = loss / (self.n - self.f)
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                def fun_lower(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            for k in range(self.f):
                                grads.append(torch.from_numpy(np.array([g[2 * (t % self.search_size)], g[2 * (t % self.search_size) + 1]])).float().to(self.device))
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(loss / (self.n - self.f))

                        reward = np.min(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                bounds = None

                cost = fun_last if self.args1.nlpobj == 0 else fun_lower if self.args1.nlpobj < 0 else fun

                weight = 1
                lambdas = []
                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)
                velocity_lst = [torch.zeros(sum([param.data.numel() for param in standin_model.parameters()])).to(
                    self.device)] * (self.n - self.f)

                for p in range(int(self.T / self.search_size)):
                    start = p * self.search_size
                    end = start + self.search_size

                    theta1 = deepcopy(theta)

                    # g0 = np.ones(2 * self.search_size) * -1
                    # g0 = np.concatenate([np.ones(int(0.2 * self.search_size)) * -0.7, np.zeros(int(1.8 * self.search_size))])
                    if self.args1.agg == 'cm' or self.args1.agg == 'tm':
                        g0 = np.concatenate(
                            [np.ones(int(0.2 * self.search_size)) * 0, np.ones(int(0.1 * self.search_size)) * 0, np.ones(int(1.7 * self.search_size)) * 0])
                    elif self.args1.agg == 'rfa':
                        g0 = np.concatenate(
                            [np.ones(int(0.2 * self.search_size)) * 0, np.ones(int(1.8 * self.search_size)) * 0])
                    else:
                        g0 = np.concatenate(
                            [np.ones(int(0.2 * self.search_size)) * 0, np.ones(int(1.8 * self.search_size)) * -0.1])
                    funs = []
                    xs = []
                    solver = []
                    # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                    for method in ['Powell']:
                        res = minimize(cost(weight, theta, state, velocity_lst, self.agg, start, end),
                                       g0, method=method, bounds=bounds)
                        if res.success:
                            funs.append(res.fun)
                            xs.append(res.x)
                            solver.append(method)
                    if not funs:
                        raise RuntimeError('NLP Attack Failed!')
                    ind = funs.index(min(funs))
                    best_x = xs[ind]
                    best_solver = solver[ind]
                    print(best_solver)
                    print(funs[ind])
                    print(best_x)
                    grads = []
                    for t in range(self.search_size):
                        grads.append(torch.from_numpy(np.array([best_x[2*t], best_x[2*t+1]])).float().to(self.device))
                    byz_grads: list[Any] = deepcopy(grads)
                    # print("byz_grads:620", byz_grads[-1].dtype)

                    self.opt_path_ful.extend(grads)
                    self.opt_path.extend(grads)

                    theta2 = deepcopy(theta)
                    assert compare_state_dicts(theta1, theta2), [theta1, theta2]

                    for t in range(start, end):
                        grads = []
                        for k in range(self.n - self.f):
                            mini_batches = datasets[k]
                            y_batch = mini_batches[t].to(self.device)
                            # print('y_batch', k, y_batch)
                            standin_optm.zero_grad()
                            output = standin_model(y_batch)
                            loss = self.loss_func(output, y_batch)
                            loss.backward()
                            gradient = torch.cat(
                                [param.grad.clone().detach().flatten() for param in standin_model.parameters() if
                                 param.grad is not None])

                            velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                            grads.append(velocity_lst[k].clone().detach())

                        for k in range(self.f):
                            grads.append(byz_grads[t % self.search_size])

                        aggregated = self.agg(grads)

                        beg = 0
                        for group in standin_optm.param_groups:
                            for p in group["params"]:
                                if p.grad is None:
                                    continue
                                ter = beg + len(p.grad.view(-1))
                                x = aggregated[beg:ter].reshape_as(p.grad.data)
                                p.grad.data = x.clone().detach()
                                beg = ter
                        standin_optm.step()
                        weight *= self.args1.nlpobj

                    theta = standin_model.state_dict()
                    state = standin_optm.state_dict()

                print('Last theta:', theta, len(self.opt_path))
                for w in nlpattackers:
                    w.opt_path_ful = deepcopy(self.opt_path_ful)
                    w.opt_path = deepcopy(self.opt_path)

        mal_grad = self.opt_path.pop(0)
        # print("nonlinear:670:mal_grad", mal_grad.dtype)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


class NOBLEAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, save_dir, momentum, n, f, T,
                 max_batches_per_epoch, search_size,
                 agg, args1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.momentum = momentum
        self.n = n
        self.f = f
        self.T = T
        self.max_batches_per_epoch = max_batches_per_epoch
        if not search_size:
            self.search_size = self.T
        else:
            self.search_size = search_size
        assert self.T % self.search_size == 0
        self.agg = agg
        self.args1 = args1
        self.opt_path_ful = []
        self.opt_path = []
        self.datasets = []

        self.save_dir = save_dir
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        if self.opt_path_ful:
            pass
        else:
            # apply optimal path to all Byzantine workers
            nlpattackers = []
            for w in self.simulator.workers:
                if isinstance(w, NOBLEAttack):
                    nlpattackers.append(w)

            opt_path = []
            for w in nlpattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            if opt_path:
                self.opt_path_ful = opt_path
                self.opt_path = deepcopy(opt_path)
                pass
            else:
                datasets = []
                if self.datasets:
                    datasets = self.datasets
                else:
                    honest_data_loaders = []
                    for w in self.simulator.workers:
                        if not isinstance(w, ByzantineWorker):
                            w.data_loader.sampler.set_epoch(0)
                            honest_data_loaders.append(w.data_loader)
                    for epoch in range(1, self.T + 1):
                        datas = []
                        for data_loader in honest_data_loaders:
                            mini_batches = extract_mini_batches(data_loader, self.max_batches_per_epoch)
                            datas.append(mini_batches)
                        datasets.append(datas)
                        for data_loader in honest_data_loaders:
                            data_loader.sampler.set_epoch(epoch)

                    for w in self.simulator.workers:
                        w.data_loader.sampler.set_epoch(0)

                    reshaped_datasets = []
                    num_honest_workers = len(honest_data_loaders)

                    for _ in range(num_honest_workers):
                        reshaped_datasets.append([])

                    for epoch_data in datasets:
                        for j in range(num_honest_workers):
                            reshaped_datasets[j].extend(epoch_data[j])

                    for worker_data in reshaped_datasets:
                        if worker_data:
                            first_element = worker_data[0]
                            worker_data.append(first_element)

                    datasets = reshaped_datasets
                    self.datasets = datasets

                # do the NLP solver work
                standin_model = extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()

                labor_model = extract_initial_parameters(standin_model, self.device)
                optimizer2_state_dict = standin_optm.state_dict()
                lr = standin_optm.param_groups[0]['lr']
                labor_optm = type(self.optimizer)(labor_model.parameters(), lr=lr)
                labor_optm.load_state_dict(deepcopy(optimizer2_state_dict))


                def fun(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            honest_grads = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                honest_grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                            for k in range(self.f):
                                lambda_t = g[t % self.search_size]
                                grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                            # print("grads:426", g.dtype, grads[0].dtype, grads[-1].dtype)
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(coe * (loss / (self.n - self.f)))
                                coe *= self.args1.nlpobj

                        reward = np.mean(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward, reward.dtype)
                        # print(t, start, -reward)
                        return -reward

                    return v

                def fun_last(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            honest_grads = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                honest_grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                            for k in range(self.f):
                                lambda_t = g[t % self.search_size]
                                grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                            # print("grads:426", g.dtype, grads[0].dtype, grads[-1].dtype)
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                        reward = loss / (self.n - self.f)
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                def fun_lower(weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr,
                        start=0, end=self.T):

                    def v(g):
                        agg = deepcopy(aggr)
                        labor_model.load_state_dict(deepcopy(init_theta_epoch))
                        labor_optm.load_state_dict(deepcopy(init_state_dict))
                        velocity_lst = deepcopy(init_velocitylst_epoch)
                        coe = deepcopy(weight)
                        losses = []

                        # for t in range(self.T):
                        for t in range(start, end):
                            grads = []
                            samples = []
                            honest_grads = []
                            for k in range(self.n - self.f):
                                mini_batches = datasets[k]
                                y_batch = mini_batches[t].to(self.device)
                                # print('y_batch', k, y_batch)
                                samples.append(mini_batches[t])
                                labor_optm.zero_grad()
                                output = labor_model(y_batch)
                                loss = self.loss_func(output, y_batch)
                                loss.backward()
                                gradient = torch.cat(
                                    [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                                     param.grad is not None])
                                # print("nonconvex402", gradient)

                                velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                                grads.append(velocity_lst[k].clone().detach())
                                honest_grads.append(velocity_lst[k].clone().detach())
                                # grads.append(gradient)
                            honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                            for k in range(self.f):
                                lambda_t = g[t % self.search_size]
                                grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                            # print("grads:426", g.dtype, grads[0].dtype, grads[-1].dtype)
                            aggregated = agg(grads)

                            beg = 0
                            for group in labor_optm.param_groups:
                                for p in group["params"]:
                                    if p.grad is None:
                                        continue
                                    ter = beg + len(p.grad.view(-1))
                                    x = aggregated[beg:ter].reshape_as(p.grad.data)
                                    p.grad.data = x.clone().detach()
                                    beg = ter
                            labor_optm.step()

                            loss = 0
                            with torch.no_grad():
                                for sam in samples:
                                    y_batch = sam.to(self.device)
                                    output = labor_model(y_batch)
                                    loss += self.loss_func(output, y_batch)
                                # print('type(loss)', type(loss), loss.dtype)
                                losses.append(loss / (self.n - self.f))

                        reward = np.min(np.array([ls.cpu().numpy() for ls in losses]))
                        print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
                        return -reward

                    return v

                bounds = None

                cost = fun_last if self.args1.nlpobj == 0 else fun_lower if self.args1.nlpobj < 0 else fun

                weight = 1
                lambdas = []
                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)
                velocity_lst = [torch.zeros(sum([param.data.numel() for param in standin_model.parameters()])
                                            ).to(self.device) for _ in range(self.n - self.f)]


                for p in range(int(self.T / self.search_size)):
                    start = p * self.search_size
                    end = start + self.search_size

                    theta1 = deepcopy(theta)

                    # g0 = np.zeros(self.search_size)
                    # g0 = np.ones(self.search_size) * 0
                    # half_length = self.search_size // 5 + 1
                    # g0 = np.concatenate([np.ones(half_length) * -1, np.zeros(self.search_size - half_length)])
                    if self.args1.agg == 'rfa':
                        half_length = self.search_size // 5 + 1
                        g0 = np.concatenate([np.ones(half_length) * 0, np.ones(self.search_size - half_length) * 0])
                    elif self.args1.agg == 'krum':
                        half_length = self.search_size // 10 + 1
                        if self.search_size > 4:
                            g0 = np.concatenate([np.ones(half_length) * 0.1, np.ones(self.search_size - half_length) * 0])
                        else:
                            g0 = np.concatenate(
                                [np.ones(half_length) * 0.0, np.ones(self.search_size - half_length) * 1])
                    else:
                        half_length = self.search_size // 5 + 1
                        if self.search_size > self.T/2:
                            g0 = np.concatenate([np.ones(half_length) * 0, np.ones(self.search_size - half_length) * 1])
                        elif self.search_size > 4:
                            g0 = np.concatenate([np.ones(half_length) * 0.17, np.ones(self.search_size - half_length) * 1])
                        else:
                            g0 = np.concatenate(
                                [np.ones(half_length) * -0.5, np.ones(self.search_size - half_length) * 1])
                    funs = []
                    xs = []
                    solver = []
                    # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                    for method in ['Powell']:
                        res = minimize(cost(weight, theta, state, velocity_lst, self.agg, start, end),
                                       g0, method=method, bounds=bounds)
                        # options = {'xtol': 1e-8, 'ftol': 1e-8}
                        if res.success:
                            funs.append(res.fun)
                            xs.append(res.x)
                            solver.append(method)
                    if not funs:
                        raise RuntimeError('NLP Attack Failed!')
                    ind = funs.index(min(funs))
                    best_x = xs[ind]
                    best_solver = solver[ind]
                    print(best_solver)
                    print(funs[ind])
                    print(best_x)
                    grads_lambdas = []
                    for t in range(self.search_size):
                        grads_lambdas.append(best_x[t])
                    byz_lambdas: list[Any] = deepcopy(grads_lambdas)
                    # print("byz_grads:620", byz_grads[-1].dtype)

                    self.opt_path_ful.extend(grads_lambdas)
                    self.opt_path.extend(grads_lambdas)

                    theta2 = deepcopy(theta)
                    assert compare_state_dicts(theta1, theta2), [theta1, theta2]

                    for t in range(start, end):
                        grads = []
                        honest_grads = []
                        for k in range(self.n - self.f):
                            mini_batches = datasets[k]
                            y_batch = mini_batches[t].to(self.device)
                            # print('y_batch', k, y_batch)
                            standin_optm.zero_grad()
                            output = standin_model(y_batch)
                            loss = self.loss_func(output, y_batch)
                            loss.backward()
                            gradient = torch.cat(
                                [param.grad.clone().detach().flatten() for param in standin_model.parameters() if
                                 param.grad is not None])

                            velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                            grads.append(velocity_lst[k].clone().detach())
                            honest_grads.append(velocity_lst[k].clone().detach())

                        honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                        for k in range(self.f):
                            lambda_t = byz_lambdas[t % self.search_size]
                            grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))

                        aggregated = self.agg(grads)

                        beg = 0
                        for group in standin_optm.param_groups:
                            for p in group["params"]:
                                if p.grad is None:
                                    continue
                                ter = beg + len(p.grad.view(-1))
                                x = aggregated[beg:ter].reshape_as(p.grad.data)
                                p.grad.data = x.clone().detach()
                                beg = ter
                        standin_optm.step()
                        weight *= self.args1.nlpobj

                    theta = standin_model.state_dict()
                    state = standin_optm.state_dict()

                lambdas = []
                for la in self.opt_path_ful:
                    lambdas.append(str(la))
                save_lambda = self.save_dir[:-14] + '/lambdas.txt'
                save_txt(lambdas, save_lambda)
                print('Last theta:', theta, len(self.opt_path))
                for w in nlpattackers:
                    w.opt_path_ful = deepcopy(self.opt_path_ful)
                    w.opt_path = deepcopy(self.opt_path)

        # mal_grad = self.opt_path.pop(0)
        # # print("nonlinear:670:mal_grad", mal_grad.dtype)
        # self._gradient = mal_grad
        # Extract the next action (lambda_t) from the optimal path and determine the malicious gradient accordingly
        lambda_t = self.opt_path.pop(0)
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)
        mal_grad = (1 - lambda_t) * honest_avg + lambda_t * (-honest_avg)
        # print("nonlinear:118:mal_grad", mal_grad)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


class TAWFOEAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, epsilon, n, f, save_dir, agg, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon
        self.n = n
        self.f = f
        self._gradient = None
        self.save_dir = save_dir
        self.agg = agg
        self.window_size = 100
        self.byz_grads = []
        self.honest_hist = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)
        # self.honest_hist.append(honest_avg)

        best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [-20], [0], 0.1)

        dot_products = [torch.sum(honest_avg * grad).item() for grad in self.honest_hist]
        avg_dot_product = sum(dot_products) / len(dot_products) if dot_products else 0
        print("Average Dot Product:", avg_dot_product, len(self.honest_hist))
        if avg_dot_product > 0 and len(self.honest_hist) >= self.window_size:
            best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [0], [20], 0.1)

        # update honest gradient history
        self.honest_hist.append(honest_avg)
        if len(self.honest_hist) > self.window_size:
            self.honest_hist.pop(0)

        # if len(self.honest_hist) > 0:
        #     # compute dot product
        #     dot_product = torch.sum(honest_avg * self.honest_hist[-1])
        #     self.dot_products.append(dot_product.item())
        #
        #     # keep window_size of dot products
        #     if len(self.dot_products) > self.window_size:
        #         self.dot_products.pop(0)
        #
        #     # get the average of slide window
        #     avg_dot_product = sum(self.dot_products) / len(self.dot_products)
        #     print("Average Dot Product:", avg_dot_product)
        #
        #     if avg_dot_product > 0 and len(self.dot_products) >= self.window_size:
        #         best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [0], [20], 0.1)
        #
        # # update honest gradient history
        # self.honest_hist.append(honest_avg)

        self._gradient = best_lambda * honest_avg
        for k in range(self.f):
            gradients.append(self._gradient)
        self.agg(gradients)

    def cost(self, honest_avg, honest_grads):
        def v(g):
            aggr = deepcopy(self.agg)
            honest = deepcopy(honest_avg)
            grads = deepcopy(honest_grads)
            for k in range(self.f):
                grads.append(g[0] * honest)

            aggregated = aggr(grads)

            return -torch.norm(aggregated - honest)
        return v

    def golden_section_search(self, f, a, b, tol):
        phi = (1 + 5 ** 0.5) / 2  # golden ratio
        c = [b[0] - (b[0] - a[0]) / phi]
        d = [a[0] + (b[0] - a[0]) / phi]
        while abs(c[0] - d[0]) > tol:
            if f(c) < f(d):
                b = d
            else:
                a = c

            c = [b[0] - (b[0] - a[0]) / phi]
            d = [a[0] + (b[0] - a[0]) / phi]

        return (b[0] + a[0]) / 2

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


def compare_state_dicts(dict1, dict2):
    for key in dict1:
        if not torch.equal(dict1[key], dict2[key]):
            return False
    return True
