from copy import deepcopy
import json

import numpy as np
import torch
from scipy.optimize import Bounds
from scipy.optimize import OptimizeResult
from scipy.optimize import minimize

from codes.components.worker import ByzantineWorker
from codes.components.utils import save_txt
# torch.set_printoptions(threshold=10)

class OptimizerCallback:

    def __init__(self, cost_function, initial_value=float('inf'), search_size=1):
        self.cost_function = cost_function
        self.init_value = initial_value
        self.current_fun_value = initial_value
        self.best_fun_value_iter = initial_value
        self.best_res = initial_value
        self.search_size = search_size
        self.best_g = np.zeros(search_size)
        self.best_g_iter = np.zeros(search_size)
        self.iter_without_improvement = 0
        self.max_iter_without_improvement = 200 * search_size ** 2

    def wrapped_cost(self, g):
        self.current_fun_value = self.cost_function(g)
        # print("wrapped g", g, self.current_fun_value)
        if self.current_fun_value <= self.best_fun_value_iter:
            self.best_fun_value_iter = self.current_fun_value
            self.best_g_iter = g
        return self.current_fun_value

    def callback(self, g):
        f = self.best_fun_value_iter

        if f < self.best_res:
            self.best_res = f
            self.best_g = self.best_g_iter
            self.iter_without_improvement = 0
        else:
            self.iter_without_improvement += 1

        self.best_fun_value_iter = self.init_value
        self.best_g_iter = np.zeros(self.search_size)

        if self.iter_without_improvement >= self.max_iter_without_improvement:
            raise EarlyTerminationException(f'No improvement for {self.max_iter_without_improvement} '
                                            f'iterations, stopping optimization.')


class EarlyTerminationException(Exception):
    pass


