from ml_collections import ConfigDict
import torch
import os
from copy import deepcopy

from utils.storage import TensorHash

from robust_diffusion.attacks.prbcd import PRBCD
from robust_diffusion.attacks.prbcd_constrained import LRBCD
from robust_diffusion.attacks.pgd import PGD
from robust_diffusion.data import count_edges_for_idx

from utils.general import accuracy
from eva_attack import EvaAttack, EvaFast, EvaLocal, EvaTarget, EvaSGCPoisoningAttack
#from eva_cert import EvaCert

def load_attack_class(attack_name):
    if attack_name.lower() == "prbcd":
        return PRBCD
    if attack_name.lower() == "lrbcd":
        return LRBCD
    if attack_name.lower() == "pgd":
        return PGD
    if attack_name.lower() == "evaattack":
        return EvaAttack
    if attack_name.lower() == "evafast":
        return EvaFast
    if attack_name.lower() == "evalocal":
        return EvaLocal
    if attack_name.lower() == "evatarget":
        return EvaTarget
    if attack_name.lower() == "evasgcpoisoningattack":
        return EvaSGCPoisoningAttack
#    if attack_name.lower() == "evacert":
 #       return EvaCert
    else:
        raise ValueError(f"Attack {attack_name} not found")
    
def attack_storage_label(attack_name, attack_params, epsilon, dataset_info, model_storage_name, split_name, suffix=""):
    if attack_params is None:
        attack_params = ConfigDict()
    attack_config_hash = TensorHash.hash_model_params(model_name=attack_name, model_params=attack_params)
    # setting_str = "ind" if inductive else "tr"
    return f"{attack_name}-{attack_config_hash}-{str(epsilon).replace('.', '_')}-{dataset_info.dataset_name}-{model_storage_name}-{suffix}"

def load_attack_instance(attack_name, attack_params, epsilon,
                           test_attr, test_adj, labels, model,
                           dataset_info, model_storage_name, split_name, 
                           test_mask, unlabeled_mask, default_attack_configs=None,
                           reports_root=None, inductive=False,
                           device='cpu'):
    if default_attack_configs is None:
        raise ValueError("default_attack_configs is required to load the model")
    
    if attack_params is None:
        attack_params = ConfigDict()

    attack_configs = deepcopy(default_attack_configs[attack_name.upper()])
    if attack_params is not None:
        attack_configs.update(attack_params)
    
    attacker_storage_name = attack_storage_label(
        attack_name=attack_name, attack_params=ConfigDict(attack_configs), 
        epsilon=epsilon, dataset_info=dataset_info, model_storage_name=model_storage_name,
        split_name=split_name)

    attack_artifacts = torch.load(f"{reports_root}/{attacker_storage_name}.pt")
    return attack_artifacts


def attack_graph(
        attack_name, attack_params, epsilon,
        test_attr, test_adj, labels, model,
        dataset_info, model_storage_name, split_name, 
        test_mask, unlabeled_mask, default_attack_configs=None,
        reports_root=None, inductive=False, training_idx = None,
        device='cpu', save=True, attack_suffix="", n_perturbations=None):

    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)

    attack_idx = eval_mask.nonzero(as_tuple=True)[0]

    if attack_params is None:
        attack_params = ConfigDict()

    attack_configs = deepcopy(default_attack_configs[attack_name.upper()])
    if attack_params is not None:
        attack_configs.update(attack_params)

    adversary = load_attack_class(attack_name=attack_name)(
        attr=test_attr, adj=test_adj, labels=labels,
        model=model, idx_attack=attack_idx.cpu().numpy(),
        device=device, data_device=device, make_undirected=True, binary_attr=False, training_idx= training_idx,
        **attack_configs)

    n_feasible_edges = count_edges_for_idx(test_adj.cpu(), attack_idx.cpu())
    # TODO: Check if for undirected this number should be devided
    if n_perturbations is not None:
        n_attack_edges = n_perturbations
    else:
        n_attack_edges = (n_feasible_edges * epsilon).int().item() // 2
    adversary.attack(n_attack_edges)
    pert_adj, pert_attr = adversary.get_pertubations()
    return adversary


def create_attack_instance(attack_name, attack_params, epsilon,
                           test_attr, test_adj, labels, model,
                           dataset_info, model_storage_name, split_name, 
                           test_mask, unlabeled_mask, default_attack_configs=None,
                           reports_root=None, inductive=False, training_idx = None,
                           device='cpu', save=True, attack_suffix=""):
    if default_attack_configs is None:
        raise ValueError("default_attack_configs is required to load the model")

    eval_mask = test_mask if inductive else (test_mask | unlabeled_mask)

    attack_idx = eval_mask.nonzero(as_tuple=True)[0]

    if attack_params is None:
        attack_params = ConfigDict() 

    attack_configs = deepcopy(default_attack_configs[attack_name.upper()])
    if attack_params is not None:
        attack_configs.update(attack_params)

    adversary = load_attack_class(attack_name=attack_name)(
        attr=test_attr, adj=test_adj, labels=labels,
        model=model, idx_attack=attack_idx.cpu().numpy(),
        device=device, data_device=device, make_undirected=True, binary_attr=False, training_idx= training_idx,
        **attack_configs)

    n_feasible_edges = count_edges_for_idx(test_adj.cpu(), attack_idx.cpu())
    # TODO: Check if for undirected this number should be devided
    n_attack_edge = (n_feasible_edges * epsilon).int().item() // 2
    adversary.attack(n_attack_edge)
    pert_adj, pert_attr = adversary.get_pertubations()

    pert_acc = accuracy(
        model=model, attr=pert_attr, adj=pert_adj, 
        labels=labels, evaluation_mask=eval_mask)

    attacker_storage_name = attack_storage_label(
        attack_name=attack_name, attack_params=ConfigDict(attack_configs), 
        epsilon=epsilon, dataset_info=dataset_info, model_storage_name=model_storage_name,
        split_name=split_name, suffix=attack_suffix)
    attack_artifacts = {
        "pert_adj": pert_adj, "pert_attr": pert_attr,
        "pert_acc": pert_acc, "n_attack_edge": n_attack_edge,
        "epsilon": epsilon, "attack_params": ConfigDict(attack_configs),
        "attacker_storage_name": attacker_storage_name
    }
    try:
        if save:
            os.makedirs(reports_root, exist_ok=True)
            torch.save(attack_artifacts, f"{reports_root}/{attacker_storage_name}.pt")
    except Exception as e:
        print("Error in saving attack artifacts", e)

    attack_artifacts["attack_obj"] = adversary
    return attack_artifacts

def adapt_model_to_attack(model, attack):
    raise NotImplementedError("This function is not implemented yet")
    return model
