from __future__ import annotations
from argparse import Namespace
from typing import Type
import torch
import torch.cuda as cuda
import torch.nn as nn

import utils.tensor_utils as tensor_utils
import utils.utils as utils

from fl.client import local_update_batch
from fl.aggregate import aggregate

class Attack:
    def __init__(
        self,
        global_model: nn.Module, communication_round: int,
        attacker_datas: list[torch.Tensor], attacker_labels: list[torch.Tensor],
        user_updates: list[dict[str, torch.Tensor]], n_user_samples: list[int],
        args,
    ) -> None:
        self.global_model = global_model
        self.communication_round=communication_round
        self.attacker_datas = attacker_datas
        self.attacker_labels = attacker_labels
        self.user_updates = user_updates
        self.n_user_samples = n_user_samples
        self.args = args
        self.prepare_attack()

    def prepare_attack(self):
        pass

    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        pass
        
    def get(self) -> tuple[list[dict[str, torch.Tensor]], list[int]]:
        at_grads = []
        n_attackers = len(self.attacker_datas)
        for at_idx in range(n_attackers):
            at_grad = {}
            for pname in self.global_model.state_dict():
                at_grad[pname] = self.get_original(at_idx=at_idx, pname=pname)
            at_grads.append(at_grad)
        n_attacker_samples = [len(attacker_data) for attacker_data in self.attacker_datas]
        return at_grads, n_attacker_samples

class BitFlip(Attack):
    def prepare_attack(self):
        self.benign_attacker_grads, _ = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, self.attacker_labels, self.args)
    
    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return -self.benign_attacker_grads[at_idx][pname]

class LabelFlip(Attack):
    def prepare_attack(self):
        flipped_attacker_labels = [self.args.n_classes - attacker_label - 1 for attacker_label in self.attacker_labels]
        self.label_flip_grad, _ = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, flipped_attacker_labels, self.args)
    
    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return self.label_flip_grad[at_idx][pname]
    
class Lie(Attack):
    @torch.no_grad()
    def prepare_attack(self):
        if len(self.user_updates) == 0:
            ref_updates = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, self.attacker_labels, self.args)
        else:
            ref_updates = self.user_updates
        flat_ref_updates, restore_info = tensor_utils.flatten_named_tensors(ref_updates)
        flat_avg = flat_ref_updates.mean(dim=0)
        flat_std = flat_ref_updates.std(dim=0, unbiased=False)
        flat_at_update = flat_avg - self.args.lie_z * flat_std
        self.at_update = tensor_utils.restore_tensor(flat_at_update, restore_info)

    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return self.at_update[pname]

# refer to https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning
class MinMax(Attack):
    def prepare_attack(self):
        if len(self.user_updates) == 0:
            ref_grads = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, self.attacker_labels, self.args)
        else:
            ref_grads = self.user_updates
        with torch.no_grad():
            flat_ref_grads, restore_info = tensor_utils.flatten_named_tensors(ref_grads)

            flat_grad_avg = flat_ref_grads.mean(dim=0)
            if self.args.dev_type == 'unit_vec':
                deviation = flat_grad_avg / torch.norm(flat_grad_avg).cuda()  # unit vector, dir opp to good dir
            elif self.args.dev_type == 'sign':
                deviation = torch.sign(flat_grad_avg).cuda()
            elif self.args.dev_type == 'std':
                deviation = torch.std(flat_ref_grads, 0, unbiased=False).cuda()

            lamda = 10.0
            threshold_diff = 1e-5
            lamda_fail = lamda
            lamda_succ = 0
            
            distances = []
            for grad in flat_ref_grads:
                distance = torch.norm((flat_ref_grads - grad), dim=1) ** 2
                distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
            
            max_distance = torch.max(distances)
            del distances

            while abs(lamda_succ - lamda) > threshold_diff:
                flat_at_grad = (flat_grad_avg - lamda * deviation)
                distance = torch.norm((flat_ref_grads - flat_at_grad), dim=1) ** 2
                max_d = torch.max(distance)
                
                if max_d <= max_distance:
                    lamda_succ = lamda
                    lamda = lamda + lamda_fail / 2
                else:
                    lamda = lamda - lamda_fail / 2

                lamda_fail = lamda_fail / 2

            flat_at_grad = (flat_grad_avg - lamda_succ * deviation)
            self.at_grad = tensor_utils.restore_tensor(flat_at_grad, restore_info)
    
    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return self.at_grad[pname]

