import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import os
import numpy as np
import argparse
import logging
import json
from datetime import datetime
from itertools import product

from torch_geometric.utils import to_undirected
from torch_geometric.data import Data

from greatx.nn.models import GCN
from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping
from greatx.utils import wrapper

import sys
sys.path.append("../")
from common import set_seed, load_inductive_graph_dataset_for_gnn, load_inductive_atk_graph_dataset_for_gnn

import yaml


class AutoGCN(nn.Module):
    """GCN variant that simulates auto behavior for text attacks:
    1) Training: Add text_attacked class for attack detection
    2) Test: Predict if nodes are text attacked, apply masking accordingly
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: list = [16], acts: list = ['relu'],
                 dropout: float = 0.5, bias: bool = True, 
                 normalize: bool = True, similarity_threshold: float = 0.5):
        super().__init__()
        
        self.original_out_channels = out_channels
        # Add one extra class for text_attacked detection during training
        self.attack_out_channels = out_channels + 1
        self.similarity_threshold = similarity_threshold
        
        # Build GCN layers
        conv = []
        assert len(hids) == len(acts)
        for hid, act in zip(hids, acts):
            conv.append(GCNConv(in_channels, hid, bias=bias, normalize=normalize))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(dropout))
            in_channels = hid
        
        # Final layer that can output both normal classes + text_attacked
        self.conv_layers = Sequential(*conv)
        self.classifier = GCNConv(in_channels, self.attack_out_channels, bias=bias, normalize=normalize)
        
    def reset_parameters(self):
        self.conv_layers.reset_parameters()
        self.classifier.reset_parameters()
        
    def filter_edges_by_similarity(self, x, edge_index):
        """Apply similarity filtering to edges"""
        if edge_index.shape[1] == 0:
            return edge_index
            
        # Ensure edge_index is on the same device as x
        if edge_index.device != x.device:
            edge_index = edge_index.to(x.device)
            
        # Get source and target node indices  
        src_nodes = edge_index[0]
        tgt_nodes = edge_index[1]
        
        # Get embeddings for source and target nodes
        src_embs = x[src_nodes] 
        tgt_embs = x[tgt_nodes]
        
        # Compute cosine similarity
        similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
        
        # Filter edges where similarity >= threshold
        keep_mask = similarities >= self.similarity_threshold
        filtered_edge_index = edge_index[:, keep_mask]
        
        return filtered_edge_index
        
    def forward(self, x, edge_index, edge_weight=None, training=True, apply_auto_logic=True, return_stats=False):
        """Forward pass with auto-like behavior during inference for text attacks"""
        
        if training or not apply_auto_logic:
            # During training: standard forward pass
            h = self.conv_layers(x, edge_index, edge_weight)
            result = self.classifier(h, edge_index, edge_weight)
            if return_stats:
                return result, {}
            return result
        
        else:
            # During inference: apply auto-like 3-stage process
            stats = {}
            
            # Stage 1: Attack detection
            h = self.conv_layers(x, edge_index, edge_weight)
            stage1_logits = self.classifier(h, edge_index, edge_weight)
            stage1_pred = torch.argmax(stage1_logits, dim=1)
            
            # Identify text-attacked nodes (last class index)
            text_attacked_mask = (stage1_pred == self.attack_out_channels - 1)
            text_attacked_indices = text_attacked_mask.nonzero(as_tuple=False).squeeze(-1)
            
            stats['text_attacked_detected'] = len(text_attacked_indices)
            stats['total_nodes'] = x.shape[0]
            
            # Stage 2: For text-attacked nodes, mask their features and re-predict
            text_recovery_used = 0
            if len(text_attacked_indices) > 0:
                masked_x = x.clone()
                masked_x[text_attacked_indices] = 0  # Mask text-attacked node features
                
                h_masked = self.conv_layers(masked_x, edge_index, edge_weight) 
                stage2_logits = self.classifier(h_masked, edge_index, edge_weight)
                
                # Update predictions for text-attacked nodes, only use original classes
                stage1_logits[text_attacked_indices, :-1] = stage2_logits[text_attacked_indices, :-1]
                stage1_logits[text_attacked_indices, -1] = -float('inf')  # Remove text_attacked class
                text_recovery_used = len(text_attacked_indices)
            
            stats['text_recovery_used'] = text_recovery_used
            
            # Stage 3: Filter edges for normal nodes (remove text_attacked edges + similarity filter)
            original_edges = edge_index.shape[1]
            
            # First remove edges involving detected text_attacked nodes (harmful to normal nodes)
            if len(text_attacked_indices) > 0:
                text_attacked_set = set(text_attacked_indices.cpu().numpy().tolist())
                edge_mask = ~torch.tensor([
                    (src.item() in text_attacked_set) or (tgt.item() in text_attacked_set)
                    for src, tgt in zip(edge_index[0], edge_index[1])
                ], device=edge_index.device)
                edges_without_text_attacked = edge_index[:, edge_mask]
            else:
                edges_without_text_attacked = edge_index
            
            # Then apply similarity filtering to remaining edges
            filtered_edge_index = self.filter_edges_by_similarity(x, edges_without_text_attacked)
            filtered_edges = filtered_edge_index.shape[1]
            
            stats['original_edges'] = original_edges
            stats['filtered_edges'] = filtered_edges
            stats['edges_removed'] = original_edges - filtered_edges
            
            # Re-compute with double-filtered edges for non-text-attacked nodes
            structure_recovery_used = 0
            if filtered_edge_index.shape[1] != edge_index.shape[1]:
                h_filtered = self.conv_layers(x, filtered_edge_index, edge_weight)
                stage3_logits = self.classifier(h_filtered, filtered_edge_index, edge_weight)
                
                # For non-text-attacked nodes, use filtered result
                non_text_attacked_mask = ~text_attacked_mask
                stage1_logits[non_text_attacked_mask, :-1] = stage3_logits[non_text_attacked_mask, :-1]
                stage1_logits[non_text_attacked_mask, -1] = -float('inf')  # Remove text_attacked class
                structure_recovery_used = non_text_attacked_mask.sum().item()
            
            stats['structure_recovery_used'] = structure_recovery_used
            
            # Return only original classes (remove text_attacked class)
            final_logits = stage1_logits[:, :-1]
            
            if return_stats:
                return final_logits, stats
            return final_logits


def create_training_data_with_attacks(train_data, attack_ratio=0.15):
    """Create training data with text_attacked samples"""
    num_nodes = train_data.x.shape[0]
    num_classes = train_data.y.max().item() + 1
    
    # Create augmented labels (add text_attacked class)
    augmented_y = train_data.y.clone()
    
    # Select nodes to mark as text_attacked
    num_attack_samples = max(1, int(num_nodes * attack_ratio))
    attack_indices = torch.randperm(num_nodes)[:num_attack_samples]
    
    # Mark selected nodes as text_attacked (last class)
    augmented_y[attack_indices] = num_classes  # text_attacked class
    
    # Create new training data
    augmented_train_data = Data(
        x=train_data.x,
        y=augmented_y,
        edge_index=train_data.edge_index,
        edge_attr=getattr(train_data, 'edge_attr', None)
    )
    
    return augmented_train_data


# ------------ Argument Parsing ------------
parser = argparse.ArgumentParser()
parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data/")
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument("--dataset", type=str, default="cora")
parser.add_argument("--attack", type=str, default="gpt", choices=['textfooler', 'llm', 'llm_Ministral', 'gpt'])
parser.add_argument("--atk_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--def_emb_type", type=str, default="bow", choices=['bow', 'roberta', 'Mistral-7B', 'MiniLM'])
parser.add_argument("--re_split", type=int, default=2)
parser.add_argument("--ptb_rate", type=float, default=0.1)
parser.add_argument("--device", type=int, default=0)
parser.add_argument("--patience", type=int, default=50)
parser.add_argument("--epochs", type=int, default=400)
parser.add_argument("--similarity_threshold", type=float, default=0.5)
parser.add_argument("--attack_ratio", type=float, default=0.15)
parser.add_argument("--use_existing_logs", action='store_true')
args = parser.parse_args()

# ------------ Load Config ------------
with open(args.config, 'r') as f:
    config = yaml.safe_load(f)

# Get GCN hyperparameters
model_config = config['hyperparams'].get('gcn', {})
hyperparams = {}
for param_name, param_values in model_config.items():
    hyperparams[param_name] = param_values

assert args.re_split == 2
if args.dataset != 'arxiv':
    assert args.re_split == 2, "Only inductive split with re_split=2 (6/2/2) is supported"
    seeds = range(3)
else:
    args.re_split = 0
    assert args.re_split == 0, "Only inductive split with re_split=0 (original split) is supported for dataset arxiv"
    seeds = [0]

# ------------ Fixed Settings ------------
root_path = args.root_path
device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

# ------------ Logging Setup ------------
log_dir = f"eval_logs_ind_text/{args.dataset}/auto_gcn/{args.attack}"
os.makedirs(log_dir, exist_ok=True)
log_file = f"{log_dir}/{args.atk_emb_type}_{args.def_emb_type}_ptb{int(args.ptb_rate*100)}_sim{int(args.similarity_threshold*100)}.log"

# Check if log file exists, skip experiment if it does
if args.use_existing_logs and os.path.exists(log_file):
    print(f"Log file {log_file} already exists. Experiment already completed. Skipping...")
    exit(0)

logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w', format='%(message)s')
logger = logging.getLogger()

logger.info("=" * 80)
logger.info(f"EXPERIMENT - Dataset: {args.dataset}, Model: AutoGCN Text")
logger.info(f"Text Attack: {args.attack}")
logger.info(f"Embedding: atk={args.atk_emb_type}, def={args.def_emb_type}, ptb_rate={args.ptb_rate}")
logger.info(f"Auto Config: similarity_threshold={args.similarity_threshold}, attack_ratio={args.attack_ratio}")
logger.info(f"Training: epochs={args.epochs}, patience={args.patience}")
logger.info(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
logger.info("=" * 80)

# ------------ Run over hyperparams ------------
param_names = list(hyperparams.keys())
param_values = list(hyperparams.values())
param_combinations = list(product(*param_values))

for params in param_combinations:
    param_dict = dict(zip(param_names, params))
    
    # Extract parameters
    lr = param_dict.get('learning_rates', 0.01)
    wd = param_dict.get('weight_decays', 0.0)
    dropout = param_dict.get('dropouts', 0.5)
    if args.dataset in ['cora', 'citeseer', 'pubmed', 'instagram', 'wikics']:
        hidden_dim = 128
    else:
        hidden_dim = 256
    
    clean_accs, attacked_accs = [], []
    clean_val_accs, attacked_val_accs = [], []
    
    logger.info("-" * 60)
    logger.info(f"Hyperparams: {param_dict}")
    start_time = time.time()
    
    for seed in seeds:
        set_seed(seed)
        
        # Load clean data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split, 
            path_prefix=root_path, emb_model=args.def_emb_type, seed=seed
        )
        full_data = full_data.to(device)
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)

        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        
        # Create training data with text_attacked samples
        augmented_train_data = create_training_data_with_attacks(train_data, args.attack_ratio)
        
        # Initialize AutoGCN model
        model = AutoGCN(
            num_features, num_classes, 
            hids=[hidden_dim], acts=['relu'], 
            dropout=dropout,
            similarity_threshold=args.similarity_threshold
        )
        
        trainer = Trainer(model, device=device, verbose=0)
        trainer.reset_optimizer(lr=lr, weight_decay=wd)
        ckp = ModelCheckpoint(f'auto_gcn_text_{args.dataset}_{args.attack}_{args.atk_emb_type}_{args.def_emb_type}_{args.ptb_rate}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')

        num_train_nodes = train_data.x.shape[0]
        num_val_nodes = val_data.x.shape[0] - num_train_nodes
        
        # Create validation mask
        val_mask_for_val_data = torch.cat([
            torch.zeros(num_train_nodes, dtype=torch.bool, device=device),
            torch.ones(num_val_nodes, dtype=torch.bool, device=device)
        ])
        
        # Training with augmented data (includes text_attacked class)
        trainer.fit((augmented_train_data, val_data), mask=(None, val_mask_for_val_data), 
                   verbose=1, callbacks=[ckp, early_stopping], epochs=args.epochs)
        
        # Evaluate on clean test set with auto logic
        model.eval()
        with torch.no_grad():
            test_logits, clean_stats = model(test_data.x, test_data.edge_index, 
                                           training=False, apply_auto_logic=True, return_stats=True)
            test_pred = torch.argmax(test_logits, dim=1)
            test_acc = (test_pred[full_data.test_mask] == test_data.y[full_data.test_mask]).float().mean()
        
        val_acc = ckp.best
        clean_val_accs.append(val_acc)
        clean_accs.append(test_acc.item())

        # Load text attacked data and evaluate with auto logic
        atk_meta_info = {
            'attack': args.attack,
            'ptb_rate': args.ptb_rate,
            'atk_emb_type': args.atk_emb_type,
            'seed': seed
        }
        
        try:
            atk_full_data, (atk_train_data, atk_val_data, atk_test_data) = load_inductive_atk_graph_dataset_for_gnn(
                args.dataset, device, atk_meta_info, re_split=args.re_split, 
                path_prefix=root_path, emb_model=args.def_emb_type, seed=seed
            )
            
            atk_full_data = atk_full_data.to(device)
            atk_test_data = atk_test_data.to(device)

            # Load attacked node information for detailed analysis
            attacked_node_ids = None
            try:
                import json
                from common import ATKG_PATH
                
                # Load attacked node IDs for gpt text attacks
                if args.attack == 'gpt':
                    atk_path = f"{ATKG_PATH}/{args.dataset}/llm_gpt-4o-mini_inductive/attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json"
                    if os.path.exists(atk_path):
                        with open(atk_path, 'r') as f:
                            attacked_data = json.load(f)
                        
                        if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
                            attacked_texts_data = attacked_data["attacked_texts"]
                            attacked_node_ids = set(int(item['node_id']) for item in attacked_texts_data)
                        else:
                            attacked_node_ids = set(int(k) for k in attacked_data.keys())
                        
                        logger.info(f"  GPT Attack: Loaded {len(attacked_node_ids)} attacked node IDs from ground truth")
                    else:
                        logger.info(f"  Warning: Attack file not found: {atk_path}")
            except Exception as e:
                logger.info(f"  Failed to load attacked node information: {e}")

            with torch.no_grad():
                attacked_logits, attack_stats = model(atk_test_data.x, atk_test_data.edge_index, 
                                                    training=False, apply_auto_logic=True, return_stats=True)
                attacked_pred = torch.argmax(attacked_logits, dim=1)
                attacked_acc = (attacked_pred[full_data.test_mask] == atk_test_data.y[full_data.test_mask]).float().mean()
            
            attacked_accs.append(attacked_acc.item())
            attacked_val_accs.append(val_acc)  # Use same validation accuracy
            
            # Log detailed auto statistics
            logger.info(f"Seed {seed}: Clean={clean_accs[-1]:.4f}, Attacked={attacked_accs[-1]:.4f}")
            logger.info(f"  Clean data auto stats:")
            logger.info(f"    - Nodes detected as text_attacked: {clean_stats.get('text_attacked_detected', 0)}")
            logger.info(f"    - Nodes using text recovery: {clean_stats.get('text_recovery_used', 0)}")
            logger.info(f"  Attack data auto stats:")
            logger.info(f"    - Nodes detected as text_attacked: {attack_stats.get('text_attacked_detected', 0)}")
            logger.info(f"    - Nodes using text recovery: {attack_stats.get('text_recovery_used', 0)}")
            
            # Detailed analysis for GPT text attacks when ground truth is available
            if args.attack == 'gpt' and attacked_node_ids is not None:
                logger.info(f"\n=== GPT Text Attack Detailed Analysis ===")
                
                # Focus on TEST NODES ONLY (nodes we're actually predicting)
                test_mask = full_data.test_mask.cpu().numpy()
                test_node_indices = test_mask.nonzero()[0]
                
                # Map test indices to original node IDs (inductive setting required)
                test_original_ids = atk_test_data.node_ids.cpu().numpy().tolist()
                test_attacked_nodes = []
                test_non_attacked_nodes = []
                
                for i in test_node_indices:  # Only consider test split nodes
                    orig_id = test_original_ids[i]
                    if orig_id in attacked_node_ids:
                        test_attacked_nodes.append(i)
                    else:
                        test_non_attacked_nodes.append(i)
                
                # Get model predictions on test nodes
                stage1_logits = model.conv_layers(atk_test_data.x, atk_test_data.edge_index)
                stage1_logits = model.classifier(stage1_logits, atk_test_data.edge_index)
                stage1_pred = torch.argmax(stage1_logits, dim=1)
                text_attacked_detected = (stage1_pred == model.attack_out_channels - 1)
                
                # 1. Check how many attacked nodes are correctly identified as text_attacked
                correctly_detected_attacked = []
                missed_attacked = []
                
                for i in test_attacked_nodes:
                    if i < len(stage1_pred) and text_attacked_detected[i]:
                        correctly_detected_attacked.append(i)
                    else:
                        missed_attacked.append(i)
                
                # 2. Check how many non-attacked nodes are incorrectly flagged
                false_positives = []
                for i in test_non_attacked_nodes:
                    if i < len(stage1_pred) and text_attacked_detected[i]:
                        false_positives.append(i)
                
                # 3. Check recovery accuracy for correctly detected attacked nodes
                correctly_recovered_attacked = []
                if len(correctly_detected_attacked) > 0:
                    for i in correctly_detected_attacked:
                        final_pred = attacked_pred[i].item()
                        ground_truth = atk_test_data.y[i].item()
                        if final_pred == ground_truth:
                            correctly_recovered_attacked.append(i)
                
                # 4. CRITICAL ANALYSIS: Why are undetected attacked nodes still correct?
                undetected_attacked_nodes = missed_attacked
                undetected_correct = []
                undetected_wrong = []
                
                for i in undetected_attacked_nodes:
                    final_pred = attacked_pred[i].item()
                    ground_truth = atk_test_data.y[i].item()
                    if final_pred == ground_truth:
                        undetected_correct.append(i)
                    else:
                        undetected_wrong.append(i)
                
                # 5. Edge connectivity analysis for attacked nodes
                edge_index = atk_test_data.edge_index.cpu()
                total_edges_original = edge_index.shape[1]
                
                # Apply similarity filtering to see how many edges are removed
                filtered_edge_index = model.filter_edges_by_similarity(atk_test_data.x, edge_index).cpu()
                total_edges_filtered = filtered_edge_index.shape[1]
                
                # Count edges involving attacked nodes (before and after filtering)
                attacked_node_set = set(test_attacked_nodes)
                
                def count_edges_for_nodes(edge_idx, node_set):
                    edges_as_source = sum(1 for src in edge_idx[0] if src.item() in node_set)
                    edges_as_target = sum(1 for tgt in edge_idx[1] if tgt.item() in node_set)
                    return edges_as_source, edges_as_target
                
                orig_edges_src, orig_edges_tgt = count_edges_for_nodes(edge_index, attacked_node_set)
                filt_edges_src, filt_edges_tgt = count_edges_for_nodes(filtered_edge_index, attacked_node_set)
                
                # Calculate average similarity scores for attacked vs non-attacked nodes
                src_nodes = edge_index[0]
                tgt_nodes = edge_index[1]
                src_embs = atk_test_data.x[src_nodes]
                tgt_embs = atk_test_data.x[tgt_nodes]
                similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
                
                # Separate similarities for edges involving attacked nodes
                attacked_edges_mask = torch.tensor([
                    (src.item() in attacked_node_set) or (tgt.item() in attacked_node_set)
                    for src, tgt in zip(src_nodes, tgt_nodes)
                ])
                clean_edges_mask = ~attacked_edges_mask
                
                attacked_similarities = similarities[attacked_edges_mask]
                clean_similarities = similarities[clean_edges_mask]
                
                # 6. NEIGHBOR ANALYSIS: Check neighbors around text-attacked nodes before and after filtering
                def get_neighbors_dict(edge_idx):
                    neighbors = {}
                    for src, tgt in zip(edge_idx[0], edge_idx[1]):
                        src_item, tgt_item = src.item(), tgt.item()
                        if src_item not in neighbors:
                            neighbors[src_item] = []
                        if tgt_item not in neighbors:
                            neighbors[tgt_item] = []
                        neighbors[src_item].append(tgt_item)
                        neighbors[tgt_item].append(src_item)
                    return neighbors
                
                original_neighbors = get_neighbors_dict(edge_index)
                filtered_neighbors = get_neighbors_dict(filtered_edge_index)
                
                # Analyze neighbor changes for attacked nodes
                neighbor_analysis = {}
                total_neighbors_lost = 0
                total_neighbors_retained = 0
                
                for attacked_node in test_attacked_nodes:
                    orig_neighs = set(original_neighbors.get(attacked_node, []))
                    filt_neighs = set(filtered_neighbors.get(attacked_node, []))
                    
                    lost_neighs = orig_neighs - filt_neighs
                    retained_neighs = orig_neighs & filt_neighs
                    
                    neighbor_analysis[attacked_node] = {
                        'original_count': len(orig_neighs),
                        'filtered_count': len(filt_neighs),
                        'lost_count': len(lost_neighs),
                        'retained_count': len(retained_neighs),
                        'retention_rate': len(retained_neighs) / len(orig_neighs) if len(orig_neighs) > 0 else 0
                    }
                    
                    total_neighbors_lost += len(lost_neighs)
                    total_neighbors_retained += len(retained_neighs)
                
                # Compare with clean nodes neighbor analysis
                clean_neighbor_analysis = {}
                clean_neighbors_lost = 0
                clean_neighbors_retained = 0
                
                # Sample some clean nodes for comparison (to avoid too much computation)
                sample_clean_nodes = test_non_attacked_nodes[:min(50, len(test_non_attacked_nodes))]
                
                for clean_node in sample_clean_nodes:
                    orig_neighs = set(original_neighbors.get(clean_node, []))
                    filt_neighs = set(filtered_neighbors.get(clean_node, []))
                    
                    lost_neighs = orig_neighs - filt_neighs
                    retained_neighs = orig_neighs & filt_neighs
                    
                    clean_neighbor_analysis[clean_node] = {
                        'original_count': len(orig_neighs),
                        'filtered_count': len(filt_neighs),
                        'lost_count': len(lost_neighs),
                        'retained_count': len(retained_neighs),
                        'retention_rate': len(retained_neighs) / len(orig_neighs) if len(orig_neighs) > 0 else 0
                    }
                    
                    clean_neighbors_lost += len(lost_neighs)
                    clean_neighbors_retained += len(retained_neighs)
                
                # Log detailed statistics
                total_attacked = len(test_attacked_nodes)
                total_non_attacked = len(test_non_attacked_nodes)
                detected_attacked = len(correctly_detected_attacked)
                missed_attacked_count = len(missed_attacked)
                false_positive_count = len(false_positives)
                recovered_attacked = len(correctly_recovered_attacked)
                
                logger.info(f"GPT Attack Detection Results (TEST NODES ONLY):")
                logger.info(f"  - Total test nodes with GPT attacks: {total_attacked}")
                logger.info(f"  - Total test nodes without attacks: {total_non_attacked}")
                logger.info(f"  - Correctly detected as text_attacked: {detected_attacked}/{total_attacked} ({detected_attacked/total_attacked*100:.1f}%)")
                logger.info(f"  - Missed text attacks: {missed_attacked_count}/{total_attacked} ({missed_attacked_count/total_attacked*100:.1f}%)")
                logger.info(f"  - False positives (clean flagged as attacked): {false_positive_count}/{total_non_attacked} ({false_positive_count/total_non_attacked*100:.1f}%)")
                logger.info(f"  - Correctly recovered after detection: {recovered_attacked}/{detected_attacked} ({recovered_attacked/detected_attacked*100:.1f}% of detected)" if detected_attacked > 0 else "  - No correctly detected nodes to recover")
                
                # CRITICAL: Undetected attacked nodes analysis
                undetected_correct_count = len(undetected_correct)
                undetected_wrong_count = len(undetected_wrong)
                undetected_accuracy = undetected_correct_count / len(undetected_attacked_nodes) if len(undetected_attacked_nodes) > 0 else 0
                
                logger.info(f"\n=== CRITICAL: Undetected Attacked Nodes Analysis ===")
                logger.info(f"  - Undetected attacked nodes (missed): {len(undetected_attacked_nodes)}")
                logger.info(f"  - Undetected but STILL CORRECT: {undetected_correct_count}/{len(undetected_attacked_nodes)} ({undetected_accuracy*100:.1f}%)")
                logger.info(f"  - Undetected and wrong: {undetected_wrong_count}/{len(undetected_attacked_nodes)} ({(1-undetected_accuracy)*100:.1f}%)")
                
                # Edge connectivity analysis
                orig_total_edges = orig_edges_src + orig_edges_tgt
                filt_total_edges = filt_edges_src + filt_edges_tgt
                edge_retention_rate = filt_total_edges / orig_total_edges if orig_total_edges > 0 else 0
                
                logger.info(f"\n=== Edge Connectivity Analysis for Attacked Nodes ===")
                logger.info(f"  - Original edges involving attacked nodes: {orig_total_edges} ({orig_total_edges/total_edges_original*100:.1f}% of all edges)")
                logger.info(f"  - Filtered edges involving attacked nodes: {filt_total_edges} ({filt_total_edges/total_edges_filtered*100:.1f}% of filtered edges)")
                logger.info(f"  - Edge retention rate for attacked nodes: {edge_retention_rate:.3f}")
                logger.info(f"  - Total edges: {total_edges_original} → {total_edges_filtered} (removed: {total_edges_original - total_edges_filtered})")
                
                # Neighbor analysis logging
                if len(neighbor_analysis) > 0:
                    attacked_retention_rates = [info['retention_rate'] for info in neighbor_analysis.values()]
                    attacked_avg_retention = sum(attacked_retention_rates) / len(attacked_retention_rates)
                    attacked_original_avg = sum(info['original_count'] for info in neighbor_analysis.values()) / len(neighbor_analysis)
                    attacked_filtered_avg = sum(info['filtered_count'] for info in neighbor_analysis.values()) / len(neighbor_analysis)
                    
                    clean_retention_rates = [info['retention_rate'] for info in clean_neighbor_analysis.values()]
                    clean_avg_retention = sum(clean_retention_rates) / len(clean_retention_rates) if len(clean_retention_rates) > 0 else 0
                    clean_original_avg = sum(info['original_count'] for info in clean_neighbor_analysis.values()) / len(clean_neighbor_analysis) if len(clean_neighbor_analysis) > 0 else 0
                    clean_filtered_avg = sum(info['filtered_count'] for info in clean_neighbor_analysis.values()) / len(clean_neighbor_analysis) if len(clean_neighbor_analysis) > 0 else 0
                    
                    logger.info(f"\n=== NEIGHBOR ANALYSIS: The Key to Understanding ===")
                    logger.info(f"  ATTACKED NODES:")
                    logger.info(f"    - Average original neighbors: {attacked_original_avg:.1f}")
                    logger.info(f"    - Average filtered neighbors: {attacked_filtered_avg:.1f}")
                    logger.info(f"    - Average neighbor retention rate: {attacked_avg_retention:.3f}")
                    logger.info(f"    - Total neighbors lost: {total_neighbors_lost}")
                    logger.info(f"    - Total neighbors retained: {total_neighbors_retained}")
                    
                    logger.info(f"  CLEAN NODES (sample of {len(sample_clean_nodes)}):")
                    logger.info(f"    - Average original neighbors: {clean_original_avg:.1f}")
                    logger.info(f"    - Average filtered neighbors: {clean_filtered_avg:.1f}")
                    logger.info(f"    - Average neighbor retention rate: {clean_avg_retention:.3f}")
                    logger.info(f"    - Total neighbors lost: {clean_neighbors_lost}")
                    logger.info(f"    - Total neighbors retained: {clean_neighbors_retained}")
                    
                    retention_difference = attacked_avg_retention - clean_avg_retention
                    logger.info(f"  COMPARISON:")
                    logger.info(f"    - Retention rate difference (attacked - clean): {retention_difference:.3f}")
                    if retention_difference < -0.1:
                        logger.info(f"    → ATTACKED NODES LOSE MORE NEIGHBORS! This explains the robustness!")
                    elif retention_difference > 0.1:
                        logger.info(f"    → CLEAN NODES LOSE MORE NEIGHBORS! This is unexpected!")
                    else:
                        logger.info(f"    → Similar retention rates. The mechanism might be different!")
                
                # Similarity analysis
                if len(attacked_similarities) > 0 and len(clean_similarities) > 0:
                    attacked_sim_mean = attacked_similarities.mean().item()
                    clean_sim_mean = clean_similarities.mean().item()
                    attacked_sim_std = attacked_similarities.std().item()
                    clean_sim_std = clean_similarities.std().item()
                    
                    # Count how many attacked edges are below threshold
                    threshold = model.similarity_threshold
                    attacked_below_threshold = (attacked_similarities < threshold).sum().item()
                    clean_below_threshold = (clean_similarities < threshold).sum().item()
                    
                    logger.info(f"\n=== Similarity Score Analysis ===")
                    logger.info(f"  - Edges involving attacked nodes similarity: {attacked_sim_mean:.3f} ± {attacked_sim_std:.3f}")
                    logger.info(f"  - Edges involving clean nodes similarity: {clean_sim_mean:.3f} ± {clean_sim_std:.3f}")
                    logger.info(f"  - Attacked edges below threshold ({threshold}): {attacked_below_threshold}/{len(attacked_similarities)} ({attacked_below_threshold/len(attacked_similarities)*100:.1f}%)")
                    logger.info(f"  - Clean edges below threshold ({threshold}): {clean_below_threshold}/{len(clean_similarities)} ({clean_below_threshold/len(clean_similarities)*100:.1f}%)")
                
                detection_precision = detected_attacked / (detected_attacked + false_positive_count) if (detected_attacked + false_positive_count) > 0 else 0
                detection_recall = detected_attacked / total_attacked if total_attacked > 0 else 0
                recovery_rate = recovered_attacked / detected_attacked if detected_attacked > 0 else 0
                
                logger.info(f"\nGPT Attack Summary Metrics:")
                logger.info(f"  - Detection Precision: {detection_precision:.3f}")
                logger.info(f"  - Detection Recall: {detection_recall:.3f}")
                logger.info(f"  - Recovery Rate: {recovery_rate:.3f}")
                logger.info(f"  - Undetected Accuracy: {undetected_accuracy:.3f} (THIS IS THE KEY!)")
                logger.info(f"=" * 80)
            
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Missing attack data for seed {seed}: {e}")

    # Log results
    logger.info(f"> Clean Val Acc: {np.mean(clean_val_accs):.4f} ± {np.std(clean_val_accs):.4f}")
    logger.info(f"> Clean Test Acc: {np.mean(clean_accs):.4f} ± {np.std(clean_accs):.4f}")
    logger.info(f"> Attacked Test Acc: {np.mean(attacked_accs):.4f} ± {np.std(attacked_accs):.4f}")
    logger.info(f"> Time: {(time.time() - start_time) / len(seeds):.2f}s")

logger.info("AutoGCN Text evaluation completed!")