import pandas as pd
# import sys
# sys.path.append("/workspace/Project_EvoWire/EVAttack/adversarial_training")

import os
import yaml
from ml_collections import ConfigDict
from tqdm import tqdm
from copy import deepcopy

import torch
import torch_geometric

from utils.data import load_dataset, make_dataset_splits, load_dataset_splits, check_dataset_valid
from utils.split import SplitManager, node_induced_subgraph
from utils.storage import TensorHash
from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance, load_robust_model_instance, from_sparse_GCN, from_sparse_GPRGNN
from utils.attack import load_attack_class, attack_storage_label, create_attack_instance, load_attack_instance

from robust_diffusion.data import SparseGraph
from robust_diffusion.data import count_edges_for_idx
from robust_diffusion.helper import utils as robust_utils
from robust_diffusion.train import train

from sacred import Experiment



experiment = Experiment("TargettedAttack")

def filtered_induced_subgraph(attr, adj, labels, target_mask, idxs=None, depth=2):
    # computing the k-hop neighbors of the target node
    target_idx = target_mask.nonzero(as_tuple=True)[0]
    k_adj = adj.clone()
    for _ in range(depth):
        k_adj = k_adj + (adj @ k_adj)
    filter_mask = (k_adj)[target_idx].to_torch_sparse_coo_tensor().coalesce().indices()[1].unique()

    filtered_attr = attr[filter_mask]
    filtered_adj = adj[filter_mask, filter_mask]
    filtered_labels = labels[filter_mask]
    
    if idxs is None:
        return filtered_attr, filtered_adj, filtered_labels
    
    filtered_idxs = []
    for idx in idxs:
        mask = torch.zeros_like(target_mask)
        mask[idx] = 1
        mask_new = mask[filter_mask]
        idx_new = mask_new.nonzero(as_tuple=True)[0]
        filtered_idxs.append(idx_new)
    # import pdb; pdb.set_trace()
    return filtered_attr, filtered_adj, filtered_labels, filtered_idxs

@experiment.config
def default_config():
    ## Experiment configs
    dataset_name = "cora_ml"

    # model_name in ["GCN", "DenseGCN", "GAT", "GPRGNN", "DenseGPRGNN", "APPNP", "ChebNetII", "SoftMedian_GDC"]
    model_name = "GCN"
    n_splits = 10

    training_split = None
    validation_split = None
    training_split_type = None
    validation_split_type = None
    test_split = None
    test_split_type = None

    model_params = None
    epsilon = 0.1

    # attack_name in ["PRBCD", "LRBCD", "EvaAttack", "Evafast", "PGD"]
    attack_name = "EvaLocal"
    train_attack_name = None
    attack_params = None

    inductive = False


