import os
import yaml
from ml_collections import ConfigDict

import torch
import torch_geometric

from sacred import Experiment
experiment = Experiment("VanillaAttack")

from utils.data import load_dataset, make_dataset_splits, load_dataset_splits
from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash
from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance
from utils.attack import load_attack_class


from robust_diffusion.data import SparseGraph
from robust_diffusion.data import count_edges_for_idx
from robust_diffusion.helper import utils as robust_utils
from robust_diffusion.train import train


@experiment.config
def default_config():
    ## Experiment configs
    dataset_name = "cora_ml"
    training_split = None
    validation_split = None
    training_split_type = None
    validation_split_type = None

    model_name = "GCN"
    model_params = None
    epsilon = 0.1

    attack_name = "PRBCD"
    attack_params = None

    inductive = False

@experiment.automain
def run(dataset_name, training_split, validation_split, training_split_type, validation_split_type, 
        model_name, model_params, epsilon, attack_name, attack_params, inductive):
    ## Loading general configs (like dataset_root, etc.)
    general_config = yaml.safe_load(open("conf/general-config.yaml"))
    default_dataset_configs = yaml.safe_load(open("conf/data-configs.yaml")).get("configs").get("default")
    default_model_configs = yaml.safe_load(open("conf/model-configs.yaml")).get("configs")
    default_attack_configs = yaml.safe_load(open("conf/attack-configs.yaml")).get("configs")

    # extracting configs 
    dataset_root = general_config.get("dataset_root", "data/")
    splits_root = general_config.get("splits_root", "splits/")
    models_root = general_config.get("models_root", "models/")
    results_root = general_config.get("results_root", "results/")
    reports_root = general_config.get("reports_root", "reports/")
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print("Experiment Started")
    # Trains the specified model on the given graph and saves the model artifacts, and the splits.

    print(dataset_name)

    # Loading the dataset, creating splits, and saving them (for both transductive and inductive)
    data = make_dataset_splits(dataset_name=dataset_name, 
                               training_split=training_split, validation_split=validation_split, 
                               training_split_type=training_split_type, validation_split_type=validation_split_type, 
                               inductive=inductive, 
                               default_dataset_configs=default_dataset_configs, dataset_root=dataset_root, splits_root=splits_root, device=device)
    # data = load_dataset_splits(dataset_name, 
    #                         "0x30d202b2fcf2b06",
    #                             inductive=inductive, dataset_root=dataset_root, splits_root=splits_root, device=device)
    training_attr = data["training_attr"]
    training_adj = data["training_adj"]
    labels = data["labels"]
    training_idx = data["training_idx"]
    validation_idx = data["validation_idx"]
    test_attr = data["test_attr"]
    test_adj = data["test_adj"]
    test_mask = data["test_mask"]
    dataset_info = data["dataset_info"]
    split_name = data["split_name"]


    # Loading and training the model
    model_instance = create_model_instance(
        model_name=model_name, model_params=model_params, dataset_info=dataset_info, 
        training_attr=training_attr, training_adj=training_adj, labels=labels, training_idx=training_idx, validation_idx=validation_idx,
        test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, inductive=inductive, split_name=split_name,
        models_root=models_root, 
        default_model_configs=default_model_configs, 
        device=device)

    # model_instance = load_model_instance(
    #     model_storage_name='GCN-0xbd14035a4016352-tr-cora_ml-0x30d202b2fcf2b06', 
    #     model_name=model_name, model_params=model_params, 
    #     test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, dataset_info=dataset_info, inductive=inductive,
    #     models_root=models_root,
    #     default_model_configs=default_model_configs, device=device)

    model = model_instance["model"]
    acc = model_instance["accuracy"]
    print("Accuracy (Clean): ", acc)
    model_params = model_instance["model_params"]
    model_storage_name = model_instance["model_storage_name"]

    idx_attack = test_mask.nonzero(as_tuple=True)[0].cpu().numpy()
    n_feasible_edges = count_edges_for_idx(test_adj.cpu(), idx_attack) / (2)
    n_attack_edges = (n_feasible_edges * epsilon).int().item()

    if attack_params is None:
        attack_params = ConfigDict(default_attack_configs.get(attack_name))
    attack_params.device = device
    adversary = load_attack_class(attack_name)(
        attr=test_attr, adj=test_adj, labels=labels, model=model, 
        idx_attack=test_mask.nonzero(as_tuple=True)[0].cpu().numpy(),
        data_device=device, make_undirected=True, binary_attr=False,
        **attack_params.to_dict())
    adversary.attack(n_attack_edges)
    pert_adj, pert_attr = adversary.get_pertubations()
    adv_acc = accuracy(model, pert_attr, pert_adj, labels, test_mask)

    attack_config_hash = TensorHash.hash_model_params(model_name=attack_name, model_params=attack_params)
    epsilon_str = str(epsilon).replace(".", "_")
    attack_storage_name = f"{attack_name}-eps{epsilon_str}-{attack_config_hash}-{model_storage_name}"

    try:
        os.makedirs(results_root, exist_ok=True)
        torch.save(pert_adj, os.path.join(results_root, f"{attack_storage_name}-adj.pt"))
    except RuntimeError as e:
        print(f"Error saving perturbed adj: {e}")

    print("Accuracy (Perturbed):", adv_acc)
    report = {
        "model": model_name,
        "dataset": dataset_name,
        "dataset_info": dataset_info.to_dict(),
        "model_params": model_params,
        "clean_accuracy": acc,
        "setting": "inductive" if inductive else "transductive",
        "attack": attack_name,
        "attack_params": attack_params,
        "attack_accuracy": adv_acc,
        "epsilon": epsilon,
        "n_attack_edges": n_attack_edges,
        "model_storage_name": model_storage_name,
        "attack_storage_name": attack_storage_name
    }

    try:
        os.makedirs(reports_root, exist_ok=True)
        torch.save(report, os.path.join(reports_root, f"{attack_storage_name}-report.pt"))
    except RuntimeError as e:
        print(f"Error saving report: {e}")

    print("Experiment Finished")


# if __name__ == '__main__':
#     ex.run_commandline()