# import sys
# sys.path.append("/workspace/Project_EvoWire/EVAttack/adversarial_training")

import os
import yaml
from tqdm import tqdm
import torch

from utils.data import load_arxiv_dataset_splits
from utils.model import load_model_instance, load_robust_model_instance
from utils.attack import load_attack_instance, create_attack_instance

from sacred import Experiment

experiment = Experiment("VanillaAttack")

@experiment.config
def default_config():
    ## Experiment configs
    dataset_name = "ogbn-arxiv"

    # model_name in ["GPRGNN", "APPNP"]
    model_name = "GPRGNN"
    n_splits = 10

    model_params = None
    epsilon = 0.1

    # attack_name in ["PRBCD", "LRBCD", "Evafast"]
    attack_name = "PRBCD"
    train_attack_name = None
    attack_params = None

    inductive = True

@experiment.automain
def run(dataset_name, model_name, n_splits, model_params, epsilon, attack_name, train_attack_name, attack_params, inductive):
    
    assert dataset_name == "ogbn-arxiv"

    general_config = yaml.safe_load(open("conf/general-config.yaml"))
    default_model_configs = yaml.safe_load(open("conf/model-configs.yaml")).get("arxiv_configs")
    default_attack_configs = yaml.safe_load(open("conf/attack-configs.yaml")).get("arxiv_configs")

    dataset_root = general_config.get("dataset_root", "data/")
    splits_root = general_config.get("splits_root", "splits/")
    models_root = general_config.get("models_root", "models/")
    reports_root = general_config.get("reports_root", "reports/")

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    dataset_splits = [split_record for split_record in os.listdir(splits_root) 
                      if split_record.split("-")[0] == dataset_name.replace("-", "_")][:n_splits]
    creating_splits = max(n_splits - len(dataset_splits), 0)

    if creating_splits > 0:
        raise ValueError("Not enough splits for the dataset. Create the splits by running training scripts.")

    # creating remaining needed dataset splits
    print(f"Found {len(dataset_splits)} splits!")

    print(f"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits")

    clean_accs = []
    pert_accs = []
    for split_file in tqdm(dataset_splits[:n_splits]):
        split_code = split_file.split("-")[1].replace(".pt", "")

        dataset = load_arxiv_dataset_splits(
            dataset_name, split_code, inductive=inductive, 
            dataset_root=dataset_root, splits_root=splits_root, device=device)
        
        test_attr = dataset["test_attr"]
        test_adj = dataset["test_adj"]
        labels = dataset["labels"]
        unlabeled_mask = dataset["unlabeled_mask"]
        test_mask = dataset["test_mask"]
        split_name = dataset["split_name"]
        dataset_info = dataset["dataset_info"]

        try:
            if train_attack_name is None:
                model_instance = load_model_instance(
                    model_name=model_name, model_params=model_params, 
                    test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                    split_name=split_name, dataset_info=dataset_info, inductive=inductive, models_root=models_root,
                    default_model_configs=default_model_configs, device=device)
            else:
                model_instance = load_robust_model_instance(
                    model_name=model_name, model_params=model_params, 
                    dataset_info=dataset_info, 
                    test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                    split_name=split_name, inductive=inductive,
                    models_root=models_root, self_training=True, robust_training=True, train_attack_name=train_attack_name, robust_epsilon=0.2,
                    default_model_configs=default_model_configs, suffix='', device=device)
       
        except FileNotFoundError as e:
            print(e)
            raise ValueError("Model not found. Run training scripts to train the model.")

        model = model_instance["model"]
        acc = model_instance["accuracy"]
        model_storage_name = model_instance["model_storage_name"]
        clean_accs.append(acc)

        attack_exists = False
        try:
            attack = load_attack_instance(
                attack_name=attack_name, attack_params=attack_params, epsilon=epsilon,
                test_attr=test_attr, test_adj=test_adj, labels=labels, model=model,
                dataset_info=dataset_info, model_storage_name=model_storage_name, 
                split_name=split_name, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                default_attack_configs=default_attack_configs, reports_root=reports_root,
                device=device, inductive=inductive)
            attack_exists = True
        except Exception as e:
            print("Attack not found. Creating attack instance")

        try:
            if not attack_exists:
                attack = create_attack_instance(
                    attack_name=attack_name, attack_params=attack_params, epsilon=epsilon,
                    test_attr=test_attr, test_adj=test_adj, labels=labels, model=model,
                    dataset_info=dataset_info, model_storage_name=model_storage_name, 
                    split_name=split_name, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                    default_attack_configs=default_attack_configs, reports_root=reports_root,
                    device=device, inductive=inductive)
        except Exception as e:
            print(f"Exception at split {split_code}, file = {split_file}, model {model_storage_name}: {e}")
            raise e
        pert_acc = attack["pert_acc"]
        pert_accs.append(pert_acc)
    mean_clean_acc = torch.mean(torch.tensor(clean_accs))
    mean_pert_acc = torch.mean(torch.tensor(pert_accs))
    std_clean_acc = torch.std(torch.tensor(clean_accs))
    std_pert_acc = torch.std(torch.tensor(pert_accs))

    print(f"Mean clean accuracy: {mean_clean_acc:.4f} $\\pm$ {std_clean_acc:.4f}")
    print(f"Mean perturbed accuracy: {mean_pert_acc:.4f} $\\pm$ {std_pert_acc:.4f}")
    print("Experiment Finished")