@experiment.automain
def run(
    dataset_name, 
    model_name, 
    n_splits, 
    training_split, 
    validation_split, 
    training_split_type, 
    validation_split_type, 
    test_split, 
    test_split_type, 
    model_params, 
    epsilon, 
    attack_name, 
    train_attack_name, 
    attack_params, 
    inductive, ):

    general_config = yaml.safe_load(open("conf/general-config.yaml"))
    default_dataset_configs = yaml.safe_load(open("conf/data-configs.yaml")).get("configs").get("default")
    default_model_configs = yaml.safe_load(open("conf/model-configs.yaml")).get("configs")
    default_attack_configs = yaml.safe_load(open("conf/attack-configs.yaml")).get("configs")


    # extracting configs 
    dataset_root = general_config.get("dataset_root", "data/")
    splits_root = general_config.get("splits_root", "splits/")
    models_root = general_config.get("models_root", "models/")
    results_root = general_config.get("results_root", "results/")
    reports_root = general_config.get("reports_root", "reports/")
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print("Experiment Started")
    # Trains the specified model on the given graph and saves the model artifacts, and the splits.

    print("Loading dataset =", dataset_name)

    dataset_splits = [
        split_record for split_record in os.listdir(splits_root) 
        if split_record.split("-")[0] == dataset_name 
        and check_dataset_valid(split_record=split_record, training_split=training_split,
                                validation_split=validation_split, training_split_type=training_split_type, 
                                validation_split_type=validation_split_type, test_split=test_split, 
                                test_split_type=test_split_type, splits_root=splits_root)]
    creating_splits = max(n_splits - len(dataset_splits), 0)

    if creating_splits > 0:
        raise ValueError("Not enough splits for the dataset. Create the splits by running training scripts.")

    # creating remaining needed dataset splits
    print(f"Found {len(dataset_splits)} splits!")

    print(f"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits")

    clean_accs = []
    pert_accs = []
    
    # for split_file in tqdm(dataset_splits[:n_splits]):
    split_file = dataset_splits[0]
    split_code = split_file.split("-")[1].replace(".pt", "")

    data = load_dataset_splits(
        dataset_name, split_code, inductive=inductive, 
        dataset_root=dataset_root, splits_root=splits_root, device=device)

    training_attr = data["training_attr"]
    training_adj = data["training_adj"]
    labels = data["labels"]
    training_idx = data["training_idx"]
    validation_idx = data["validation_idx"]
    test_attr = data["test_attr"]
    test_adj = data["test_adj"]
    test_mask = data["test_mask"]
    unlabeled_mask = data["unlabeled_mask"]
    test_idx = test_mask.nonzero(as_tuple=True)[0]
    dataset_info = data["dataset_info"]
    split_name = data["split_name"]

    try:
        if train_attack_name is None:
            model_instance = load_model_instance(
                model_name=model_name, model_params=model_params, 
                test_attr=test_attr, test_adj=test_adj, labels=labels, 
                test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                split_name=split_name, dataset_info=dataset_info, 
                inductive=inductive,
                models_root=models_root,
                default_model_configs=default_model_configs, device=device)
        else:
            model_instance = load_robust_model_instance(
                model_name=model_name, model_params=model_params, 
                dataset_info=dataset_info, 
                test_attr=test_attr, test_adj=test_adj, labels=labels, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
                split_name=split_name, inductive=inductive,
                models_root=models_root, self_training=True, robust_training=True, train_attack_name=train_attack_name, robust_epsilon=0.2,
                default_model_configs=default_model_configs, suffix='', device=device)

    except FileNotFoundError as e:
        print(e)
        raise ValueError("Model not found. Run training scripts to train the model.")


    model = model_instance["model"]
    acc = model_instance["accuracy"]
    model_params = model_instance["model_params"]
    model_storage_name = model_instance["model_storage_name"]
    clean_accs.append(acc)

    if attack_name == "PGD" and model_name == "GCN":
        model = from_sparse_GCN(model, model_params)
    elif attack_name == "PGD" and model_name == "GPRGNN":
        model = from_sparse_GPRGNN(model, model_params)


    all_stats = []
    # for v_idx in range(test_mask.sum().item()):
    saving_idx = 0
    for v_idx in range(100):
        torch.cuda.empty_cache()
        print(f"v_idx = {v_idx}")
        target_mask = torch.zeros_like(test_mask).to(device)
        target_node = test_mask.nonzero(as_tuple=True)[0][v_idx]
        target_mask[target_node] = 1

        # x, e, y = filtered_induced_subgraph(test_attr, test_adj, labels, test_mask, depth=2)
        target_idx = target_mask.nonzero(as_tuple=True)[0]
        filtered_attr, filtered_adj, filtered_labels, filtered_idxs = filtered_induced_subgraph(
            test_attr, test_adj, labels, target_mask, 
            idxs=[training_idx, validation_idx, test_idx, unlabeled_mask.nonzero(as_tuple=True)[0], target_idx], depth=2)

        training_idx_filtered = filtered_idxs[0]
        validation_idx_filtered = filtered_idxs[1]
        test_idx_filtered = filtered_idxs[2]
        unlabeled_idx_filtered = filtered_idxs[3]
        target_idx_filtered = filtered_idxs[4]

        test_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)
        test_mask_filtered[test_idx_filtered] = 1
        unlabeled_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)
        unlabeled_mask_filtered[unlabeled_idx_filtered] = 1
        target_mask_filtered = torch.zeros(size=(filtered_attr.shape[0], ), dtype=bool).to(device)
        target_mask_filtered[target_idx_filtered] = 1

        filtered_dataset_info = deepcopy(dataset_info)
        filtered_dataset_info["n_nodes"] = filtered_attr.size(0)
        attack_params = {"attacking_nodes": target_idx_filtered.cpu().numpy().tolist()}
        # attack = create_attack_instance(
        #     attack_name='EVAFAST', attack_params=attack_params, epsilon=0.1,
        #     test_attr=test_attr, test_adj=test_adj, labels=labels, model=model,
        #     dataset_info=dataset_info, model_storage_name=model_storage_name, 
        #     split_name=split_name, test_mask=test_mask, unlabeled_mask=unlabeled_mask,
        #     default_attack_configs=default_attack_configs, reports_root=reports_root,
        #     device=device, inductive=inductive, save=False)

        for epsilon in [0.001, 0.002, 0.005, 0.007, 0.01, 0.02, 0.05, 0.07, 0.1, 0.2]:
            print("epsilon = ", epsilon)
            attack = create_attack_instance(
                attack_name='EVATARGET', attack_params=attack_params, epsilon=epsilon,
                test_attr=filtered_attr, test_adj=filtered_adj, labels=filtered_labels, model=model,
                dataset_info=filtered_dataset_info, model_storage_name=model_storage_name, 
                split_name=split_name, test_mask=test_mask_filtered, unlabeled_mask=unlabeled_mask_filtered,
                default_attack_configs=default_attack_configs, reports_root=reports_root,
                device=device, inductive=False, save=False)
            
            # import pdb; pdb.set_trace()
            
            status = {
                "epsilon": epsilon,
                "n_edges": attack["n_attack_edge"],
                "steps": attack["attack_obj"].success_steps,
                "attacking_node": target_node.item(), 
                "node_degree": test_adj.sum(dim=1)[target_node].item(),
                "idx_in_test": v_idx,
                "neigh_edges": (filtered_adj.sum() // 2).item(),
            }
            print(status)
            all_stats.append(status)
            if attack["attack_obj"].success_steps != -1:
                break

        saving_idx += 1
        saving_idx = saving_idx % 10
        if saving_idx == 9:
            df = pd.DataFrame(all_stats)
            df.to_csv(f"../../eva_target_stats-{v_idx}.csv")
    
    df = pd.DataFrame(all_stats)
    df.to_csv("../../eva_target_stats.csv")