class SSNLPAttack(ByzantineWorker):
    """
    Implementation of the SSNLP Attack in a federated learning environment. This attacker is a
    Byzantine Worker that manipulates its gradients to affect the global model adversely.

    Args:
        n (int): Total number of workers.
        f (int): Number of Byzantine (malicious) workers.
        T (int): the number of total epochs.
        max_batches_per_epoch (int): Maximum number of mini-batches to be processed in one epoch.
        search_size (int, optional): Length of each segment to search for an optimal attack. If not provided, it defaults to max_batches_per_epoch.
        agg (function): Aggregation function used in the federated learning setup.
        args1: Additional arguments related to the attack mechanism.
        *args: Variable length argument list for the superclass.
        **kwargs: Arbitrary keyword arguments for the superclass.
    """

    def __init__(self, n, f, T, save_dir,
                 max_batches_per_epoch, search_size,
                 agg, args1, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n = n
        self.f = f
        self.T = T
        self.save_dir = save_dir
        self.agg = agg
        self.max_batches_per_epoch = max_batches_per_epoch
        if not search_size:
            self.search_size = max_batches_per_epoch
        else:
            self.search_size = search_size
        self.args1 = args1
        self.opt_path_ful = []
        self.opt_path = []

        self.datasets = []

        self.random_states = {}
        self.init_states = []
        # self.cache_random_state()

        """
            Only for checking correctness of implementation
        """
        self.theta_check = deepcopy(self.model.state_dict())
        self.state_check = deepcopy(self.optimizer.state_dict())
        # self.restore_random_state()

    def get_gradient(self):
        """
        Fetch the gradient manipulated by the Byzantine worker.

        Returns:
            Tensor: The manipulated gradient of the worker.
        """
        return self._gradient

    def extract_subdataset(self, train_loader, max_batches_per_epoch):
        """
        Extracts mini-batches from a data loader.

        Args:
            train_loader (torch.utils.data.DataLoader): The data loader from which batches are extracted.

        Returns:
            list[tuple]: A list of mini-batches.
        """
        mini_batches = []
        for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
            mini_batches.append((X_batch, y_batch))
        return mini_batches

    def extract_mini_batches(self, train_loader, max_batches_per_epoch):
        """
        Extracts a limited number of mini-batches from a data loader.

        Args:
            train_loader (torch.utils.data.DataLoader): The data loader from which batches are extracted.
            max_batches_per_epoch (int): The maximum number of batches to extract.

        Returns:
            list[tuple]: A list of mini-batches.
        """
        mini_batches = []
        pass_flag = 0
        for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
            mini_batches.append((X_batch, y_batch))
            if len(mini_batches) >= max_batches_per_epoch:
                pass_flag = 1
                break

        # Ensure that the desired number of mini-batches were extracted
        assert pass_flag, [len(mini_batches), max_batches_per_epoch]

        return mini_batches

    def extract_initial_parameters(self, model, device):
        """
        Extracts the initial parameters and gradients of a given model into a new
        model of the same type without modifying the original model.

        Args:
            model (torch.nn.Module): The input model from which parameters and gradients will be extracted.
            device (torch.device): The device on which the new model will reside.

        Returns:
            torch.nn.Module: A new model initialized with the parameters and gradients of the input model.
        """
        # self.cache_random_state()
        # self.restore_random_state()
        # Create a new model of the same type and move it to the specified device
        standin_model = type(model)().to(device)

        # Extract the state dictionary of the original model
        state_dict_model = model.state_dict()

        # Load the state dictionary of the original model into the stand-in model
        standin_model.load_state_dict(state_dict_model)

        # Extract gradients from the original model
        grads1 = {name: param.grad for name, param in model.named_parameters()}

        # Manually copy the gradients to the stand-in model
        for name, param in standin_model.named_parameters():
            if grads1[name] is not None:
                param.grad = grads1[name].clone()
        # self.restore_random_state()

        return standin_model

    def cache_and_restore_init_state(self) -> None:
        flag = False
        for w in self.simulator.workers:
            if w.worker_rank == 0:
                self.taw_init_states = w.taw_init_states
                self.random_states = w.taw_init_states
                # self.restore_random_state()
                flag = True
        assert flag, "Fail to restore init state"

    def cache_init_states(self) -> None:
        for w in self.simulator.workers:
            self.init_states.append(w.taw_init_states)

    def _restore_init_state(self, r_states) -> None:
        if self.use_cuda:
            if torch.cuda.is_available():
                torch.cuda.set_rng_state(r_states["torch_cuda"])
            elif torch.backends.mps.is_available():
                pass
                # torch.backends.mps.manual_seed(self.random_states["torch_cuda"])
        # random.setstate(self.random_states["random"])
        torch.set_rng_state(r_states["torch"])
        np.random.set_state(r_states["numpy"])

    def cache_random_state(self) -> None:
        # self.random_states["random"] = random.getstate()
        if self.use_cuda:
            if torch.cuda.is_available():
                self.random_states["torch_cuda"] = torch.cuda.get_rng_state()
            elif torch.backends.mps.is_available():
                pass
                # You cannot cache the MPS RNG state, but you can set a seed.
                # self.random_states["torch_cuda"] = random.randint(0, 2 ** 32 - 1)
                # torch.backends.mps.manual_seed(self.random_states["torch_cuda"])
        self.random_states["torch"] = torch.get_rng_state()
        self.random_states["numpy"] = np.random.get_state()
        # print("Cache SSNLP random state:", self.random_states)

    def restore_random_state(self) -> None:
        if self.use_cuda:
            if torch.cuda.is_available():
                torch.cuda.set_rng_state(self.random_states["torch_cuda"])
            elif torch.backends.mps.is_available():
                pass
                # torch.backends.mps.manual_seed(self.random_states["torch_cuda"])
        # random.setstate(self.random_states["random"])
        torch.set_rng_state(self.random_states["torch"])
        np.random.set_state(self.random_states["numpy"])
        # print("Restore SSNLP random state:", self.random_states)

    # Define the objective function for the NLP solver
    def fun(self, weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr, labor_model,
            labor_optm, datasets, start, end):

        def v(g):
            # self.restore_random_state()
            agg = deepcopy(aggr)
            labor_model.load_state_dict(deepcopy(init_theta_epoch))
            labor_optm.load_state_dict(deepcopy(init_state_dict))
            velocity_lst = deepcopy(init_velocitylst_epoch)
            coe = deepcopy(weight)
            losses = []
            # for t in range(T):
            for itr in range(start, end):
                labor_model.train()
                grads = []
                samples = []
                honest_grads = []
                for k in range(self.n):
                    mini_batches = datasets[k]
                    X_batch, y_batch = mini_batches[itr]
                    X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                    # samples.append(mini_batches[itr + 1])
                    samples.append(mini_batches[itr])
                    if itr != start:
                        # self.cache_random_state()
                        # if k != 0:
                            # print([i for i in self.random_states['torch']])
                            # assert compare_randome_state_dict(self.random_states, self.init_states[k]), [self.random_states, self.init_states[k-1]]
                        # if itr == start:
                        #     self._restore_init_state(self.init_states[k])
                        labor_optm.zero_grad()
                        output = labor_model(X_batch)
                        loss = self.loss_func(output, y_batch)
                        loss.backward()
                        gradient = torch.cat(
                            [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                             param.grad is not None])

                        velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                    if k < self.n - self.f:
                        grads.append(velocity_lst[k].clone().detach())
                        honest_grads.append(velocity_lst[k].clone().detach())

                honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                # print("honest_avg:301", t, itr, honest_avg.shape)
                for k in range(self.f):
                    lambda_t = g[itr % self.search_size]
                    grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                # print("grads:1409", g.dtype, grads[0].dtype, grads[-1].dtype)
                aggregated = agg(grads)

                beg = 0
                for group in labor_optm.param_groups:
                    for p in group["params"]:
                        if p.grad is None:
                            continue
                        ter = beg + len(p.grad.view(-1))
                        x = aggregated[beg:ter].reshape_as(p.grad.data)
                        p.grad.data = x.clone().detach()
                        beg = ter
                if self.clipping:
                    torch.nn.utils.clip_grad_norm_(labor_model.parameters(), max_norm=self.clipping)
                labor_optm.step()

                loss = 0
                labor_model.eval()
                with torch.no_grad():
                    for sam in samples:
                        X_batch, y_batch = sam
                        X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                        output = labor_model(X_batch)
                        loss += self.loss_func(output, y_batch)
                    losses.append(coe * (loss / (self.n - self.f)))
                    coe *= self.args1.nlpobj

            reward = np.mean(np.array([ls.cpu().numpy() for ls in losses]))
            print(self.args1.seed, self.args1.agg, self.args1.f, start, reward, reward.dtype)
            # self.restore_random_state()
            return -reward

        return v

    def fun_last(self, weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr, labor_model,
                 labor_optm, datasets, start, end):

        def v(g):
            # self.restore_random_state()
            agg = deepcopy(aggr)
            labor_model.load_state_dict(deepcopy(init_theta_epoch))
            labor_optm.load_state_dict(deepcopy(init_state_dict))
            velocity_lst = deepcopy(init_velocitylst_epoch)
            coe = deepcopy(weight)
            losses = []
            loss = 0
            # for t in range(T):
            for itr in range(start, end):
                labor_model.train()
                grads = []
                samples = []
                honest_grads = []
                for k in range(self.n - self.f):
                    mini_batches = datasets[k]
                    X_batch, y_batch = mini_batches[itr]
                    X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                    # samples.append(mini_batches[itr + 1])
                    samples.append(mini_batches[itr])
                    labor_optm.zero_grad()
                    output = labor_model(X_batch)
                    loss = self.loss_func(output, y_batch)
                    loss.backward()
                    gradient = torch.cat(
                        [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                         param.grad is not None])

                    velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                    grads.append(velocity_lst[k].clone().detach())
                    honest_grads.append(velocity_lst[k].clone().detach())

                honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                # print("honest_avg:301", t, itr, honest_avg.shape)
                for k in range(self.f):
                    lambda_t = g[itr % self.search_size]
                    grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                # print("grads:305", grads[0].shape, grads[-1].shape)
                aggregated = agg(grads)

                beg = 0
                for group in labor_optm.param_groups:
                    for p in group["params"]:
                        if p.grad is None:
                            continue
                        ter = beg + len(p.grad.view(-1))
                        x = aggregated[beg:ter].reshape_as(p.grad.data)
                        p.grad.data = x.clone().detach()
                        beg = ter
                if self.clipping:
                    torch.nn.utils.clip_grad_norm_(labor_model.parameters(), max_norm=self.clipping)
                labor_optm.step()

                loss = 0
                labor_model.eval()
                with torch.no_grad():
                    for sam in samples:
                        X_batch, y_batch = sam
                        X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                        output = labor_model(X_batch)
                        loss += self.loss_func(output, y_batch)
            losses = loss / (self.n - self.f)
            # print(self.args1.seed, self.args1.agg, self.args1.f, start, losses)
            # self.restore_random_state()
            return -losses

        return v

    def fun_lower(self, weight, init_theta_epoch, init_state_dict, init_velocitylst_epoch, aggr, labor_model,
                  labor_optm, datasets, start, end):

        def v(g):
            # self.restore_random_state()
            agg = deepcopy(aggr)
            labor_model.load_state_dict(deepcopy(init_theta_epoch))
            labor_optm.load_state_dict(deepcopy(init_state_dict))
            velocity_lst = deepcopy(init_velocitylst_epoch)
            coe = deepcopy(weight)
            losses = []
            # for t in range(T):
            for itr in range(start, end):
                labor_model.train()
                grads = []
                samples = []
                honest_grads = []
                for k in range(self.n - self.f):
                    mini_batches = datasets[k]
                    X_batch, y_batch = mini_batches[itr]
                    X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                    # samples.append(mini_batches[itr + 1])
                    samples.append(mini_batches[itr])
                    labor_optm.zero_grad()
                    output = labor_model(X_batch)
                    loss = self.loss_func(output, y_batch)
                    loss.backward()
                    gradient = torch.cat(
                        [param.grad.clone().detach().flatten() for param in labor_model.parameters() if
                         param.grad is not None])

                    velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                    grads.append(velocity_lst[k].clone().detach())
                    honest_grads.append(velocity_lst[k].clone().detach())

                honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                # print("honest_avg:301", t, itr, honest_avg.shape)
                for k in range(self.f):
                    lambda_t = g[itr % self.search_size]
                    grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                # print("grads:305", grads[0].shape, grads[-1].shape)
                aggregated = agg(grads)

                beg = 0
                for group in labor_optm.param_groups:
                    for p in group["params"]:
                        if p.grad is None:
                            continue
                        ter = beg + len(p.grad.view(-1))
                        x = aggregated[beg:ter].reshape_as(p.grad.data)
                        p.grad.data = x.clone().detach()
                        beg = ter
                if self.clipping:
                    torch.nn.utils.clip_grad_norm_(labor_model.parameters(), max_norm=self.clipping)
                labor_optm.step()

                loss = 0
                labor_model.eval()
                with torch.no_grad():
                    for sam in samples:
                        X_batch, y_batch = sam
                        X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                        output = labor_model(X_batch)
                        loss += self.loss_func(output, y_batch)
                    losses.append(loss / (self.n - self.f))
            reward = np.min(np.array([ls.cpu().numpy() for ls in losses]))
            # print(self.args1.seed, self.args1.agg, self.args1.f, start, reward)
            # self.restore_random_state()
            return -reward

        return v

    def omniscient_callback(self):
        # print("=======================call SSNLP omniscient_callback")
        # self.cache_and_restore_init_state()
        # self.cache_random_state()
        # self.cache_init_states()
        # If the optimal path is already complete, do nothing
        if self.opt_path:
            pass
        else:
            # Identify workers that are NLP attackers
            nlpattackers = []
            for w in self.simulator.workers:
                if isinstance(w, SSNLPAttack):
                    nlpattackers.append(w)

            # Get the longest optimal path from available NLP attackers
            opt_path = []
            for w in nlpattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            # If we've found an optimal path, set it for the current instance
            if opt_path:
                self.opt_path_ful = opt_path
                self.opt_path = deepcopy(opt_path)
                pass
            else:
                # Get datasets for the NLP problem; if not available, compute them
                datasets = []
                if self.datasets:
                    datasets = self.datasets
                else:
                    # self.restore_random_state()
                    data_loaders = []
                    for w in self.simulator.workers:
                        # if not isinstance(w, ByzantineWorker):
                        w.data_loader.sampler.set_epoch(0)
                        data_loaders.append(w.data_loader)
                    for epoch in range(1, self.T + 1):
                        datas = []
                        for data_loader in data_loaders:
                            mini_batches = self.extract_mini_batches(data_loader, self.max_batches_per_epoch)
                            datas.append(mini_batches)
                        datasets.append(datas)
                        for data_loader in data_loaders:
                            data_loader.sampler.set_epoch(epoch)

                    for w in self.simulator.workers:
                        w.data_loader.sampler.set_epoch(0)
                    # self.restore_random_state()

                    reshaped_datasets = []
                    num_workers = len(data_loaders)

                    for _ in range(num_workers):
                        reshaped_datasets.append([])

                    for epoch_data in datasets:
                        for j in range(num_workers):
                            reshaped_datasets[j].extend(epoch_data[j])

                    for worker_data in reshaped_datasets:
                        if worker_data:
                            first_element = worker_data[0]
                            worker_data.append(first_element)

                    datasets = reshaped_datasets
                    self.datasets = datasets

                # Set up models and optimizers for the NLP problem
                standin_model = self.extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()
                # print("initial_theta", initial_theta)
                # """
                #     Only for checking correctness of implementation
                # """
                # assert compare_theta_dicts(self.theta_check, initial_theta), [self.theta_check, initial_theta]
                # print('=====Pass Consistency Check=====')

                labor_model = self.extract_initial_parameters(standin_model, self.device)
                optimizer2_state_dict = standin_optm.state_dict()
                lr = standin_optm.param_groups[0]['lr']
                labor_optm = type(self.optimizer)(labor_model.parameters(), lr=lr)
                labor_optm.load_state_dict(deepcopy(optimizer2_state_dict))
                # self.restore_random_state()
                # assert compare_theta_dicts(self.theta_check, initial_theta), [self.theta_check, initial_theta]

                # bounds = [(-9, 10)] * self.search_size
                bounds = None

                # Determine which cost function to use based on a given parameter args.nlpobj
                cost = self.fun_last if self.args1.nlpobj == 0 else self.fun_lower if self.args1.nlpobj < 0 else self.fun

                weight = 1
                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)
                velocity_lst = []
                for w in self.simulator.workers:
                    velocity_lst.append(w.get_true_gradient())

                # self.restore_random_state()

                # # Iteratively solve the NLP problem to obtain the malicious gradient
                # # for epoch in range(self.T):
                # for p in range(int((self.T * self.max_batches_per_epoch) / self.search_size) + 1):
                #     if p == int((self.T * self.max_batches_per_epoch) / self.search_size):
                #         if (self.T * self.max_batches_per_epoch) % self.search_size != 0:
                #             start = p * self.search_size
                #             end = start + (self.T * self.max_batches_per_epoch) % self.search_size
                #         else:
                #             break
                #     else:
                #         start = p * self.search_size
                #         end = start + self.search_size
                #
                theta1 = deepcopy(theta)
                start = len(self.opt_path_ful)
                end = start + self.search_size

                g0 = np.ones(self.search_size) * -1
                funs = []
                xs = []
                solver = []
                # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                # methods = {'krum': 'Powell', 'cm': 'Powell', 'cp': 'SLSQP', 'rfa': 'L-BFGS-B', 'tm': 'Powell',
                #            'avg': 'Powell'}
                # for method in [methods[self.agg]]:
                methods = ['Powell']
                for method in methods:
                    optimizer_callback = OptimizerCallback(cost_function=cost(weight, theta, state,
                                                                              velocity_lst, self.agg, labor_model,
                                                                              labor_optm, datasets, start, end
                                                                              ), initial_value=float('inf'),
                                                           search_size=self.search_size)
                    try:
                        # self.restore_random_state()
                        res = minimize(optimizer_callback.wrapped_cost, g0,
                                       method=method, bounds=bounds, callback=optimizer_callback.callback)
                        # self.restore_random_state()
                    except EarlyTerminationException as e:
                        print(str(e))
                        res = OptimizeResult()
                        res.x = optimizer_callback.best_g
                        res.fun = optimizer_callback.best_res
                        res.success = False
                        res.message = str(e)

                    print(self.args1.seed, self.args1.agg, self.args1.f, start, optimizer_callback.best_res,
                          optimizer_callback.best_g, res.message)
                    funs.append(optimizer_callback.best_res)
                    xs.append(optimizer_callback.best_g)
                    solver.append(method)
                    # print("Done:", epoch, method)
                # if not funs:
                #     # raise RuntimeError('NLP Attack Fail!')
                #     funs.append(best_res)
                #     xs.append(best_g)
                #     solver.append(method)
                # print("Done:", epoch, method)
                ind = funs.index(min(funs))
                best_x = xs[ind]
                best_solver = solver[ind]
                # print(best_solver)
                # print(funs[ind])
                # print(best_x)
                # print(len(lambdas), best_x)
                lambda_ts = []
                for itr in range(self.search_size):
                    lambda_ts.append(best_x[itr])

                """
                    Only for checking correctness of implementation
                """
                theta2 = deepcopy(theta)
                assert compare_theta_dicts(theta1, theta2), [theta1, theta2]
                # assert compare_theta_dicts(self.theta_check, initial_theta), [self.theta_check, initial_theta]

                # self.restore_random_state()
                for itr in range(start, end):
                    grads = []
                    honest_grads = []
                    for k in range(self.n):
                        if itr != start:
                            standin_model.train()
                            mini_batches = datasets[k]
                            X_batch, y_batch = mini_batches[itr]
                            X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                            # if itr == start:
                            #     self._restore_init_state(self.init_states[k])
                            standin_optm.zero_grad()
                            output = standin_model(X_batch)
                            loss = self.loss_func(output, y_batch)
                            # print("fakeloss", k, loss)
                            loss.backward()
                            gradient = torch.cat(
                                [param.grad.clone().detach().flatten() for param in standin_model.parameters() if
                                 param.grad is not None])


                            velocity_lst[k] = velocity_lst[k] * self.momentum + (1 - self.momentum) * gradient
                        if k < self.n - self.f:
                            grads.append(velocity_lst[k].clone().detach())
                            honest_grads.append(velocity_lst[k].clone().detach())

                    honest_avg = torch.stack(honest_grads, dim=0).mean(dim=0)
                    for k in range(self.f):
                        lambda_t = lambda_ts[itr % self.search_size]
                        grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))

                    aggregated = self.agg(grads)

                    beg = 0
                    for group in standin_optm.param_groups:
                        for p in group["params"]:
                            if p.grad is None:
                                continue
                            ter = beg + len(p.grad.view(-1))
                            x = aggregated[beg:ter].reshape_as(p.grad.data)
                            p.grad.data = x.clone().detach()
                            beg = ter
                    # print("standin_model.parameters()")
                    # for name, param in standin_model.named_parameters():
                    #     print(name, param.data)
                    if self.clipping:
                        torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
                    standin_optm.step()
                    weight *= self.args1.nlpobj

                # self.restore_random_state()
                self.theta_check = standin_model.state_dict()
                self.state_check = standin_optm.state_dict()

                self.opt_path_ful.extend(lambda_ts)
                self.opt_path = deepcopy(lambda_ts)
                if len(self.opt_path_ful) >= self.T:
                    lambdas = []
                    for la in self.opt_path_ful:
                        lambdas.append(str(la))
                    save_lambda = self.save_dir + '/lambdas.txt'
                    save_txt(lambdas, save_lambda)

                # Once solved, assign the determined optimal path to all NLP attackers
                for w in nlpattackers:
                    w.opt_path_ful = deepcopy(self.opt_path_ful)
                    w.opt_path = deepcopy(self.opt_path)

        # Extract the next action (lambda_t) from the optimal path and determine the malicious gradient accordingly
        lambda_t = self.opt_path.pop(0)
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)
        mal_grad = (1 - lambda_t) * honest_avg + lambda_t * (-honest_avg)
        # print("nonlinear:721:mal_grad", mal_grad)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