# refer to https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning
class MinSum(Attack):
    def prepare_attack(self):
        if len(self.user_updates) == 0:
            ref_grads = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, self.attacker_labels, self.args)
        else:
            ref_grads = self.user_updates
        
        with torch.no_grad():
            flat_ref_grads, restore_info = tensor_utils.flatten_named_tensors(ref_grads)

            flat_grad_avg = flat_ref_grads.mean(dim=0)
            if self.args.dev_type == 'unit_vec':
                deviation = flat_grad_avg / torch.norm(flat_grad_avg).cuda()  # unit vector, dir opp to good dir
            elif self.args.dev_type == 'sign':
                deviation = torch.sign(flat_grad_avg).cuda()
            elif self.args.dev_type == 'std':
                deviation = torch.std(flat_ref_grads, 0, unbiased=False).cuda()

            lamda = 10.0
            threshold_diff = 1e-5
            lamda_fail = lamda
            lamda_succ = 0
            
            distances = []
            for grad in flat_ref_grads:
                distance = torch.norm((flat_ref_grads - grad), dim=1) ** 2
                distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
            
            scores = torch.sum(distances, dim=1)
            min_score = torch.min(scores)

            while abs(lamda_succ - lamda) > threshold_diff:
                flat_at_grad = (flat_grad_avg - lamda * deviation)
                distance = torch.norm((flat_ref_grads - flat_at_grad), dim=1) ** 2
                score = torch.sum(distance)
                
                if score <= min_score:
                    lamda_succ = lamda
                    lamda = lamda + lamda_fail / 2
                else:
                    lamda = lamda - lamda_fail / 2

                lamda_fail = lamda_fail / 2

            flat_at_grad = (flat_grad_avg - lamda_succ * deviation)
            self.at_grad = tensor_utils.restore_tensor(flat_at_grad, restore_info)

    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return self.at_grad[pname]

class Ipm(Attack):
    def prepare_attack(self):
        if len(self.user_updates) == 0:
            ref_updates = local_update_batch(self.global_model, self.communication_round, self.attacker_datas, self.attacker_labels, self.args)
        else:
            ref_updates = self.user_updates
        with torch.no_grad():
            avg_update, _ = aggregate(client_updates=ref_updates, )

            def eval_epsilon(epsilon):
            # Apply the given epsilon
                at_update = {name: -epsilon * param_update for name, param_update in avg_update.items()}
                n_attackers = len(self.attacker_datas)
                at_updates = [at_update, ] * n_attackers
                # Measure effective squared distance
                n_attacker_samples = [len(attacker_data) for attacker_data in self.attacker_datas]
                agg_update, _ = aggregate(
                    client_updates=self.user_updates + at_updates,
                    n_samples=self.n_user_samples + n_attacker_samples,
                    n_attackers=n_attackers,
                    agg_type=self.args.agg_type,
                    n_subvectors=self.args.n_subvectors, 
                    budget=self.args.rfa_budget,
                    filtering_fraction=self.args.dnc_filter, n_sampled_coordinates=self.args.dnc_n_sample,
                )
                deviation = sum([(agg_update[name] - avg_update[name]).square().sum().item() for name in agg_update])
                return deviation
            epsilon = utils.line_maximize(scape=eval_epsilon, evals=self.args.ipm_evals)
            self.at_update = {name: -epsilon * param_update for name, param_update in avg_update.items()}

    @torch.no_grad()
    def get_original(self, at_idx: int, pname: str) -> torch.Tensor:
        return self.at_update[pname]

ATTACKS: dict[str, Type[Attack]] = {
    'bit_flip': BitFlip,
    'ipm': Ipm,
    'label_flip': LabelFlip,
    'lie': Lie,
    'min_max': MinMax,
    'min_sum': MinSum,
}

def attack(
    attacker_datas: list[torch.Tensor], attacker_labels: list[torch.Tensor],
    global_model: nn.Module, communication_round: int,
    user_grads: list[dict[str, torch.Tensor]], n_user_samples: list[int],
    args: Namespace,
)->tuple[list[dict[str, torch.Tensor]], list[int]]:
    '''
    parse args to obtain hyperparameters for byzantine attack and local training
    '''
    if args.n_attackers == 0:
        return [], []
    else:
        if args.at_type in ATTACKS:
            attack_obj = ATTACKS[args.at_type](
                global_model, communication_round,
                attacker_datas, attacker_labels,
                user_grads, n_user_samples,
                args,
            )
            at_update =  attack_obj.get()
            
            return at_update
        else:
            raise Exception('EXCEPTATION: INVALID at_type')