"""
Shared utilities for auto prompt inference - handles 3-stage attack detection and recovery.
Used by both train.py (clean data) and train_atk.py (attacked data) to ensure consistency.
"""

import torch
from common import prepare_edge_list
from common.sft_prompts import CLASSES_WITH_ATTACKED, AUTO_SIMILARITY_THRESHOLDS, DUALAUTO_TRAIN_THRESHOLDS, DUALAUTO_TEST_THRESHOLDS
from common import CLASSES as classes, UNKNOW
from train import tokenizer_test_data


def filter_edges_by_similarity(node_embeddings, edge_index, threshold=0.5):
    """Filter edges based on cosine similarity between connected nodes (shared by auto and simf)"""
    print(f"Similarity filtering: Original edges: {edge_index.shape[1]}")
    
    # Get source and target node indices
    src_nodes = edge_index[0]
    tgt_nodes = edge_index[1]
    
    # Get embeddings for source and target nodes
    src_embs = node_embeddings[src_nodes]
    tgt_embs = node_embeddings[tgt_nodes]
    
    # Compute cosine similarity
    similarities = torch.cosine_similarity(src_embs, tgt_embs, dim=1)
    
    # Filter edges where similarity >= threshold
    keep_mask = similarities >= threshold
    filtered_edge_index = edge_index[:, keep_mask]
    
    print(f"Similarity filtering: Filtered edges: {filtered_edge_index.shape[1]} (kept {keep_mask.sum().item()}/{len(keep_mask)} edges)")
    return filtered_edge_index


def run_auto_inference_with_structure(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data=None, is_attack_eval=False, prompt_type="auto"):
    """
    Run three-stage inference for auto prompt:
    1) detect text-attacked and structure-attacked nodes
    2) recover text-attacked nodes using neighbors only
    3) recover structure-attacked nodes by filtering dissimilar neighbors
    
    Args:
        is_attack_eval: If True, show detailed attack analysis logging (train_atk.py)
                       If False, show simplified clean data logging (train.py)
    """
    batch_size = data_args.batch_size * 2
    results = []
    
    # Get similarity threshold for this dataset
    similarity_threshold = AUTO_SIMILARITY_THRESHOLDS.get(dataset_name, 0.5)
    
    # Prepare edge and similarity information if graph_data is provided
    node_degrees = None
    node_embeddings = None
    edge_index = None
    original_edge_list = None
    filtered_edge_list = None
    
    if graph_data is not None:
        if hasattr(graph_data, 'edge_index'):
            edge_index = graph_data.edge_index.cpu()
            node_degrees = torch.zeros(graph_data.num_nodes, dtype=torch.long)
            for i in range(graph_data.num_nodes):
                node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
            
            # Prepare original and filtered edge lists
            original_edge_list = prepare_edge_list(edge_index, graph_data.num_nodes)
            filtered_edge_list = [neighbors.copy() for neighbors in original_edge_list]
        
        # Get node embeddings for similarity calculation (required for structure attack detection)
        if not hasattr(graph_data, 'x') or graph_data.x is None:
            raise ValueError("Auto prompt requires node embeddings for structure attack detection, but graph_data.x is None")
        node_embeddings = graph_data.x.cpu()
        
        if is_attack_eval:
            print(f"Auto: Using node embeddings (shape: {node_embeddings.shape})")
    
    print("Running Auto inference - Stage 1: Attack detection...")
    
    # Stage 1: Dual attack detection (structure + text)
    stage1_pred_labels = []
    stage1_gt_labels = []
    text_attacked_indices = []
    structure_attacked_indices = []
    
    # Structure attack detection (embedding-based, only during inference)
    if edge_index is not None:
        # Pre-compute adjacency list for O(1) neighbor lookup (massive speedup for large graphs)
        from collections import defaultdict
        adj_list = defaultdict(list)
        for i in range(edge_index.shape[1]):
            src, tgt = edge_index[0, i].item(), edge_index[1, i].item()
            adj_list[src].append(tgt)
            adj_list[tgt].append(src)
        
        for idx, sample in enumerate(test_contents):
            node_id = sample["id"]
            
            if "neighbor_texts" in sample and len(sample["neighbor_texts"]) > 0:
                # Fast O(1) neighbor lookup instead of O(E) scan
                neighbors = adj_list[node_id]
                
                if len(neighbors) > data_args.maximum_neighbor:
                    neighbors = neighbors[:data_args.maximum_neighbor]
                
                if len(neighbors) > 0:
                    center_embedding = node_embeddings[node_id].unsqueeze(0)
                    neighbor_embeddings = node_embeddings[neighbors]
                    similarities = torch.cosine_similarity(center_embedding, neighbor_embeddings, dim=1)
                    low_sim_count = (similarities <= similarity_threshold).sum().item()
                    
                    if low_sim_count >= len(neighbors) / 2:
                        structure_attacked_indices.append(idx)
                        sample["low_sim_neighbors"] = [neighbors[i] for i in range(len(neighbors)) if similarities[i] <= similarity_threshold]
    
    # Text attack detection (LLM-based)
    for i in range(0, len(test_contents), batch_size):
        batch_data = test_contents[i: min(i+batch_size, len(test_contents))]
        
        batch_encodings = tokenizer_test_data(
            batch_data, tokenizer, dataset_name, "auto", model_name,
            max_txt_length=data_args.max_txt_length, 
            max_origin_txt_length=data_args.max_origin_txt_length
        )
        
        batch_input_ids = batch_encodings["input_ids"]
        batch_attention_mask = batch_encodings["attention_mask"]
        batch_input_ids = batch_input_ids.to(accelerator.device)
        batch_attention_mask = batch_attention_mask.to(accelerator.device)
        
        model_for_generation = trainer.model
        if hasattr(model_for_generation, 'module'):
            model_for_generation = model_for_generation.module
        
        with torch.no_grad():
            output = model_for_generation.generate(
                input_ids=batch_input_ids, 
                attention_mask=batch_attention_mask, 
                max_new_tokens=data_args.max_ans_length,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        decode_output = tokenizer.batch_decode(output, skip_special_tokens=True)
        origin_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
        
        for idx, pred_content in enumerate(decode_output):
            global_idx = i + idx
            origin_prompt = origin_prompts[idx] 
            label, node_id = batch_data[idx]["output"], batch_data[idx]["id"]
            
            pred_label = pred_content.replace(origin_prompt, "").strip()
            lines = [line.strip() for line in pred_label.split('\n') if line.strip()]
            if lines:
                pred_label = lines[0]
            pred_label = pred_label.strip().strip('"').strip("'").strip()
            
            valid_classes = CLASSES_WITH_ATTACKED[dataset_name]
            pred_label = pred_label if pred_label in valid_classes else "unknown"
            
            stage1_pred_labels.append(pred_label)
            stage1_gt_labels.append(label)
            
            # Track text-attacked nodes (but not if already marked as structure-attacked)
            if pred_label == "text_attacked":
                text_attacked_indices.append(global_idx)
                if global_idx in structure_attacked_indices:
                    structure_attacked_indices.remove(global_idx)

            results.append({
                "idx": node_id, 
                "ground-truth": label,
                "stage1_pred": pred_label,
                "raw_stage1_output": pred_content,
                "cleaned_stage1_pred": pred_content.replace(origin_prompt, "").strip()
            })
    
    # Log attack detection summary
    if is_attack_eval:
        print(f"\n=== Attack Detection Summary ===")
        print(f"Total test nodes: {len(test_contents)}")
        print(f"Text-attacked nodes detected: {len(text_attacked_indices)} ({len(text_attacked_indices)/len(test_contents)*100:.1f}%)")
        print(f"Structure-attacked nodes detected: {len(structure_attacked_indices)} ({len(structure_attacked_indices)/len(test_contents)*100:.1f}%)")
        print(f"Total attacked nodes: {len(text_attacked_indices) + len(structure_attacked_indices)} ({(len(text_attacked_indices) + len(structure_attacked_indices))/len(test_contents)*100:.1f}%)")
        print(f"Clean nodes: {len(test_contents) - len(text_attacked_indices) - len(structure_attacked_indices)} ({(len(test_contents) - len(text_attacked_indices) - len(structure_attacked_indices))/len(test_contents)*100:.1f}%)")
        print(f"================================\n")
    else:
        print(f"Stage 1 completed: {len(text_attacked_indices)} text-attacked, {len(structure_attacked_indices)} structure-attacked detected")
    
    # Update filtered edge list: remove text-attacked nodes from others' neighborhoods
    if filtered_edge_list is not None:
        text_attacked_node_ids = set()
        for idx in text_attacked_indices:
            if idx < len(test_contents):
                text_attacked_node_ids.add(test_contents[idx]["id"])
        
        if text_attacked_node_ids:
            total_edges_removed = 0
            nodes_affected = 0
            for node_id in range(len(filtered_edge_list)):
                original_len = len(filtered_edge_list[node_id])
                filtered_edge_list[node_id] = [neighbor for neighbor in filtered_edge_list[node_id] 
                                             if neighbor not in text_attacked_node_ids]
                edges_removed = original_len - len(filtered_edge_list[node_id])
                if edges_removed > 0:
                    nodes_affected += 1
                    total_edges_removed += edges_removed
            
            if is_attack_eval and hasattr(data_args, 'atk_type'):
                # Detailed logging for attack evaluation
                attack_type = getattr(data_args, 'atk_type', 'unknown')
                if attack_type == 'structure':
                    ptb_rate = getattr(data_args, 'ptb_rate', 0.0)
                    total_original_edges = sum(len(neighbors) for neighbors in original_edge_list)
                    estimated_harmful_edges = int(total_original_edges * ptb_rate)
                    removal_percentage = (total_edges_removed / estimated_harmful_edges * 100) if estimated_harmful_edges > 0 else 0
                    
                    print(f"Auto: Structure attack analysis:")
                    print(f"  - Estimated harmful edges added: ~{estimated_harmful_edges} ({ptb_rate:.1%} of {total_original_edges})")
                    print(f"  - Edges removed by auto filtering: {total_edges_removed}")
                    print(f"  - Auto removal coverage: {removal_percentage:.1f}% of estimated harmful edges")
                    print(f"  - Nodes affected by filtering: {nodes_affected}")
                else:
                    print(f"Auto: Filtered {total_edges_removed} edges from {nodes_affected} nodes")
    
    # Stage 2: Text attack recovery (using original edges)
    stage2_text_pred_labels = {}
    
    if len(text_attacked_indices) > 0:
        print("Running Auto inference - Stage 2: Recovery for text-attacked nodes...")
        
        stage2_test_data = []
        for idx in text_attacked_indices:
            sample = test_contents[idx].copy()
            sample["input"] = ""  # No center node text
            sample["recovery_type"] = "neighbors_only"
            
            # For text-attacked nodes, use original (full) neighbors
            if original_edge_list is not None:
                node_id = sample["id"]
                original_neighbors = original_edge_list[node_id]
                
                if len(original_neighbors) > data_args.maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in original_neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    selected_neighbors = [neigh for neigh, _ in neighbor_degrees[:data_args.maximum_neighbor]]
                else:
                    selected_neighbors = original_neighbors
                
                if hasattr(graph_data, 'raw_texts'):
                    sample["neighbor_texts"] = [graph_data.raw_texts[neigh] for neigh in selected_neighbors]
            
            stage2_test_data.append(sample)
        
        # Run Stage 2 inference
        for i in range(0, len(stage2_test_data), batch_size):
            batch_data = stage2_test_data[i: min(i+batch_size, len(stage2_test_data))]
            
            batch_encodings = tokenizer_test_data(
                batch_data, tokenizer, dataset_name, "auto", model_name,
                max_txt_length=data_args.max_txt_length, 
                max_origin_txt_length=data_args.max_origin_txt_length
            )
            
            batch_input_ids = batch_encodings["input_ids"]
            batch_attention_mask = batch_encodings["attention_mask"]
            batch_input_ids = batch_input_ids.to(accelerator.device)
            batch_attention_mask = batch_attention_mask.to(accelerator.device)
            
            with torch.no_grad():
                output = model_for_generation.generate(
                    input_ids=batch_input_ids, 
                    attention_mask=batch_attention_mask, 
                    max_new_tokens=data_args.max_ans_length,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
            
            decode_output = tokenizer.batch_decode(output, skip_special_tokens=True)
            origin_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            
            for idx, pred_content in enumerate(decode_output):
                batch_global_idx = text_attacked_indices[i + idx]
                origin_prompt = origin_prompts[idx]
                
                pred_label = pred_content.replace(origin_prompt, "").strip()
                lines = [line.strip() for line in pred_label.split('\n') if line.strip()]
                if lines:
                    pred_label = lines[0]
                pred_label = pred_label.strip().strip('"').strip("'").strip()
                pred_label = pred_label if pred_label in classes[dataset_name] else "unknown"
                
                stage2_text_pred_labels[batch_global_idx] = pred_label
                results[batch_global_idx]["stage2_text_pred"] = pred_label
                results[batch_global_idx]["raw_stage2_text_output"] = pred_content
                results[batch_global_idx]["recovery_method"] = "neighbors_only"
    
    # Stage 3: Structure attack recovery (using filtered edges)
    stage3_struct_pred_labels = {}
    
    if len(structure_attacked_indices) > 0:
        print("Running Auto inference - Stage 3: Recovery for structure-attacked nodes...")
        
        stage3_test_data = []
        for idx in structure_attacked_indices:
            sample = test_contents[idx].copy()
            node_id = sample["id"]
            
            # Use filtered edge list for structure-attacked nodes
            if filtered_edge_list is not None:
                filtered_neighbors = filtered_edge_list[node_id]
                
                # Further filter out low-similarity neighbors if we have the information
                if "low_sim_neighbors" in sample:
                    low_sim_neighbors = sample["low_sim_neighbors"]
                    final_neighbors = [neigh for neigh in filtered_neighbors if neigh not in low_sim_neighbors]
                    edges_filtered = len(filtered_neighbors) - len(final_neighbors)
                    if edges_filtered > 0 and is_attack_eval:
                        print(f"Auto: Node {node_id} filtered {edges_filtered} additional low-sim neighbors")
                else:
                    final_neighbors = filtered_neighbors
                
                if len(final_neighbors) > data_args.maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in final_neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    final_neighbors = [neigh for neigh, _ in neighbor_degrees[:data_args.maximum_neighbor]]
                
                if hasattr(graph_data, 'raw_texts'):
                    sample["neighbor_texts"] = [graph_data.raw_texts[neigh] for neigh in final_neighbors]
                    if "neighbor_labels" in sample:
                        sample["neighbor_labels"] = [None] * len(final_neighbors)
                
                sample["recovery_type"] = "filtered_neighbors"
            
            stage3_test_data.append(sample)
        
        # Run Stage 3 inference
        for i in range(0, len(stage3_test_data), batch_size):
            batch_data = stage3_test_data[i: min(i+batch_size, len(stage3_test_data))]
            
            batch_encodings = tokenizer_test_data(
                batch_data, tokenizer, dataset_name, "auto", model_name,
                max_txt_length=data_args.max_txt_length, 
                max_origin_txt_length=data_args.max_origin_txt_length
            )
            
            batch_input_ids = batch_encodings["input_ids"]
            batch_attention_mask = batch_encodings["attention_mask"]
            batch_input_ids = batch_input_ids.to(accelerator.device)
            batch_attention_mask = batch_attention_mask.to(accelerator.device)
            
            with torch.no_grad():
                output = model_for_generation.generate(
                    input_ids=batch_input_ids, 
                    attention_mask=batch_attention_mask, 
                    max_new_tokens=data_args.max_ans_length,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
            
            decode_output = tokenizer.batch_decode(output, skip_special_tokens=True)
            origin_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
            
            for idx, pred_content in enumerate(decode_output):
                batch_global_idx = structure_attacked_indices[i + idx]
                origin_prompt = origin_prompts[idx]
                
                pred_label = pred_content.replace(origin_prompt, "").strip()
                lines = [line.strip() for line in pred_label.split('\n') if line.strip()]
                if lines:
                    pred_label = lines[0]
                pred_label = pred_label.strip().strip('"').strip("'").strip()
                pred_label = pred_label if pred_label in classes[dataset_name] else "unknown"
                
                stage3_struct_pred_labels[batch_global_idx] = pred_label
                results[batch_global_idx]["stage3_struct_pred"] = pred_label
                results[batch_global_idx]["raw_stage3_struct_output"] = pred_content
                results[batch_global_idx]["recovery_method"] = batch_data[idx]["recovery_type"]
    
    # Combine final predictions
    final_pred_labels = []
    final_gt_labels = []
    
    for i, result in enumerate(results):
        if i in text_attacked_indices and i in stage2_text_pred_labels:
            # Text-attacked: use stage 2 recovery
            final_pred = stage2_text_pred_labels[i]
            result["final_pred"] = final_pred
            result["prediction_source"] = "text_recovery"
        elif i in structure_attacked_indices and i in stage3_struct_pred_labels:
            # Structure-attacked: use stage 3 recovery
            final_pred = stage3_struct_pred_labels[i]
            result["final_pred"] = final_pred
            result["prediction_source"] = "structure_recovery"
        else:
            # Not attacked or not recovered: use stage 1 prediction
            final_pred = result["stage1_pred"]
            if final_pred == "text_attacked":
                final_pred = "unknown"
            result["final_pred"] = final_pred
            result["prediction_source"] = "direct_inference"
        
        final_pred_labels.append(final_pred)
        final_gt_labels.append(result["ground-truth"])
    
    if is_attack_eval:
        print(f"\nRecovery completed:")
        print(f"  - Text-attacked nodes recovered: {len(stage2_text_pred_labels)}")
        print(f"  - Structure-attacked nodes recovered: {len(stage3_struct_pred_labels)}")
        print(f"  - Direct predictions: {len([r for r in results if r['prediction_source'] == 'direct_inference'])}")
        print(f"  - Total valid predictions: {len([p for p in final_pred_labels if p != 'unknown'])}/{len(final_pred_labels)}")
    
    return final_pred_labels, final_gt_labels, results