def compare_theta_dicts(dict1, dict2):
    for key in dict1:
        if not torch.equal(dict1[key], dict2[key]):
            return False
    return True


def compare_randome_state_dict(dict1, dict2):
    for key in dict1:
        if key == 'torch' or 'torch_cuda':
            if not torch.equal(dict1[key], dict2[key]):
                return False
        elif key == 'numpy':
            for item1, item2 in zip(dict1[key], dict2[key]):
                if isinstance(item1, np.ndarray) and isinstance(item2, np.ndarray):
                    if not np.array_equal(item1, item2):
                        return False
                else:
                    if item1 != item2:
                        return False

    return True

# def deep_clone(obj):
#     if isinstance(obj, torch.Tensor):
#         return obj.clone().detach()
#     elif isinstance(obj, dict):
#         return {k: deep_clone(v) for k, v in obj.items()}
#     else:
#         cp = obj
#         return dp(cp)
#
# # def deep_clone_state_dict(state_dict):
# def deepcopy(state_dict):
#     return deep_clone(state_dict)


# class NLP1Attack(ByzantineWorker):
#     """
#     Class that implements the NLP1Attack, a subversion of SSNLP with
#     segment length 1
#     """
#
#     def __init__(self, n, f, T, use_cuda,
#                  max_batches_per_epoch, agg, length, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.n = n
#         self.f = f
#         self.T = T
#         self.use_cuda = use_cuda
#         self.max_batches_per_epoch = max_batches_per_epoch
#         self.agg = agg
#         self.length = length
#         self.opt_path_ful = []
#
#         self.datasets = []
#
#         self.random_states = {}
#
#     def get_gradient(self):
#         """Return the malicious gradient created by the attacker."""
#         return self._gradient
#
#     def omniscient_callback(self):
#         """
#         Main method for the attack where the malicious gradient is computed.
#         """
#         # Gather gradients and data from honest workers.
#         honest_gradients = []
#         datas = []
#         for w in self.simulator.workers:
#             if not isinstance(w, ByzantineWorker):
#                 honest_gradients.append(w.get_gradient())
#                 datas.append((w.running["data"], w.running["target"]))
#
#         # print('datas1', datas[0][-1], datas[1][-1])
#
#         stacked_gradients = torch.stack(honest_gradients, 0)
#         honest_avg = torch.mean(stacked_gradients, 0)
#
#         self.agg = deepcopy(self.simulator.aggregator)
#
#         # control the attack length: after how many iterations we stop the attack
#         # If the attack hasn't reached its designated length
#         if len(self.opt_path_ful) < self.length:
#             # apply optimal path to all Byzantine workers
#             nlp1attackers = []
#             for w in self.simulator.workers:
#                 if isinstance(w, NLP1Attack):
#                     nlp1attackers.append(w)
#
#             opt_path = []
#             for w in nlp1attackers:
#                 if len(w.opt_path_ful) > len(self.opt_path_ful):
#                     opt_path = deepcopy(w.opt_path_ful)
#                     break
#             # If an optimal path is found among them, use it
#             if opt_path:
#                 self.opt_path_ful = opt_path
#                 pass
#             else:
#                 # Compute datasets for the NLP problem
#                 datasets = []
#                 if self.datasets:
#                     datasets = self.datasets
#                 else:
#                     honest_data_loaders = []
#                     for w in self.simulator.workers:
#                         if not isinstance(w, ByzantineWorker):
#                             w.data_loader.sampler.set_epoch(0)
#                             honest_data_loaders.append(w.data_loader)
#                     for epoch in range(1, self.T + 1):
#                         datas = []
#                         for data_loader in honest_data_loaders:
#                             mini_batches = extract_mini_batches(data_loader, self.max_batches_per_epoch)
#                             datas.append(mini_batches)
#                         datasets.append(datas)
#                         for data_loader in honest_data_loaders:
#                             data_loader.sampler.set_epoch(epoch)
#
#                     for w in self.simulator.workers:
#                         w.data_loader.sampler.set_epoch(len(self.opt_path_ful) % self.max_batches_per_epoch)
#
#                     reshaped_datasets = []
#                     num_honest_workers = len(honest_data_loaders)
#
#                     for _ in range(num_honest_workers):
#                         reshaped_datasets.append([])
#
#                     for epoch_data in datasets:
#                         for j in range(num_honest_workers):
#                             reshaped_datasets[j].extend(epoch_data[j])
#
#                     for worker_data in reshaped_datasets:
#                         if worker_data:
#                             first_element = worker_data[0]
#                             worker_data.append(first_element)
#
#                     datasets = reshaped_datasets
#                     self.datasets = datasets
#
#                 # datas = [worker_data[len(self.opt_path_ful) + 1] for worker_data in datasets]
#                 datas = [worker_data[len(self.opt_path_ful)] for worker_data in datasets]
#                 # print('DATAS2', datas[0][-1], datas[1][-1])
#
#                 # Extract parameters from the model for manipulation
#                 standin_model = extract_initial_parameters(self.model, self.device)
#                 optimizer1_state_dict = self.optimizer.state_dict()
#                 lr = self.optimizer.param_groups[0]['lr']
#                 standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
#                 standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))
#
#                 initial_theta = standin_model.state_dict()
#                 init_state = standin_optm.state_dict()
#
#                 # Define the objective function for the NLP solver
#                 def fun(datas, honest_grads, honest_avg, init_theta_epoch, init_state_dict, aggr):
#                     """
#                     The objective function that needs to be minimized by the NLP solver.
#                     """
#
#                     def v(g):
#                         self.cache_random_state()
#                         agg = deepcopy(aggr)
#                         standin_model.load_state_dict(deepcopy(init_theta_epoch))
#                         standin_optm.load_state_dict(deepcopy(init_state_dict))
#                         grads = deepcopy(honest_grads)
#                         standin_model.train()
#
#                         for k in range(self.f):
#                             lambda_t = g[0]
#                             grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
#                         # print("grads:1409", g.dtype, grads[0].dtype, grads[-1].dtype)
#                         aggregated = agg(grads)
#
#                         beg = 0
#                         for group in standin_optm.param_groups:
#                             for p in group["params"]:
#                                 if p.grad is None:
#                                     continue
#                                 ter = beg + len(p.grad.view(-1))
#                                 x = aggregated[beg:ter].reshape_as(p.grad.data)
#                                 p.grad.data = x.clone().detach()
#                                 beg = ter
#                         if self.clipping:
#                             torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
#                         standin_optm.step()
#
#                         reward = 0
#                         standin_model.eval()
#                         with torch.no_grad():
#                             for sam in datas:
#                                 X_batch, y_batch = sam
#                                 X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
#                                 output = standin_model(X_batch)
#                                 reward += self.loss_func(output, y_batch)
#
#                         reward = reward / (self.n - self.f)
#                         reward = reward.cpu().numpy()
#                         # print(len(self.opt_path_ful), reward)
#                         # print(len(self.opt_path_ful), self.f, reward, reward.dtype)
#                         self.restore_random_state()
#                         return -reward
#
#                     return v
#
#                 bounds = None
#
#                 cost = fun
#
#                 theta = deepcopy(initial_theta)
#                 state = deepcopy(init_state)
#
#                 g0 = np.zeros(1)
#                 funs = []
#                 xs = []
#                 solver = []
#                 # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
#                 methods = ['Powell']
#                 for method in methods:
#                     optimizer_callback = OptimizerCallback(cost_function=cost(datas, honest_gradients, honest_avg,
#                                                                               theta, state, self.agg),
#                                                            initial_value=float('inf'), search_size=1)
#                     try:
#                         res = minimize(optimizer_callback.wrapped_cost, g0,
#                                        method=method, bounds=bounds, callback=optimizer_callback.callback)
#                     except EarlyTerminationException as e:
#                         print(str(e))
#                         res = OptimizeResult()
#                         res.x = optimizer_callback.best_g
#                         res.fun = optimizer_callback.best_res
#                         res.success = False
#                         res.message = str(e)
#
#                     print(self.f, len(self.opt_path_ful), optimizer_callback.best_res, optimizer_callback.best_g, res.message)
#                     funs.append(optimizer_callback.best_res)
#                     xs.append(optimizer_callback.best_g)
#                     solver.append(method)
#                     # print("Done:", len(self.opt_path_ful), method)
#                 # if not funs:
#                 #     # raise RuntimeError('NLP Attack Fail!')
#                 #     funs.append(best_res)
#                 #     xs.append(best_g)
#                 #     solver.append(method)
#                     # print("Done:", len(self.opt_path_ful), method)
#                 ind = funs.index(min(funs))
#                 best_x = xs[ind]
#                 best_solver = solver[ind]
#                 # print(best_solver)
#                 # print(funs[ind])
#                 # print(len(self.opt_path_ful), best_x)
#
#                 grads = deepcopy(honest_gradients)
#                 for k in range(self.f):
#                     lambda_t = best_x[0]
#                     grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
#                     # print(grads[-1].dtype)
#                 aggregated = self.agg(grads)
#
#                 self.opt_path_ful.append(best_x[0])
#         else:
#             self.opt_path_ful.append(0)
#
#         # Apply the malicious modification to the gradient
#         lambda_t = self.opt_path_ful[-1]
#         mal_grad = (1 - lambda_t) * honest_avg + lambda_t * (-honest_avg)
#         # print("nonlinear:1993:mal_grad", lambda_t, mal_grad, mal_grad.dtype)
#         self._gradient = mal_grad
#
#     def cache_random_state(self) -> None:
#         # self.random_states["random"] = random.getstate()
#         if self.use_cuda:
#             if torch.cuda.is_available():
#                 self.random_states["torch_cuda"] = torch.cuda.get_rng_state()
#             elif torch.backends.mps.is_available():
#                 pass
#                 # You cannot cache the MPS RNG state, but you can set a seed.
#                 # self.random_states["torch_cuda"] = random.randint(0, 2 ** 32 - 1)
#                 # torch.backends.mps.manual_seed(self.random_states["torch_cuda"])
#         self.random_states["torch"] = torch.get_rng_state()
#         self.random_states["numpy"] = np.random.get_state()
#
#     def restore_random_state(self) -> None:
#         if self.use_cuda:
#             if torch.cuda.is_available():
#                 torch.cuda.set_rng_state(self.random_states["torch_cuda"])
#             elif torch.backends.mps.is_available():
#                 pass
#                 # torch.backends.mps.manual_seed(self.random_states["torch_cuda"])
#         # random.setstate(self.random_states["random"])
#         torch.set_rng_state(self.random_states["torch"])
#         np.random.set_state(self.random_states["numpy"])
#
#     def set_gradient(self, gradient) -> None:
#         """Method to set the gradient. Not implemented for the NLP1Attack."""
#         raise NotImplementedError
#
#     def apply_gradient(self) -> None:
#         """Method to apply the gradient. Not implemented for the NLP1Attack."""
#         raise NotImplementedError


