from copy import deepcopy

import numpy as np
import torch
from scipy.optimize import Bounds
from scipy.optimize import OptimizeResult
from scipy.optimize import minimize

from codes.components.worker import ByzantineWorker

# torch.set_printoptions(precision=20)
# torch.set_default_dtype(torch.float64)


def compute_jacobian(inputs, output):
    """
    :param inputs: Batch X Size (e.g. Depth X Width X Height)
    :param output: Batch X Classes
    :return: jacobian: Batch X Classes X Size
    """
    assert inputs.requires_grad

    num_classes = output.size()[1]

    jacobian = torch.zeros(num_classes, *inputs.size())
    grad_output = torch.zeros(*output.size())
    if inputs.is_cuda:
        grad_output = grad_output.cuda()
        jacobian = jacobian.cuda()

    for i in range(num_classes):
        zero_gradients(inputs)
        grad_output.zero_()
        grad_output[:, i] = 1
        output.backward(grad_output, retain_graph=True)
        jacobian[i] = inputs.grad.data

    return torch.transpose(jacobian, dim0=0, dim1=1)


def zero_gradients(inputs):
    if inputs.grad is not None:
        inputs.grad.zero_()


def extract_initial_parameters(model, device):
    standin_model = type(model)().to(device)
    # Get the state dict of model1
    state_dict_model = model.state_dict()
    # Load the state dict of model1 into model2
    standin_model.load_state_dict(state_dict_model)
    # Get gradients from model1
    grads1 = {name: param.grad for name, param in model.named_parameters()}

    # Manually copy the gradients to model2
    for name, param in standin_model.named_parameters():
        if grads1[name] is not None:
            param.grad = grads1[name].clone()

    return standin_model


class StabilityBreakingAttack(ByzantineWorker):

    def __init__(self, n, f, agg, length, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n = n
        self.f = f
        self.agg = agg
        self.length = length
        self.opt_path_ful = []

    def get_gradient(self):
        return self._gradient

    def omniscient_callback(self):
        honest_gradients = []
        datas = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                honest_gradients.append(w.get_gradient())
                datas.append((w.running["data"], w.running["target"]))

        stacked_gradients = torch.stack(honest_gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)

        self.agg = deepcopy(self.simulator.aggregator)

        # control the attack length: after how many iterations we stop the attack
        if len(self.opt_path_ful) < self.length:
            # apply optimal path to all Byzantine workers
            stabattackers = []
            for w in self.simulator.workers:
                if isinstance(w, StabilityBreakingAttack):
                    stabattackers.append(w)

            opt_path = []
            for w in stabattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            if opt_path:
                self.opt_path_ful = opt_path
                pass
            else:
                # do the NLP solver work
                standin_model = extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()

                def fun(datas, honest_grads, honest_avg, init_theta_epoch, init_state_dict, aggr):

                    def v(g):
                        agg = deepcopy(aggr)
                        standin_model.load_state_dict(deepcopy(init_theta_epoch))
                        standin_optm.load_state_dict(deepcopy(init_state_dict))
                        grads = deepcopy(honest_grads)

                        for k in range(self.f):
                            lambda_t = g[0]
                            grads.append((1 - lambda_t) * honest_avg + lambda_t * (-honest_avg))
                        aggregated = agg(grads)

                        beg = 0
                        for group in standin_optm.param_groups:
                            for p in group["params"]:
                                if p.grad is None:
                                    continue
                                ter = beg + len(p.grad.view(-1))
                                x = aggregated[beg:ter].reshape_as(p.grad.data)
                                p.grad.data = x.clone().detach()
                                beg = ter
                        if self.clipping:
                            torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
                        standin_optm.step()

                        reward = 0
                        with torch.no_grad():
                            for sam in datas:
                                X_batch, y_batch = sam
                                output = standin_model(X_batch)
                                max_minus_min_per_sample = output.max(dim=1)[0] - output.min(dim=1)[0]
                                average_max_minus_min = max_minus_min_per_sample.mean()
                                reward += average_max_minus_min

                        reward = reward / (self.n - self.f)
                        return -reward.cpu().numpy()

                    return v

                bounds = None

                cost = fun

                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)

                g0 = np.zeros(1)
                funs = []
                xs = []
                solver = []
                # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                methods = ['Powell']
                for method in methods:
                    best_res = 0
                    best_g = np.zeros(1)
                    iter_without_improvement = 0
                    max_iter_without_improvement = 200

                    def callback(g):
                        nonlocal best_res, best_g, iter_without_improvement
                        f = cost(datas, honest_gradients, honest_avg, theta, state, self.agg)(g)

                        # if np.isnan(f):
                        #     print("Warning: f is nan. Skipping this iteration.")
                        #     return

                        if f < best_res:
                            best_res = f
                            best_g = g
                            iter_without_improvement = 0
                        else:
                            iter_without_improvement += 1
                        if iter_without_improvement >= max_iter_without_improvement:
                            raise Exception('No improvement for 200 iterations, stopping optimization.')

                    try:
                        res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg),
                                       g0, method=method, bounds=bounds, callback=callback)
                    except Exception as e:
                        # print(str(e))  # Output the exception message
                        res = OptimizeResult()
                        res.x = best_g  # The best solution found
                        res.fun = best_res  # The function value at the best solution
                        res.success = True  # The optimization did not converge, but still see it as success
                    if res.success:
                        funs.append(best_res)
                        xs.append(best_g)
                        solver.append(method)
                    # print("Done:", len(self.opt_path_ful), method)
                if not funs:
                    g0 = best_g
                    res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg), g0,
                                   method='Powell', bounds=bounds, callback=callback)
                    # raise RuntimeError('NLP Attack Fail!')
                    funs.append(best_res)
                    xs.append(best_g)
                    solver.append(method)
                    # print("Done:", len(self.opt_path_ful), method)
                ind = funs.index(min(funs))
                best_x = xs[ind]
                best_solver = solver[ind]
                # print(best_solver)
                # print(funs[ind])
                # print(best_x)

                self.opt_path_ful.append(best_x[0])
        else:
            self.opt_path_ful.append(0)

        lambda_t = self.opt_path_ful[-1]
        # print(lambda_t)
        mal_grad = (1 - lambda_t) * honest_avg + lambda_t * (-honest_avg)
        # print("nonlinear:118:mal_grad", mal_grad)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError


class StabilityBreakingAttackLC(ByzantineWorker):

    def __init__(self, n, f, agg, length, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.n = n
        self.f = f
        self.agg = agg
        self.length = length
        self.opt_path_ful = []

    def get_gradient(self):
        return self._gradient

    def omniscient_callback(self):
        honest_gradients = []
        datas = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                honest_gradients.append(w.get_gradient())
                datas.append((w.running["data"], w.running["target"]))

        stacked_gradients = torch.stack(honest_gradients, 0)
        honest_avg = torch.mean(stacked_gradients, 0)

        self.agg = deepcopy(self.simulator.aggregator)

        # control the attack length: after how many iterations we stop the attack
        if len(self.opt_path_ful) < self.length:
            # apply optimal path to all Byzantine workers
            stabattackers = []
            for w in self.simulator.workers:
                if isinstance(w, StabilityBreakingAttackLC):
                    stabattackers.append(w)

            opt_path = []
            for w in stabattackers:
                if len(w.opt_path_ful) > len(self.opt_path_ful):
                    opt_path = deepcopy(w.opt_path_ful)
                    break
            if opt_path:
                self.opt_path_ful = opt_path
                pass
            else:
                # do the NLP solver work
                standin_model = extract_initial_parameters(self.model, self.device)
                optimizer1_state_dict = self.optimizer.state_dict()
                lr = self.optimizer.param_groups[0]['lr']
                standin_optm = type(self.optimizer)(standin_model.parameters(), lr=lr)
                standin_optm.load_state_dict(deepcopy(optimizer1_state_dict))

                initial_theta = standin_model.state_dict()
                init_state = standin_optm.state_dict()

                def fun(datas, honest_grads, honest_avg, init_theta_epoch, init_state_dict, aggr):

                    def v(g):
                        agg = deepcopy(aggr)
                        standin_model.load_state_dict(deepcopy(init_theta_epoch))
                        standin_optm.load_state_dict(deepcopy(init_state_dict))
                        grads = deepcopy(honest_grads)

                        byzantine_grad = sum(a * b for a, b in zip(g, grads))
                        for k in range(self.f):
                            grads.append(deepcopy(byzantine_grad))
                        aggregated = agg(grads)

                        beg = 0
                        for group in standin_optm.param_groups:
                            for p in group["params"]:
                                if p.grad is None:
                                    continue
                                ter = beg + len(p.grad.view(-1))
                                x = aggregated[beg:ter].reshape_as(p.grad.data)
                                p.grad.data = x.clone().detach()
                                beg = ter
                        if self.clipping:
                            torch.nn.utils.clip_grad_norm_(standin_model.parameters(), max_norm=self.clipping)
                        standin_optm.step()

                        reward = 0
                        with torch.no_grad():
                            for sam in datas:
                                X_batch, y_batch = sam
                                output = standin_model(X_batch)
                                max_minus_min_per_sample = output.max(dim=1)[0] - output.min(dim=1)[0]
                                average_max_minus_min = max_minus_min_per_sample.mean()
                                reward += average_max_minus_min

                        reward = reward / (self.n - self.f)
                        return -reward.cpu().numpy()

                    return v

                bounds = None

                cost = fun

                theta = deepcopy(initial_theta)
                state = deepcopy(init_state)

                g0 = np.zeros(len(honest_gradients))
                funs = []
                xs = []
                solver = []
                # for method in ['Nelder-Mead', 'Powell', 'CG', 'BFGS', 'L-BFGS-B', 'TNC', 'COBYLA', 'SLSQP', 'trust-constr']:
                methods = ['Powell']
                for method in methods:
                    best_res = 0
                    best_g = np.zeros(1)
                    iter_without_improvement = 0
                    max_iter_without_improvement = 200

                    def callback(g):
                        nonlocal best_res, best_g, iter_without_improvement
                        f = cost(datas, honest_gradients, honest_avg, theta, state, self.agg)(g)

                        # if np.isnan(f):
                        #     print("Warning: f is nan. Skipping this iteration.")
                        #     return

                        if f < best_res:
                            best_res = f
                            best_g = g
                            iter_without_improvement = 0
                        else:
                            iter_without_improvement += 1
                        if iter_without_improvement >= max_iter_without_improvement:
                            raise Exception('No improvement for 200 iterations, stopping optimization.')

                    try:
                        res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg),
                                       g0, method=method, bounds=bounds, callback=callback)
                    except Exception as e:
                        # print(str(e))  # Output the exception message
                        res = OptimizeResult()
                        res.x = best_g  # The best solution found
                        res.fun = best_res  # The function value at the best solution
                        res.success = True  # The optimization did not converge, but still see it as success
                    if res.success:
                        funs.append(best_res)
                        xs.append(best_g)
                        solver.append(method)
                    # print("Done:", len(self.opt_path_ful), method)
                if not funs:
                    g0 = best_g
                    # res = minimize(cost(datas, honest_gradients, honest_avg, theta, state, self.agg), g0,
                    #                method='Powell', bounds=bounds, callback=callback)
                    # raise RuntimeError('NLP Attack Fail!')
                    funs.append(best_res)
                    xs.append(best_g)
                    solver.append(method)
                    # print("Done:", len(self.opt_path_ful), method)
                ind = funs.index(min(funs))
                best_x = xs[ind]
                best_solver = solver[ind]
                # print(best_solver)
                # print(funs[ind])
                # print(best_x)

                self.opt_path_ful.append(best_x)
        else:
            self.opt_path_ful.append(np.ones(len(honest_gradients)) / len(honest_gradients))

        lambda_ts = self.opt_path_ful[-1]
        # print(lambda_t)
        mal_grad = sum(a * b for a, b in zip(lambda_ts, honest_gradients))
        # print("nonlinear:118:mal_grad", mal_grad)
        self._gradient = mal_grad

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError



def compare_state_dicts(dict1, dict2):
    for key in dict1:
        if not torch.equal(dict1[key], dict2[key]):
            return False
    return True

