import torch 
import os
import json
import numpy as np
from collections import defaultdict
from torch_geometric.utils import to_undirected, mask_to_index
from torch_geometric.data import Data
from torch_sparse import SparseTensor

PATH = "/path/to/GraphAD_data"
ATKG_PATH = "/path/to/GraphAD_data/atkg"


def generate_attacked_embeddings(dataset_name, atk_meta_info, changed_node_ids, attacked_texts, emb_model, device, path_prefix=PATH, setting="transductive"):
    """
    Generate and cache embeddings for attacked texts
    
    Args:
        dataset_name: Name of the dataset
        atk_meta_info: Attack metadata dict
        changed_node_ids: List of node IDs that were attacked
        attacked_texts: List of attacked texts corresponding to changed_node_ids
        emb_model: Embedding model type (roberta, etc.)
        device: Device to run on
        path_prefix: Data path prefix
        setting: "inductive" or "transductive"
    
    Returns:
        torch.Tensor: New embeddings for attacked nodes
        List: Node IDs corresponding to embeddings
    """
    # Create cache directory
    attack = atk_meta_info['attack']
    ptb_rate = atk_meta_info['ptb_rate']
    seed = atk_meta_info['seed']
    
    # Use different cache directories for inductive vs transductive
    if attack == 'textfooler':
        cache_dir = f"{path_prefix}/atkg/{dataset_name}/text_textfooler_{setting}"
    elif attack in ['llm', 'llm_Ministral', 'gpt']:
        if attack == 'gpt':
            cache_dir = f"{path_prefix}/atkg/{dataset_name}/llm_gpt-4o-mini_{setting}"
        else:
            cache_dir = f"{path_prefix}/atkg/{dataset_name}/llm_Ministral-8B_{setting}"
    else:
        # Fallback for other attacks
        cache_dir = f"{path_prefix}/atkg/{dataset_name}/{attack}_{setting}"
    
    os.makedirs(cache_dir, exist_ok=True)
    cache_file = f"{cache_dir}/{emb_model}_embeddings_seed{seed}_ptb{int(ptb_rate*100)}.pt"
    
    if os.path.exists(cache_file):
        print(f"Loading cached embeddings from {cache_file}")
        cached_data = torch.load(cache_file, map_location=device)
        return cached_data['embeddings'], cached_data['node_ids']
    
    print(f"Generating new embeddings for {len(attacked_texts)} attacked texts using {emb_model}...")
    
    # Initialize text encoder based on emb_model
    if emb_model == "bow":
        import pickle
        from sklearn.feature_extraction.text import CountVectorizer
        
        # Load existing BOW vocabulary for the dataset
        vocab_dir = f"{path_prefix}/datasets/vocab/{dataset_name}"
        vocab_path = f"{vocab_dir}/bow_vocabulary.pkl"
        
        if not os.path.exists(vocab_path):
            raise FileNotFoundError(f"BOW vocabulary not found at {vocab_path}. Please generate it first using embedding.py")
        
        # Load the pre-trained vectorizer
        with open(vocab_path, 'rb') as f:
            vectorizer = pickle.load(f)
        
        # Transform attacked texts using the existing vocabulary
        bow_matrix = vectorizer.transform(attacked_texts)
        new_embeddings = torch.FloatTensor(bow_matrix.toarray())
        
    else:
        # For neural embeddings (roberta, Mistral-7B, etc.)
        from .lm import TextEncoder
        encoder_type = "LLM" if emb_model in ["Mistral-7B", "Qwen-7B", "Llama-8B", "Qwen3-8B", "Ministral-8B"] else "LM"
        
        # Clear GPU cache before creating encoder
        torch.cuda.empty_cache()
        
        text_encoder = TextEncoder(emb_model, encoder_type, device)
        
        with torch.no_grad():
            new_embeddings = []
            
            # Process one text at a time to avoid memory issues
            for i, text in enumerate(attacked_texts):
                text = "Empty text" if len(text) == 0 else text
                emb = text_encoder.forward(text, pooling="mean")
                new_embeddings.append(emb.cpu())  # Move to CPU immediately
                
                # Clear GPU cache after each text
                torch.cuda.empty_cache()
                
                if (i + 1) % 10 == 0:  # Print progress every 10 texts
                    print(f"Processed {i + 1}/{len(attacked_texts)} texts")
            
            # Concatenate all embeddings
            new_embeddings = torch.cat(new_embeddings, dim=0)
            
        # Clean up encoder
        if hasattr(text_encoder, 'model'):
            del text_encoder.model
        del text_encoder
        torch.cuda.empty_cache()
    
    # Cache the results
    cache_data = {
        'embeddings': new_embeddings,
        'node_ids': changed_node_ids,
        'attack_info': atk_meta_info
    }
    torch.save(cache_data, cache_file)
    print(f"Cached embeddings to {cache_file}")
    
    return new_embeddings.to(device), changed_node_ids