# class NLP1AttackLC(ByzantineWorker):
#
#     def __init__(self, n, f, T, max_batches_per_epoch, agg, length, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.n = n
#         self.f = f
#         self.T = T
#         self.max_batches_per_epoch = max_batches_per_epoch
#         self.agg = agg
#         self.length = length
#         self.opt_path_ful = []
#
#         self.datasets = []
#
#     def get_gradient(self):
#         return self._gradient
#
#     def omniscient_callback(self):
#         honest_gradients = []
#         datas = []
#         for w in self.simulator.workers:
#             if not isinstance(w, ByzantineWorker):
#                 honest_gradients.append(w.get_gradient())
#                 datas.append((w.running["data"], w.running["target"]))
#
#         stacked_gradients = torch.stack(honest_gradients, 0)
#         honest_avg = torch.mean(stacked_gradients, 0)
#
#         self.agg = deepcopy(self.simulator.aggregator)
#
#         # control the attack length: after how many iterations we stop the attack
#         if len(self.opt_path_ful) < self.length:
#             # apply optimal path to all Byzantine workers
#             stabattackers = []
#             for w in self.simulator.workers:
#                 if isinstance(w, NLP1AttackLC):
#                     stabattackers.append(w)
#
#             opt_path = []
#             for w in stabattackers:
#                 if len(w.opt_path_ful) > len(self.opt_path_ful):
#                     opt_path = deepcopy(w.opt_path_ful)
#                     break
#             if opt_path:
#                 self.opt_path_ful = opt_path
#                 pass
#             else:
#                 # do the NLP solver work
#                 datasets = []
#                 if self.datasets:
#                     datasets = self.datasets
#                 else:
#                     honest_data_loaders = []
#                     for w in self.simulator.workers:
#                         if not isinstance(w, ByzantineWorker):
#                             w.data_loader.sampler.set_epoch(0)
#                             honest_data_loaders.append(w.data_loader)
#                     for epoch in range(1, self.T + 1):
#                         datas = []
#                         for data_loader in honest_data_loaders:
#                             mini_batches = extract_mini_batches(data_loader, self.max_batches_per_epoch)
#                             datas.append(mini_batches)
#                         datasets.append(datas)
#                         for data_loader in honest_data_loaders:
#                             data_loader.sampler.set_epoch(epoch)
#
#                     for w in self.simulator.workers:
#                         w.data_loader.sampler.set_epoch(len(self.opt_path_ful) % self.max_batches_per_epoch)
#
#                     reshaped_datasets = []
#                     num_honest_workers = len(honest_data_loaders)
#
#                     for _ in range(num_honest_workers):
#                         reshaped_datasets.append([])
#
#                     for epoch_data in datasets:
#                         for j in range(num_honest_workers):
#                             reshaped_datasets[j].extend(epoch_data[j])
#
#                     for worker_data in reshaped_datasets:
#                         if worker_data:
#                             first_element = worker_data[0]
#                             worker_data.append(first_element)
#
#                     datasets = reshaped_datasets
#                     self.datasets = datasets
#
#                 # datas = [worker_data[len(self.opt_path_ful) + 1] for worker_data in datasets]
#                 datas = [worker_data[len(self.opt_path_ful)] for worker_data in datasets]
#
#                 standin_model = extract_initial_parameters(self.model, self.device)
#                 optimizer1_state_dict = self.optimizer.state_dict()
#                 lr = self.optimizer.param_groups[0]['lr']
#                 standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
#                 standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))
#
#                 initial_theta = standin_model.state_dict()
#                 init_state = standin_optm.state_dict()
#
#                 def fun(datas, honest_grads, honest_avg, init_theta_epoch, init_state_dict, aggr):
#
#                     def v(g):
#                         agg = deepcopy(aggr)
#                         standin_model.load_state_dict(deepcopy(init_theta_epoch))
#                         standin_optm.load_state_dict(deepcopy(init_state_dict))
#                         grads = deepcopy(honest_grads)
#
#                         byzantine_grad = sum(a * b for a, b in zip(g, grads))
#                         for k in range(self.f):
#                             grads.append(deepcopy(byzantine_grad))
#                         aggregated = agg(grads)
#
#                         beg = 0
#                         for group in standin_optm.param_groups:
#                             for p in group["params"]:
#                                 if p.grad is None:
#                                     continue
#                                 ter = beg + len(p.grad.view(-1))
#                                 x = aggregated[beg:ter].reshape_as(p.grad.data)
#                                 p.grad.data = x.clone().detach()
#                                 beg = ter
#                         if self.clipping:
#                             torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
#                         standin_optm.step()
#
#                         reward = 0
#                         with torch.no_grad():
#                             for sam in datas:
#                                 X_batch, y_batch = sam
#                                 output = standin_model(X_batch)
#                                 reward += self.loss_func(output, y_batch)
#
#                         reward = reward / (self.n - self.f)
#                         reward = reward.cpu().numpy()
#                         # print(reward.dtype, g.dtype)
#                         return -reward
#
#                     return v
#
#                 bounds = None
#
#                 cost = fun
#
#                 theta = deepcopy(initial_theta)
#                 state = deepcopy(init_state)
#
#                 g0 = np.zeros(len(honest_gradients))
#                 funs = []
#                 xs = []
#                 solver = []
#                 # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
#                 methods = ['Powell']
#                 for method in methods:
#                     best_res = 0
#                     best_g = np.zeros(1)
#                     iter_without_improvement = 0
#                     max_iter_without_improvement = 200
#
#                     def callback(g):
#                         nonlocal best_res, best_g, iter_without_improvement
#                         f = cost(datas, honest_gradients, honest_avg, theta, state, self.agg)(g)
#
#                         # if np.isnan(f):
#                         #     print("Warning: f is nan. Skipping this iteration.")
#                         #     return
#
#                         if f < best_res:
#                             best_res = f
#                             best_g = g
#                             iter_without_improvement = 0
#                         else:
#                             iter_without_improvement += 1
#                         if iter_without_improvement >= max_iter_without_improvement:
#                             raise Exception('No improvement for 200 iterations, stopping optimization.')
#
#                     try:
#                         res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg),
#                                        g0, method=method, bounds=bounds, callback=callback)
#                     except Exception as e:
#                         # print(str(e))  # Output the exception message
#                         res = OptimizeResult()
#                         res.x = best_g  # The best solution found
#                         res.fun = best_res  # The function value at the best solution
#                         res.success = True  # The optimization did not converge, but still see it as success
#                     if res.success:
#                         funs.append(best_res)
#                         xs.append(best_g)
#                         solver.append(method)
#                     # print("Done:", len(self.opt_path_ful), method)
#                 if not funs:
#                     g0 = best_g
#                     res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg), g0,
#                                    method='Powell', bounds=bounds, callback=callback)
#                     # raise RuntimeError('NLP Attack Fail!')
#                     funs.append(best_res)
#                     xs.append(best_g)
#                     solver.append(method)
#                     # print("Done:", len(self.opt_path_ful), method)
#                 ind = funs.index(min(funs))
#                 best_x = xs[ind]
#                 best_solver = solver[ind]
#                 # print(best_solver)
#                 # print(funs[ind])
#                 # print(best_x)
#
#                 self.opt_path_ful.append(best_x)
#         else:
#             self.opt_path_ful.append(np.ones(len(honest_gradients)) / len(honest_gradients))
#
#         lambda_ts = self.opt_path_ful[-1]
#         # print(lambda_t)
#         mal_grad = sum(a * b for a, b in zip(lambda_ts, honest_gradients))
#         # print("nonlinear:118:mal_grad", mal_grad)
#         self._gradient = mal_grad
#
#     def set_gradient(self, gradient) -> None:
#         raise NotImplementedError
#
#     def apply_gradient(self) -> None:
#         raise NotImplementedError


