# import sys
# sys.path.append("/workspace/Project_EvoWire/EVAttack/adversarial_training")

import os
import yaml
from ml_collections import ConfigDict
from tqdm import tqdm
from copy import deepcopy
import numpy as np
import logging

import torch
import torch_geometric

from utils.data import load_arxiv_dataset_splits
from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash
from utils.model import create_robust_model_instance, load_robust_model_instance
from utils.attack import load_attack_class, attack_storage_label, create_attack_instance

from robust_diffusion.data import SparseGraph
from robust_diffusion.data import count_edges_for_idx
from robust_diffusion.train import train, train_inductive

from robust_diffusion.helper.utils import calculate_loss
from robust_diffusion.models import create_model, GPRGNN, DenseGPRGNN, ChebNetII
from robust_diffusion.models.gprgnn import GPR_prop

from sacred import Experiment

experiment = Experiment("VanillaAttack")

@experiment.config
def default_config():
    ## Experiment configs
    dataset_name = "ogbn-arxiv"

    # model_name in ["GPRGNN", "APPNP"]
    model_name = "GPRGNN"
    n_splits = 10

    model_params = None

    # attack_name in ["PRBCD", "LRBCD", "EVAFAST"] 
    train_attack_name = "PRBCD"
    train_attack_params = None
    val_attack_params = None
    validate_every = 10

    train_params = None
    loss_type = 'tanhMargin'

    inductive = True
    self_training = False
    robust_training = True
    robust_epsilon = 0.05

    make_undirected = True
    binary_attr = False


@experiment.automain
def run(dataset_name, model_name, n_splits, train_params, loss_type, validate_every, make_undirected, binary_attr,
        model_params, train_attack_name, train_attack_params, val_attack_params , inductive, self_training ,robust_training, robust_epsilon):
    
    assert dataset_name == "ogbn-arxiv"
    assert self_training == False

    ## Loading general configs (like dataset_root, etc.) and initial parameters
    general_config = yaml.safe_load(open("conf/general-config.yaml"))
    default_model_configs = yaml.safe_load(open("conf/model-configs.yaml")).get("arxiv_configs")
    default_robust_configs= yaml.safe_load(open("conf/model-configs.yaml")).get("robust_arxiv_configs")

    dataset_root = general_config.get("dataset_root", "data/")
    splits_root = general_config.get("splits_root", "splits/")
    models_root = general_config.get("models_root", "models/")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    if train_params is None:
        train_params = default_robust_configs.get("training")
    
    val_attack_name = train_attack_name
    if train_attack_params is None:
        if train_attack_name in ["LRBCD", "EVAFAST", "PRBCD"]:
            train_attack_params = default_robust_configs.get("train_attack_params").get(train_attack_name)
        else:
            raise(ValueError(f"{train_attack_name} Invalid attack name"))
    
    if val_attack_params is None:
        if val_attack_name in ["LRBCD", "EVAFAST", "PRBCD"]:
            val_attack_params = default_robust_configs.get("val_attack_params").get(val_attack_name)
        else:
            raise(ValueError(f"{val_attack_name} Invalid attack name"))

    # load dataset splits
    dataset_splits = [split_record for split_record in os.listdir(splits_root) 
                      if split_record.split("-")[0] == dataset_name.replace("-", "_")][:n_splits]
    creating_splits = max(n_splits - len(dataset_splits), 0)
    if creating_splits > 0:
        raise ValueError("Not enough splits for the dataset. Create the splits by running training scripts.")
    print(f"Found {len(dataset_splits)} splits!")
    print(f"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits")

    clean_accs = []
    for split_file in tqdm(dataset_splits[:n_splits]):
        split_code = split_file.split("-")[1].replace(".pt", "")

        # load dataset splits
        dataset = load_arxiv_dataset_splits(
            dataset_name, split_code, inductive=inductive, 
            dataset_root=dataset_root, splits_root=splits_root, device=device)
        
        training_attr = dataset["training_attr"]
        training_adj = dataset["training_adj"]
        validation_attr = dataset["validation_attr"]
        validation_adj = dataset["validation_adj"]
        test_attr = dataset["test_attr"]
        test_adj = dataset["test_adj"]
        labels = dataset["labels"]
        training_idx = dataset["training_idx"]
        validation_idx = dataset["validation_idx"]
        unlabeled_mask = dataset["unlabeled_mask"]
        test_mask = dataset["test_mask"]
        split_name = dataset["split_name"]
        dataset_info = dataset["dataset_info"]

        try:
            model_instance = load_robust_model_instance(
                model_name=model_name, model_params=model_params, dataset_info=dataset_info, 
                test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                split_name=split_name, inductive=inductive,
                models_root=models_root, self_training=self_training, 
                robust_training=robust_training, train_attack_name=train_attack_name, robust_epsilon=robust_epsilon,
                default_model_configs=default_model_configs, suffix='', device=device)
            print("model loaded successfully")

        except FileNotFoundError:
            print("Model not found. Run training scripts to train the model.")
            model_instance = create_robust_model_instance(
                model_name=model_name, model_params=model_params, dataset_info=dataset_info, 
                training_attr=training_attr, training_adj=training_adj, 
                validation_attr=validation_attr, validation_adj=validation_adj,
                labels=labels, training_idx=training_idx, validation_idx=validation_idx, train_attack_name=train_attack_name,
                test_attr=test_attr, test_adj=test_adj, test_mask=test_mask, unlabeled_mask=unlabeled_mask, inductive=inductive,
                split_name=split_name, train_params=train_params, train_attack_params=train_attack_params, val_attack_params=val_attack_params,
                make_undirected=make_undirected,
                models_root=models_root, default_model_configs=default_model_configs, suffix='',
                self_training=self_training, robust_training=robust_training, robust_epsilon=robust_epsilon,
                validate_every=validate_every, loss_type=loss_type, binary_attr=binary_attr,
                device=device)
        
        clean_acc = model_instance["clean_accuracy"]
        clean_accs.append(clean_acc)

    mean_clean_acc = torch.mean(torch.tensor(clean_accs))
    std_clean_acc = torch.std(torch.tensor(clean_accs))
    print(f"Mean clean accuracy on clean dateset: {mean_clean_acc:.4f} $\\pm$ {std_clean_acc:.4f}")

    print("Experiment Finished")