def re_split_data(num_node, train_percent=0.6, val_percent=0.2, test_percent=0.2, device="cuda:0"):
    node_ids = np.arange(num_node)
    np.random.shuffle(node_ids)
    
    train_ids = np.sort(node_ids[:int(num_node * train_percent)])
    val_ids = np.sort(node_ids[int(num_node * train_percent): int(num_node * (train_percent + val_percent))])
    test_ids = np.sort(node_ids[int(num_node * (train_percent + val_percent)): int(num_node * (train_percent + val_percent + test_percent))])
    
    train_mask = torch.tensor(np.array([idx in train_ids for idx in range(num_node)]))
    val_mask = torch.tensor(np.array([idx in val_ids for idx in range(num_node)]))
    test_mask = torch.tensor(np.array([idx in test_ids for idx in range(num_node)])) 

    return train_mask.to(device), val_mask.to(device), test_mask.to(device)
    

def load_graph_dataset(dataset_name, device, re_split=False, path_prefix=PATH):
    if dataset_name.startswith('taglas_'):
        taglas_name = dataset_name[len('taglas_'):]
        graph_data = torch.load(f"{path_prefix}/datasets/taglas_data/{taglas_name}.pt", weights_only=False).to(device)
    else:
        graph_data = torch.load(f"{path_prefix}/datasets/{dataset_name}.pt", weights_only=False).to(device)

    graph_data.edge_index = to_undirected(graph_data.edge_index)
    
    if re_split == 1:
        graph_data.train_mask, graph_data.val_mask, graph_data.test_mask = re_split_data(graph_data.num_nodes, train_percent=0.1, val_percent=0.1, test_percent=0.8, device=device)
    elif re_split == 2:
        graph_data.train_mask, graph_data.val_mask, graph_data.test_mask = re_split_data(graph_data.num_nodes, train_percent=0.6, val_percent=0.2, test_percent=0.2, device=device)

    return graph_data


def load_graph_dataset_for_gnn(dataset_name, device, re_split=False, path_prefix=PATH, emb_model="shallow"):
    graph_data = load_graph_dataset(dataset_name, device, re_split, path_prefix)

    if emb_model != "shallow":
        assert os.path.exists(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt")
        node_feat = torch.load(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
        graph_data.x = node_feat
    
    # TODO: check datasets that need shallow embedding
    # Apply Node2Vec for datasets without shallow embeddings
    if emb_model == "shallow" and dataset_name in ["reddit", "instagram", "computer", "photo", "history"]:
        if os.path.exists(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt"):
            node_feat = torch.load(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt", map_location=device).to(device)
        else:
            from node2vec import Node2Vec
            from torch_geometric.utils.convert import to_networkx
        
            nx_graph = to_networkx(graph_data)
            node2vec = Node2Vec(nx_graph, dimensions=300, walk_length=30, num_walks=10, workers=4)
            node2vec_model = node2vec.fit(window=10, min_count=1, batch_words=4)
            print(node2vec_model.wv.vectors.shape, type(node2vec_model.wv.vectors))
            node_feat = torch.FloatTensor(node2vec_model.wv.vectors).to(device)
            os.makedirs(f"{path_prefix}/datasets/Node2Vec", exist_ok=True)
            torch.save(node_feat, f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt")
        graph_data.x = node_feat
    
    return graph_data


def load_graph_dataset_for_llaga(dataset_name, device, re_split=0, encoder="roberta", seed=0):
    is_inductive = (re_split == 2) or (re_split == 0 and dataset_name == "arxiv")
    if not is_inductive:
        graph_data = load_graph_dataset_for_gnn(dataset_name, device, re_split, path_prefix=PATH, emb_model=encoder)
        return graph_data
    else:
        graph_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(dataset_name, device, re_split, path_prefix=PATH, emb_model=encoder, seed=seed)
        return graph_data, (train_data, val_data, test_data)


def load_inductive_graph_dataset(dataset_name, device, re_split=False, path_prefix=PATH, emb_model="shallow", seed=0):
    if dataset_name.startswith('taglas_'):
        taglas_name = dataset_name[len('taglas_'):]
        graph_data = torch.load(f"{path_prefix}/datasets/taglas_data/{taglas_name}.pt", weights_only=False).to(device)
    else:
        graph_data = torch.load(f"{path_prefix}/datasets/{dataset_name}.pt", weights_only=False).to(device)
    
    # Fix for to_undirected CPU backend issue
    original_device = graph_data.edge_index.device
    graph_data.edge_index = to_undirected(graph_data.edge_index)
    
    train_path = f"{path_prefix}/datasets/inductive/{dataset_name}_train_data_{seed}.pt"
    val_path = f"{path_prefix}/datasets/inductive/{dataset_name}_val_data_{seed}.pt"
    test_path = f"{path_prefix}/datasets/inductive/{dataset_name}_test_data.pt"
    
    if dataset_name == "arxiv":
        assert re_split == 0, "Only re_split == 0 (default split) is supported for dataset arxiv."
    else:
        assert re_split == 2, "Only re_split == 2 is supported for inductive learning."
        graph_data.train_mask, graph_data.val_mask, graph_data.test_mask = re_split_data(graph_data.num_nodes, train_percent=0.6, val_percent=0.2, test_percent=0.2, device=device)
    
    if os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(test_path):
        print("Load inductive splits from cache.")
        train_data = torch.load(train_path, weights_only=False)
        val_data = torch.load(val_path, weights_only=False)
        test_data = torch.load(test_path, weights_only=False)
        train_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if graph_data.train_mask[i]]
        train_val_mask = graph_data.train_mask | graph_data.val_mask
        val_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if train_val_mask[i]]
        # In test phase, we can access all nodes' raw_texts, but only train+val nodes' labels are known  
        test_data.raw_texts = list(graph_data.raw_texts)
        train_data.label_name = graph_data.label_name
        val_data.label_name = graph_data.label_name
        test_data.label_name = graph_data.label_name
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        return graph_data, (train_data, val_data, test_data)
    
    train_mask = graph_data.train_mask
    val_mask = graph_data.val_mask
    test_mask = graph_data.test_mask
    train_val_mask = torch.logical_or(train_mask, val_mask)
    
    row, col = graph_data.edge_index
    adj = SparseTensor(row=row, col=col, sparse_sizes=(graph_data.num_nodes, graph_data.num_nodes))
    
    # Create different adjacency matrices for inductive learning
    adj_train = adj[train_mask][:, train_mask]
    adj_val = adj[train_val_mask][:, train_val_mask]
    
    # Get edge indices for each split
    train_row, train_col, _ = adj_train.coo()
    train_edge_index = torch.stack([train_row, train_col], dim=0)
    
    val_row, val_col, _ = adj_val.coo()
    val_edge_index = torch.stack([val_row, val_col], dim=0)
    
    # Train data (only train nodes)
    train_idx = torch.where(train_mask)[0]
    train_data = Data(
        x=graph_data.x[train_mask],
        edge_index=train_edge_index,
        node_ids=train_idx,
        y=graph_data.y[train_mask],
    )
    
    # Val data (train + val nodes)
    train_val_idx = torch.where(train_val_mask)[0]
    val_data = Data(
        x=graph_data.x[train_val_mask],
        edge_index=val_edge_index,
        node_ids=train_val_idx,
        y=graph_data.y[train_val_mask],
    )
    
    # test data (all nodes)
    test_data = Data(
        x=graph_data.x,
        edge_index=graph_data.edge_index,
        node_ids=torch.arange(graph_data.num_nodes),
        y=graph_data.y,
    )
    
    # Save the splits
    os.makedirs(f"{path_prefix}/datasets/inductive", exist_ok=True)
    torch.save(train_data.cpu(), train_path)
    torch.save(val_data.cpu(), val_path)
    torch.save(test_data.cpu(), test_path)
    
    train_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if train_mask[i]]
    val_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if train_val_mask[i]]
    # In test phase, we can access all nodes' raw_texts, but only train+val nodes' labels are known
    test_data.raw_texts = list(graph_data.raw_texts)
    train_data.label_name = graph_data.label_name
    val_data.label_name = graph_data.label_name
    test_data.label_name = graph_data.label_name

    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    return graph_data, (train_data, val_data, test_data)


def load_inductive_graph_dataset_for_gnn(dataset_name, device, re_split=False, path_prefix=PATH, emb_model="shallow", seed=0):
    graph_data, (train_data, val_data, test_data) = load_inductive_graph_dataset(dataset_name, device, re_split, path_prefix, emb_model, seed)

    if emb_model != "shallow":
        assert os.path.exists(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt")
        node_feat = torch.load(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
        graph_data.x = node_feat
        # Update features in train_data, val_data, and test_data
        train_data.x = node_feat[train_data.node_ids]
        val_data.x = node_feat[val_data.node_ids]
        test_data.x = node_feat
    
    # TODO: check datasets that need shallow embedding
    # Apply Node2Vec for datasets without shallow embeddings
    if emb_model == "shallow" and dataset_name in ["reddit", "instagram", "computer", "photo", "history"]:
        if os.path.exists(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt"):
            node_feat = torch.load(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt", map_location=device).to(device)
        else:
            from node2vec import Node2Vec
            from torch_geometric.utils.convert import to_networkx
        
            nx_graph = to_networkx(graph_data)
            node2vec = Node2Vec(nx_graph, dimensions=300, walk_length=30, num_walks=10, workers=4)
            node2vec_model = node2vec.fit(window=10, min_count=1, batch_words=4)
            print(node2vec_model.wv.vectors.shape, type(node2vec_model.wv.vectors))
            node_feat = torch.FloatTensor(node2vec_model.wv.vectors).to(device)
            os.makedirs(f"{path_prefix}/datasets/Node2Vec", exist_ok=True)
            torch.save(node_feat, f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt")
        graph_data.x = node_feat
        train_data.x = node_feat[train_data.node_ids]
        val_data.x = node_feat[val_data.node_ids]
        test_data.x = node_feat
    
    graph_data = graph_data.to(device)
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    return graph_data, (train_data, val_data, test_data) 


def load_atk_graph_dataset(dataset_name, device, atk_meta_info, re_split=False, path_prefix=PATH):
    """
    Load attacked graph dataset for transductive setting
    
    Args:
        dataset_name: Name of the dataset
        device: Device to load data on
        atk_meta_info: Dictionary containing attack information
            - attack: attack type (pgd, grbcd, prbcd, text_fooler, etc.)
            - ptb_rate: perturbation rate
            - atk_emb_type: embedding type used for attack (bow, roberta, Mistral-7B)
            - seed: random seed
            - atk_type: attack type ('structure', 'text', 'hybrid')
        re_split: Whether to use re-split
        path_prefix: Path prefix for original data
    
    Returns:
        Attacked graph data
    """
    # Load original graph data
    if dataset_name.startswith('taglas_'):
        taglas_name = dataset_name[len('taglas_'):]
        graph_data = torch.load(f"{path_prefix}/datasets/taglas_data/{taglas_name}.pt", weights_only=False).to(device)
    else:
        graph_data = torch.load(f"{path_prefix}/datasets/{dataset_name}.pt", weights_only=False).to(device)

    # Extract attack metadata
    attack = atk_meta_info['attack']
    ptb_rate = atk_meta_info['ptb_rate']
    atk_emb_type = atk_meta_info['atk_emb_type']
    seed = atk_meta_info['seed']
    atk_type = atk_meta_info.get('atk_type', 'structure')  # Default to structure attack
    
    # Determine attack type based on attack name if not specified
    if atk_type == 'structure' or attack in ['pgd', 'grbcd', 'prbcd', 'dice', 'metattack']:
        # Structure attack: replace edge_index
        atk_path = f"{ATKG_PATH}/{dataset_name}/{attack}/{atk_emb_type}_{int(ptb_rate*100)}_{seed}.pt"
        
        if not os.path.exists(atk_path):
            raise FileNotFoundError(f"Attack file not found: {atk_path}")
        
        attacked_edge_index = torch.load(atk_path, map_location=device)
        graph_data.edge_index = to_undirected(attacked_edge_index)
            
    elif atk_type == 'text' or attack in ['textfooler', 'llm']:
        # Text attack: replace raw_texts in test_data
        if attack == 'textfooler':
            atk_path = f"{ATKG_PATH}/{dataset_name}/text_textfooler_transductive/attacked_texts_seed{seed}_ptb{int(ptb_rate*100)}.json"
        elif attack == 'llm':
            atk_path = f"{ATKG_PATH}/{dataset_name}/llm_Ministral-8B_transductive/attacked_texts_seed{seed}_ptb{int(ptb_rate*100)}.json"
        elif attack == 'gpt':
            atk_path = f"{ATKG_PATH}/{dataset_name}/llm_gpt-4o-mini_transductive/attacked_texts_seed{seed}_ptb{int(ptb_rate*100)}.json"

        # Load attacked texts
        with open(atk_path, 'r') as f:
            attacked_data = json.load(f)
        
        # Handle both old and new format
        if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
            # New format with metadata
            attacked_texts_data = attacked_data["attacked_texts"]
            node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
        else:
            # Old format (direct mapping)
            node_to_text = attacked_data
        
        # Replace raw_texts with attacked versions for full graph (transductive)
        new_raw_texts = list(graph_data.raw_texts)
        for node_id, attacked_text in node_to_text.items():
            new_raw_texts[int(node_id)] = attacked_text
        graph_data.raw_texts = new_raw_texts
        
    else:
        raise ValueError(f"Unsupported attack type: {atk_type} or attack: {attack}")
    
    # Apply re-split if needed
    if re_split == 1:
        graph_data.train_mask, graph_data.val_mask, graph_data.test_mask = re_split_data(graph_data.num_nodes, train_percent=0.1, val_percent=0.1, test_percent=0.8, device=device)
    elif re_split == 2:
        graph_data.train_mask, graph_data.val_mask, graph_data.test_mask = re_split_data(graph_data.num_nodes, train_percent=0.6, val_percent=0.2, test_percent=0.2, device=device)

    return graph_data


def load_inductive_atk_graph_dataset(dataset_name, device, atk_meta_info, re_split=False, path_prefix=PATH, emb_model="shallow", seed=0):
    """
    Load attacked graph dataset for inductive setting
    
    Args:
        dataset_name: Name of the dataset
        device: Device to load data on
        atk_meta_info: Dictionary containing attack information
            - attack: attack type (pgd, grbcd, prbcd, text_fooler, etc.)
            - ptb_rate: perturbation rate
            - atk_emb_type: embedding type used for attack (bow, roberta, Mistral-7B)
            - seed: random seed
            - atk_type: attack type ('structure', 'text', 'hybrid')
        re_split: Re-split type (should be 2 for inductive)
        path_prefix: Path prefix for original data
        emb_model: Embedding model type
        seed: Random seed for data splitting
    
    Returns:
        (full_graph_data, (train_data, val_data, test_data)) where test_data contains attacked content
    """
    # Load original inductive dataset
    graph_data, (train_data, val_data, test_data) = load_inductive_graph_dataset(
        dataset_name, device, re_split, path_prefix, emb_model, seed
    )
    
    # Extract attack metadata
    attack = atk_meta_info['attack']
    ptb_rate = atk_meta_info['ptb_rate']
    atk_emb_type = atk_meta_info['atk_emb_type']
    atk_seed = atk_meta_info['seed']
    
    # Determine attack type based on attack name
    if attack in ['textfooler', 'llm', 'llm_Ministral', 'gpt']:
        atk_type = 'text'
    elif attack in ['pgd', 'grbcd', 'prbcd', 'dice', 'metattack'] or attack.startswith('pgdguard'):
        atk_type = 'structure'
    elif attack == 'wtgia':
        atk_type = 'hybrid'
    else:
        atk_type = atk_meta_info.get('atk_type', 'structure')  # Default to structure attack

    train_path = f"{path_prefix}/datasets/inductive/{dataset_name}_train_data_{seed}.pt"
    val_path = f"{path_prefix}/datasets/inductive/{dataset_name}_val_data_{seed}.pt"
    test_path = f"{path_prefix}/datasets/inductive/{dataset_name}_test_data.pt"

    if os.path.exists(train_path) and os.path.exists(val_path) and os.path.exists(test_path):
        print("Load inductive splits from cache.")
        train_data = torch.load(train_path, weights_only=False)
        val_data = torch.load(val_path, weights_only=False)
        test_data = torch.load(test_path, weights_only=False)
        train_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if graph_data.train_mask[i]]
        train_val_mask = graph_data.train_mask | graph_data.val_mask
        val_data.raw_texts = [graph_data.raw_texts[i] for i in range(len(graph_data.raw_texts)) if train_val_mask[i]]
        test_data.raw_texts = list(graph_data.raw_texts)
        train_data.label_name = graph_data.label_name
        val_data.label_name = graph_data.label_name
        test_data.label_name = graph_data.label_name
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        # Determine attack type based on attack name if not specified
        if atk_type == 'structure' or attack in ['pgd', 'grbcd', 'prbcd', 'dice', 'metattack'] or attack.startswith('pgdguard'):
            # Structure attack: replace test_data edge_index
            if attack.startswith('pgdguard'):
                # pgdguard attacks use roberta embeddings and have suffix format (pgdguard_0.5)
                atk_path = f"{ATKG_PATH}/{dataset_name}/{attack}/roberta_{int(ptb_rate*100)}_{atk_seed}.pt"
            else:
                atk_path = f"{ATKG_PATH}/{dataset_name}/{attack}/{atk_emb_type}_{int(ptb_rate*100)}_{atk_seed}.pt"            
            
            if not os.path.exists(atk_path):
                raise FileNotFoundError(f"Attack file not found: {atk_path}")
                
            attacked_edge_index = torch.load(atk_path, map_location=device)
            test_data.edge_index = to_undirected(attacked_edge_index)
            print(f"Applied structure attack: {attack} with ptb_rate={ptb_rate} for {dataset_name}")
            
        elif atk_type == 'text' or attack in ['textfooler', 'llm', 'gpt']:
            # Text attack: replace raw_texts in test_data
            if attack == 'textfooler':
                atk_path = f"{ATKG_PATH}/{dataset_name}/text_textfooler_inductive/attacked_texts_seed{atk_seed}_ptb{int(ptb_rate*100)}.json"
            elif attack == 'llm':
                atk_path = f"{ATKG_PATH}/{dataset_name}/llm_Ministral-8B_inductive/attacked_texts_seed{atk_seed}_ptb{int(ptb_rate*100)}.json"
            elif attack == 'gpt':
                atk_path = f"{ATKG_PATH}/{dataset_name}/llm_gpt-4o-mini_inductive/attacked_texts_seed{atk_seed}_ptb{int(ptb_rate*100)}.json"

            # Load attacked texts
            with open(atk_path, 'r') as f:
                attacked_data = json.load(f)
            
            if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
                attacked_texts_data = attacked_data["attacked_texts"]
                node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
            else:
                node_to_text = attacked_data

            # Replace attacked texts using global node ID to local index mapping
            replaced_count = 0
            new_test_raw_texts = list(test_data.raw_texts)
            
            # Create mapping from global node ID to local index
            global_to_local = {}
            for local_idx in range(len(test_data.node_ids)):
                global_id = test_data.node_ids[local_idx].item()
                global_to_local[global_id] = local_idx
            
            # Replace attacked texts using correct mapping
            for global_node_id, attacked_text in node_to_text.items():
                global_id = int(global_node_id)
                if global_id in global_to_local:
                    local_idx = global_to_local[global_id]
                    new_test_raw_texts[local_idx] = attacked_text
                    replaced_count += 1
            
            test_data.raw_texts = new_test_raw_texts
            new_full_raw_texts = list(graph_data.raw_texts)
            for node_id, attacked_text in node_to_text.items():
                new_full_raw_texts[int(node_id)] = attacked_text
            graph_data.raw_texts = new_full_raw_texts
            
            print(f"Updated {len(node_to_text)} attacked texts in inductive setting")
            
        elif atk_type == 'hybrid' or attack == 'wtgia':
            # Handle WTGIA: load and append injected features
            if attack == 'wtgia':
                atkg_path = f"{path_prefix}/atkg"
                injection = atk_meta_info.get('injection', 'atdgia')
                attack_name = f"wtgia_{injection}"
                
                # Load edge index
                edge_path = f"{atkg_path}/{dataset_name}/{attack_name}/{atk_emb_type}_{int(ptb_rate*100)}_{atk_seed}.pt"
                texts_path = f"{atkg_path}/{dataset_name}/{attack_name}_texts/llama-3.1-8B_{int(ptb_rate*100)}_{atk_seed}.json"
                
                if not os.path.exists(edge_path):
                    raise FileNotFoundError(f"WTGIA data not found: {edge_path}")
                
                attacked_edge_index = torch.load(edge_path, map_location=device)
                # Load texts if available
                injected_texts = []
                if os.path.exists(texts_path):
                    with open(texts_path, 'r') as f:
                        text_data = json.load(f)
                        injected_texts = text_data.get('texts', [])
                
                n_inject = len(injected_texts)
                
                test_data.edge_index = attacked_edge_index
                test_data.y = torch.cat([test_data.y, torch.zeros(n_inject, dtype=torch.long, device=device)])
                test_data.raw_texts.extend(injected_texts)
                test_data.num_nodes += n_inject
                                
                graph_data.edge_index = attacked_edge_index
                graph_data.y = torch.cat([graph_data.y, torch.zeros(n_inject, dtype=torch.long, device=device)])
                graph_data.raw_texts.extend(injected_texts if injected_texts else [f"Injected node {i}" for i in range(n_inject)])
                graph_data.num_nodes += n_inject

                print(test_data, graph_data)
                
                # Update masks to exclude injected nodes from evaluation
                for mask_name in ['train_mask', 'val_mask', 'test_mask']:
                    if hasattr(graph_data, mask_name):
                        original_mask = getattr(graph_data, mask_name)
                        extended_mask = torch.cat([original_mask, torch.zeros(n_inject, dtype=torch.bool, device=device)])
                        setattr(graph_data, mask_name, extended_mask)
                
                print(f"WTGIA: Loaded {n_inject} injected nodes in inductive setting")

    return graph_data, (train_data, val_data, test_data)


def load_inductive_atk_graph_dataset_for_gnn(dataset_name, device, atk_meta_info, re_split=False, path_prefix=PATH, emb_model="shallow", seed=0):
    """
    Load attacked graph dataset for GNN models in inductive setting
    """
    graph_data, (train_data, val_data, test_data) = load_inductive_atk_graph_dataset(
        dataset_name, device, atk_meta_info, re_split, path_prefix, emb_model, seed
    )

    # Update node features based on embedding model
    if emb_model != "shallow":
        assert os.path.exists(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt")
        node_feat = torch.load(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
        graph_data.x = node_feat.clone()
        train_data.x = node_feat[train_data.node_ids].clone()
        val_data.x = node_feat[val_data.node_ids].clone()
        test_data.x = node_feat.clone()
        
        # Check attack type and regenerate embeddings if needed
        attack = atk_meta_info['attack']
        
        # Determine attack type based on attack name
        if attack in ['textfooler', 'llm', 'llm_Ministral', 'gpt']:
            atk_type = 'text'
        elif attack in ['pgd', 'grbcd', 'prbcd', 'dice', 'metattack'] or attack.startswith('pgdguard'):
            atk_type = 'structure'
        elif attack == 'wtgia':
            atk_type = 'hybrid'
        else:
            atk_type = atk_meta_info.get('atk_type', 'structure')
        
        # Handle text attacks - regenerate embeddings for attacked nodes
        if atk_type == 'text':
            # Load attack information to get which nodes were changed
            if attack == 'textfooler':
                atk_path = f"{ATKG_PATH}/{dataset_name}/text_textfooler_inductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
            elif attack in ['llm', 'gpt']:
                if attack == 'llm':
                    atk_path = f"{ATKG_PATH}/{dataset_name}/llm_Ministral-8B_inductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
                elif attack == 'gpt':
                    atk_path = f"{ATKG_PATH}/{dataset_name}/llm_gpt-4o-mini_inductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
            
            if os.path.exists(atk_path):
                with open(atk_path, 'r') as f:
                    attacked_data = json.load(f)
                
                # Extract changed nodes and texts
                if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
                    attacked_texts_data = attacked_data["attacked_texts"]
                    changed_node_ids = [item['node_id'] for item in attacked_texts_data]
                    attacked_texts = [item['attacked_text'] for item in attacked_texts_data]
                else:
                    changed_node_ids = list(attacked_data.keys())
                    attacked_texts = list(attacked_data.values())
                
                # Generate new embeddings for attacked texts
                new_embeddings, node_ids = generate_attacked_embeddings(
                    dataset_name, atk_meta_info, changed_node_ids, attacked_texts, 
                    emb_model, device, path_prefix, setting="inductive"
                )
                
                # Update embeddings for attacked nodes in all data splits
                for i, node_id in enumerate(node_ids):
                    global_node_id = int(node_id)
                    graph_data.x[global_node_id] = new_embeddings[i]
                    test_data.x[global_node_id] = new_embeddings[i]
                    
                    # Update train_data and val_data if the node is present
                    if hasattr(train_data, 'node_ids'):
                        train_local_indices = (train_data.node_ids == global_node_id).nonzero(as_tuple=True)[0]
                        if len(train_local_indices) > 0:
                            train_data.x[train_local_indices[0]] = new_embeddings[i]
                    
                    if hasattr(val_data, 'node_ids'):
                        val_local_indices = (val_data.node_ids == global_node_id).nonzero(as_tuple=True)[0]
                        if len(val_local_indices) > 0:
                            val_data.x[val_local_indices[0]] = new_embeddings[i]
                
                print(f"Updated embeddings for {len(node_ids)} attacked nodes in inductive setting")
        
        # Handle WTGIA attacks - regenerate embeddings for injected nodes
        elif atk_type == 'hybrid' and attack == 'wtgia':
            atkg_path = f"{path_prefix}/atkg"
            injection = atk_meta_info.get('injection', 'atdgia')
            attack_name = f"wtgia_{injection}"
            ptb_rate = atk_meta_info['ptb_rate']
            atk_seed = atk_meta_info['seed']
            
            # Load injected texts for embedding generation
            texts_path = f"{atkg_path}/{dataset_name}/{attack_name}_texts/llama-3.1-8B_{int(ptb_rate*100)}_{atk_seed}.json"
            
            if os.path.exists(texts_path):
                with open(texts_path, 'r') as f:
                    text_data = json.load(f)
                    injected_texts = text_data.get('texts', [])
                
                if injected_texts:
                    # Generate embeddings for injected texts
                    print(f"Generating embeddings for {len(injected_texts)} injected nodes using {emb_model}")
                    
                    # Create dummy node IDs for injected nodes (they don't exist in original graph)
                    original_num_nodes = len(graph_data.raw_texts) - len(injected_texts)
                    injected_node_ids = list(range(original_num_nodes, original_num_nodes + len(injected_texts)))
                    
                    # Generate new embeddings for injected texts
                    new_embeddings, _ = generate_attacked_embeddings(
                        dataset_name, atk_meta_info, injected_node_ids, injected_texts, 
                        emb_model, device, path_prefix, setting="inductive", 
                    )
                    
                    # Replace the injected node embeddings (last n_inject nodes)
                    n_inject = len(injected_texts)
                    
                    # Update embeddings for injected nodes
                    graph_data.x = torch.cat([
                        graph_data.x[:original_num_nodes],  # Original node embeddings
                        new_embeddings                       # New injected node embeddings
                    ], dim=0)
                    
                    # Update test_data: keep original embeddings + append new injected embeddings  
                    test_data.x = torch.cat([
                        test_data.x[:original_num_nodes],   # Original node embeddings
                        new_embeddings                       # New injected node embeddings
                    ], dim=0)

                    print(graph_data, test_data)
                    
                    print(f"Updated embeddings for {n_inject} injected nodes in WTGIA attack")
                else:
                    print("No injected texts found for WTGIA attack, using default embeddings")
            else:
                print(f"Injected texts file not found: {texts_path}, using default embeddings")
    
    # Apply Node2Vec for datasets without shallow embeddings
    if emb_model == "shallow" and dataset_name in ["reddit", "instagram", "computer", "photo", "history"]:
        if os.path.exists(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt"):
            node_feat = torch.load(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt", map_location=device).to(device)
        else:
            from node2vec import Node2Vec
            from torch_geometric.utils.convert import to_networkx
        
            nx_graph = to_networkx(graph_data)
            node2vec = Node2Vec(nx_graph, dimensions=300, walk_length=30, num_walks=10, workers=4)
            node2vec_model = node2vec.fit(window=10, min_count=1, batch_words=4)
            print(node2vec_model.wv.vectors.shape, type(node2vec_model.wv.vectors))
            node_feat = torch.FloatTensor(node2vec_model.wv.vectors).to(device)
            os.makedirs(f"{path_prefix}/datasets/Node2Vec", exist_ok=True)
            torch.save(node_feat, f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt")
        graph_data.x = node_feat
        train_data.x = node_feat[train_data.node_ids]
        val_data.x = node_feat[val_data.node_ids]
        test_data.x = node_feat
    
    graph_data = graph_data.to(device)
    train_data = train_data.to(device)
    val_data = val_data.to(device)
    test_data = test_data.to(device)
    
    return graph_data, (train_data, val_data, test_data)


def load_atk_graph_dataset_for_llaga(dataset_name, device, atk_meta_info, re_split=0, encoder="roberta", seed=0):
    """
    Load attacked graph dataset for LLaGA model
    """
    is_inductive = (re_split == 2) or (re_split == 0 and dataset_name == "arxiv")
    if not is_inductive:
        # Transductive setting - use basic attack loader
        graph_data = load_atk_graph_dataset_for_gnn(dataset_name, device, atk_meta_info, emb_model=encoder, re_split=re_split, path_prefix=PATH)
        return graph_data
    else:
        # Inductive setting - use inductive attack loader for GNNs
        graph_data, (train_data, val_data, test_data) = load_inductive_atk_graph_dataset_for_gnn(
            dataset_name, device, atk_meta_info, re_split, path_prefix=PATH, emb_model=encoder, seed=seed
        )
        return graph_data, (train_data, val_data, test_data)


def load_atk_graph_dataset_for_gnn(dataset_name, device, atk_meta_info, re_split=False, path_prefix=PATH, emb_model="shallow"):
    """
    Load text attacked graph dataset for GNN models in transductive setting
    Re-embeds attacked texts and caches results
    """
    # First use base function to get attacked data with text changes
    graph_data = load_atk_graph_dataset(dataset_name, device, atk_meta_info, re_split, path_prefix)
    
    # Load original embeddings first (without any attack considerations)
    if emb_model != "shallow":
        assert os.path.exists(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt")
        node_feat = torch.load(f"{path_prefix}/datasets/{emb_model}/{dataset_name}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
        graph_data.x = node_feat.clone()
    else:
        # Handle shallow embeddings (Node2Vec)
        if dataset_name in ["reddit", "instagram", "computer", "photo", "history"]:
            if os.path.exists(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt"):
                node_feat = torch.load(f"{path_prefix}/datasets/Node2Vec/{dataset_name}.pt", map_location=device).to(device)
                graph_data.x = node_feat.clone()

    # Only regenerate embeddings if using neural encoders and the attack involves text changes
    attack = atk_meta_info['attack']
    if emb_model not in ["shallow"] and attack in ['textfooler', 'llm', 'llm_Ministral', 'gpt']:
        # Load attack information to get which nodes were changed
        atkg_path = f"{path_prefix}/atkg"
        if attack == 'textfooler':
            atk_path = f"{atkg_path}/{dataset_name}/text_textfooler_transductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
        elif attack in ['llm', 'llm_Ministral']:
            atk_path = f"{atkg_path}/{dataset_name}/llm_Ministral-8B_transductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
        elif attack in ['gpt']:
            atk_path = f"{atkg_path}/{dataset_name}/llm_gpt-4o-mini_transductive/attacked_texts_seed{atk_meta_info['seed']}_ptb{int(atk_meta_info['ptb_rate']*100)}.json"
        if os.path.exists(atk_path):
            with open(atk_path, 'r') as f:
                attacked_data = json.load(f)
            
            # Extract changed nodes and texts
            if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
                attacked_texts_data = attacked_data["attacked_texts"]
                changed_node_ids = [item['node_id'] for item in attacked_texts_data]
                attacked_texts = [item['attacked_text'] for item in attacked_texts_data]
            else:
                changed_node_ids = list(attacked_data.keys())
                attacked_texts = list(attacked_data.values())
            
            # Generate new embeddings for attacked texts
            new_embeddings, node_ids = generate_attacked_embeddings(
                dataset_name, atk_meta_info, changed_node_ids, attacked_texts, emb_model, device, path_prefix, setting="transductive"
            )
            
            # Update embeddings for attacked nodes
            for i, node_id in enumerate(node_ids):
                graph_data.x[int(node_id)] = new_embeddings[i]
            
            print(f"Updated embeddings for {len(node_ids)} nodes in transductive setting")
    
    return graph_data 