# class EstNLP1Attack(ByzantineWorker):
#
#     def __init__(self, n, f, T, max_batches_per_epoch, agg, length, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.n = n
#         self.f = f
#         self.T = T
#         self.max_batches_per_epoch = max_batches_per_epoch
#         self.agg = agg
#         self.length = length
#         self.opt_path_ful = []
#
#         self.datasets = []
#
#     def get_gradient(self):
#         return self._gradient
#
#     def omniscient_callback(self):
#         honest_gradients = []
#         for w in self.simulator.workers:
#             if not isinstance(w, ByzantineWorker):
#                 honest_gradients.append(w.get_gradient())
#         honest_num = len(honest_gradients)
#
#         # print('datas1', datas[0][-1], datas[1][-1])
#
#         stacked_gradients = torch.stack(honest_gradients, 0)
#         honest_avg = torch.mean(stacked_gradients, 0)
#
#         self.agg = deepcopy(self.simulator.aggregator)
#
#         # control the attack length: after how many iterations we stop the attack
#         if len(self.opt_path_ful) < self.length:
#             # apply optimal path to all Byzantine workers
#             estnlp1attackers = []
#             for w in self.simulator.workers:
#                 if isinstance(w, EstNLP1Attack):
#                     estnlp1attackers.append(w)
#
#             opt_path = []
#             for w in estnlp1attackers:
#                 if len(w.opt_path_ful) > len(self.opt_path_ful):
#                     opt_path = deepcopy(w.opt_path_ful)
#                     break
#             if opt_path:
#                 self.opt_path_ful = opt_path
#                 pass
#             else:
#                 # do the NLP solver work
#                 datasets = []
#                 if self.datasets:
#                     datasets = self.datasets
#                 else:
#                     byzantine_data_loaders = []
#                     for w in self.simulator.workers:
#                         if isinstance(w, ByzantineWorker):
#                             w.data_loader.sampler.set_epoch(0)
#                             byzantine_data_loaders.append(w.data_loader)
#
#                     datas = []
#                     for data_loader in byzantine_data_loaders:
#                         mini_batches = extract_subdataset(data_loader, self.max_batches_per_epoch)
#                         datas.append(mini_batches)
#                     datasets = datas
#
#                     for w in self.simulator.workers:
#                         w.data_loader.sampler.set_epoch(len(self.opt_path_ful) % self.max_batches_per_epoch)
#
#                     reshaped_datasets = []
#
#                     for ds in datasets:
#                         for d in ds:
#                             reshaped_datasets.append(d)
#
#                     datasets = reshaped_datasets
#                     self.datasets = datasets
#                     print(len(datasets))
#
#                 g = torch.Generator()
#                 g.manual_seed(len(self.opt_path_ful))
#                 indices = torch.randperm(len(datasets), generator=g).tolist()
#
#                 datas = []
#                 for idx in range(honest_num):
#                     datas.append(datasets[indices[idx]])
#
#                 # print('DATAS2', datas[0][-1], datas[1][-1])
#
#                 standin_model = extract_initial_parameters(self.model, self.device)
#                 optimizer1_state_dict = self.optimizer.state_dict()
#                 lr = self.optimizer.param_groups[0]['lr']
#                 standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
#                 standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))
#
#                 initial_theta = standin_model.state_dict()
#                 init_state = standin_optm.state_dict()
#
#                 def fun(datas, honest_grads, honest_avg, init_theta_epoch, init_state_dict, aggr):
#
#                     def v(g):
#                         agg = deepcopy(aggr)
#                         standin_model.load_state_dict(deepcopy(init_theta_epoch))
#                         standin_optm.load_state_dict(deepcopy(init_state_dict))
#                         grads = deepcopy(honest_grads)
#
#                         for k in range(self.f):
#                             lambda_t = g[0]
#                             grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
#                             # print(grads[-1].dtype)
#                         aggregated = agg(grads)
#
#                         beg = 0
#                         for group in standin_optm.param_groups:
#                             for p in group["params"]:
#                                 if p.grad is None:
#                                     continue
#                                 ter = beg + len(p.grad.view(-1))
#                                 x = aggregated[beg:ter].reshape_as(p.grad.data)
#                                 p.grad.data = x.clone().detach()
#                                 beg = ter
#                         if self.clipping:
#                             torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
#                         standin_optm.step()
#
#                         reward = 0
#                         with torch.no_grad():
#                             for sam in datas:
#                                 X_batch, y_batch = sam
#                                 X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
#                                 output = standin_model(X_batch)
#                                 reward += self.loss_func(output, y_batch)
#
#                         reward = reward / (self.n - self.f)
#                         reward = reward.cpu().numpy()
#                         # print(len(self.opt_path_ful), reward)
#                         return -reward
#
#                     return v
#
#                 bounds = None
#
#                 cost = fun
#
#                 theta = deepcopy(initial_theta)
#                 state = deepcopy(init_state)
#
#                 # Here begins the line search to replace the nonlinear programming
#                 # num_steps = 10
#                 # lambdas = np.linspace(-4, 5, num_steps)
#                 # best_lambda = 0
#                 # best_reward = float('inf')
#
#                 # for lambda_t in lambdas:
#                 #     reward = cost(datas, honest_gradients, honest_avg, theta, state, self.agg)([lambda_t])
#                 #     if reward < best_reward:
#                 #         best_reward = reward
#                 #         best_lambda = lambda_t
#                 best_lambda = self.golden_section_search(cost(datas, honest_gradients, honest_avg, theta, state, self.agg), [-4], [5], 1)
#
#                 self.opt_path_ful.append(best_lambda)
#         else:
#             self.opt_path_ful.append(0)
#
#         lambda_t = self.opt_path_ful[-1]
#         # print(self.opt_path_ful)
#         # print(lambda_t)
#         mal_grad = (1 - lambda_t) * honest_avg + lambda_t * (-honest_avg)
#         # print("nonlinear:1978:mal_grad", lambda_t, mal_grad)
#         self._gradient = mal_grad
#
#     def golden_section_search(self, f, a, b, tol):
#         phi = (1 + 5 ** 0.5) / 2  # golden ratio
#         c = [b[0] - (b[0] - a[0]) / phi]
#         d = [a[0] + (b[0] - a[0]) / phi]
#         while abs(c[0] - d[0]) > tol:
#             if f(c) < f(d):
#                 b = d
#             else:
#                 a = c
#
#             c = [b[0] - (b[0] - a[0]) / phi]
#             d = [a[0] + (b[0] - a[0]) / phi]
#
#         return (b[0] + a[0]) / 2
#
#     def set_gradient(self, gradient) -> None:
#         raise NotImplementedError
#
#     def apply_gradient(self) -> None:
#         raise NotImplementedError


class TAWFOEAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, epsilon, n, f, save_dir, agg, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon
        self.n = n
        self.f = f
        self.save_dir = save_dir
        self._gradient = None
        # self.save_dir = save_dir
        self.agg = agg
        self.window_size = 30
        self.byz_grads = []
        self.honest_hist = []
        self.dot_products = []

    def get_gradient(self):
        # self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
        #     self._gradient.clone().detach().cpu().tolist()[1]))
        # save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)
        # self.honest_hist.append(honest_avg)

        # best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [-20], [0], 0.1)
        best_lambda = -1

        dot_products = [torch.sum(honest_avg * grad).item() for grad in self.honest_hist]
        avg_dot_product = sum(dot_products) / len(dot_products) if dot_products else 0
        print("Average Dot Product:", avg_dot_product, len(self.honest_hist))
        if avg_dot_product > 0 and len(self.honest_hist) >= self.window_size:
            # best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [0], [20], 0.1)
            best_lambda = 5

        # update honest gradient history
        self.honest_hist.append(honest_avg)
        if len(self.honest_hist) > self.window_size:
            self.honest_hist.pop(0)

        # if len(self.honest_hist) > 0:
        #     # compute dot product
        #     dot_product = torch.sum(honest_avg * self.honest_hist[-1])
        #     self.dot_products.append(dot_product.item())
        #
        #     # keep window_size of dot products
        #     if len(self.dot_products) > self.window_size:
        #         self.dot_products.pop(0)
        #
        #     # get the average of slide window
        #     avg_dot_product = sum(self.dot_products) / len(self.dot_products)
        #     print("Average Dot Product:", avg_dot_product, len(self.dot_products))
        #
        #     if avg_dot_product > 0 and len(self.dot_products) >= self.window_size:
        #         # best_lambda = self.golden_section_search(self.cost(honest_avg, gradients), [0], [20], 0.1)
        #         best_lambda = 20
        #
        # # update honest gradient history
        # self.honest_hist.append(honest_avg)

        self._gradient = best_lambda * honest_avg
        for k in range(self.f):
            gradients.append(self._gradient)
        self.agg(gradients)

    def cost(self, honest_avg, honest_grads):
        def v(g):
            aggr = deepcopy(self.agg)
            honest = deepcopy(honest_avg)
            grads = deepcopy(honest_grads)
            for k in range(self.f):
                grads.append(g[0] * honest)

            aggregated = aggr(grads)

            return -torch.norm(aggregated - honest)
        return v

    def golden_section_search(self, f, a, b, tol):
        phi = (1 + 5 ** 0.5) / 2  # golden ratio
        c = [b[0] - (b[0] - a[0]) / phi]
        d = [a[0] + (b[0] - a[0]) / phi]
        while abs(c[0] - d[0]) > tol:
            if f(c) < f(d):
                b = d
            else:
                a = c

            c = [b[0] - (b[0] - a[0]) / phi]
            d = [a[0] + (b[0] - a[0]) / phi]

        return (b[0] + a[0]) / 2

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError

