import argparse
import os
import sys
import json
import csv
import time
import random
from dataclasses import dataclass, field
from typing import Optional, List

import torch
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments as HFTrainingArguments,
    Trainer, 
    IntervalStrategy
)

from dataset import TextDataset

sys.path.append("../..")
from common import set_seed, load_graph_dataset, load_inductive_graph_dataset, compute_acc_and_f1
from common import IGNORE_INDEX, UNKNOW
from common import CLASSES as classes 
from common import MODEL_PATHs as llm_paths
from common import get_prompt_template, get_available_prompt_types, get_supported_datasets, requires_neighbor_info, get_classes_str
from common import prepare_edge_list
from common.model_path import get_model_save_path, check_model_exists, get_setting_info
from common.sft_prompts import CLASSES_WITH_ATTACKED, AUTO_SIMILARITY_THRESHOLDS


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Mistral-7B")


@dataclass
class DataArguments:
    dataset: str = field(default="cora", metadata={"help": "Dataset name"})
    prompt_type: str = field(default="instruction_tuning", metadata={"help": "Type of prompt template to use"})
    max_txt_length: int = field(default=128, metadata={"help": "Maximum length of query prompt"})
    max_origin_txt_length: int = field(default=128, metadata={"help": "Maximum length of node's original text"})
    max_ans_length: int = field(default=16, metadata={"help": "Maximum length of answer"})
    re_split: int = field(default=0, metadata={"help": "Whether to use re-split"})
    maximum_neighbor: int = field(default=8, metadata={"help": "Maximum number of neighbors to include"})


@dataclass
class TrainingArguments(HFTrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=8192,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=True)
    num_epoch: int = field(default=2)
    batch_size: int = field(default=8)


@dataclass
class LoraArguments:
    lora_r: int = field(default=8)
    lora_alpha: int = field(default=16)
    lora_dropout: float = field(default=0.1)
    lora_target_modules: List[str] = field(
        default_factory=lambda: ["q_proj", "v_proj"]
    )
    lora_weight_path: str = field(default="")
    lora_bias: str = field(default="none")
    q_lora: bool = field(default=False)


def prepare_graph_instruction_tuning_data(graph_data, data_type="train", dataset_name="cora", prompt_type="instruction_tuning", maximum_neighbor=8, 
                                          inductive=False, full_graph_data=None):
    """Prepare data for different prompt types"""
    if inductive:
        raw_texts = graph_data.raw_texts
        if data_type == "train":
            focus_nodes = list(range(graph_data.num_nodes))
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist()
            known_label_original_ids = set(original_node_ids)
        elif data_type == "val":
            train_mask = full_graph_data.train_mask
            train_original_ids = train_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist()
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist()
            val_nodes_in_full_graph = set(original_node_ids) - set(train_original_ids)
            focus_nodes = [i for i, orig_id in enumerate(original_node_ids) if orig_id in val_nodes_in_full_graph]
            known_label_original_ids = set(original_node_ids)
        elif data_type == "test":
            train_val_mask = full_graph_data.train_mask | full_graph_data.val_mask
            test_mask = full_graph_data.test_mask
            test_original_ids = set(test_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist())
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist() if hasattr(graph_data, 'node_ids') else list(range(graph_data.num_nodes))
            focus_nodes = [i for i, orig_id in enumerate(original_node_ids) if orig_id in test_original_ids]
            known_label_original_ids = set(train_val_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist())
    else:
        raw_texts = graph_data.raw_texts
        focus_mask = {"train": graph_data.train_mask, "val": graph_data.val_mask, "test": graph_data.test_mask}[data_type]
        focus_nodes = focus_mask.nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
        train_val_nodes = (graph_data.train_mask | graph_data.val_mask).nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
        known_label_original_ids = set(train_val_nodes)
        original_node_ids = None
    
    # Prepare edge list and degree information if neighbor information is needed
    edge_list = None
    node_degrees = None
    if requires_neighbor_info(prompt_type):
        edge_index = graph_data.edge_index.cpu()
        
        
        # Apply similarity filtering for simf during test phase only
        if prompt_type == "simf" and data_type == "test":
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                from auto_utils import filter_edges_by_similarity
                print(f"SimF: Applying similarity filtering during {data_type} phase...")
                node_embeddings_cpu = graph_data.x.cpu()
                edge_index = filter_edges_by_similarity(node_embeddings_cpu, edge_index, threshold=0.5)
            else:
                print(f"SimF: Warning - Node embeddings not available, using original edges")
        
        
        # Prepare edge list and degrees for normal samples
        edge_list = prepare_edge_list(edge_index, graph_data.num_nodes)
        node_degrees = torch.zeros(graph_data.num_nodes, dtype=torch.long)
        for i in range(graph_data.num_nodes):
            node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
        
    
    # For noisetxt: prepare text replacement strategy during training
    modified_raw_texts = list(raw_texts)
    if prompt_type == "noisetxt" and data_type == "train":
        avg_degree = node_degrees.float().mean().item() if node_degrees is not None else 0
        high_degree_nodes = [node_id for node_id in focus_nodes if node_degrees[node_id].item() > avg_degree]
        num_to_replace = max(1, int(len(high_degree_nodes) * 0.1))
        if high_degree_nodes:
            nodes_to_replace = random.sample(high_degree_nodes, min(num_to_replace, len(high_degree_nodes)))
            for node_id in nodes_to_replace:
                current_label = graph_data.y[node_id].item()
                replacement_candidates = [
                    candidate_id for candidate_id in range(graph_data.num_nodes)
                    if candidate_id != node_id and graph_data.y[candidate_id].item() != current_label
                    and (inductive and (original_node_ids[candidate_id] if original_node_ids else candidate_id) in known_label_original_ids
                         or not inductive and candidate_id in known_label_original_ids)
                ]
                if replacement_candidates:
                    replacement_node = random.choice(replacement_candidates)
                    modified_raw_texts[node_id] = raw_texts[replacement_node]
                    replacement_label = classes[dataset_name][graph_data.y[replacement_node].item()]
                    current_label_name = classes[dataset_name][current_label]
                    print(f"NoiseText: Replaced text for node {node_id} (label: {current_label_name}) with text from node {replacement_node} (label: {replacement_label})")
    
    # For noisefull: apply both noise and noisetxt strategies  
    elif prompt_type == "noisefull" and data_type == "train":
        avg_degree = node_degrees.float().mean().item() if node_degrees is not None else 0
        high_degree_nodes = [node_id for node_id in focus_nodes if node_degrees[node_id].item() > avg_degree]
        num_to_replace = max(1, int(len(high_degree_nodes) * 0.1))
        if high_degree_nodes:
            nodes_to_replace = random.sample(high_degree_nodes, min(num_to_replace, len(high_degree_nodes)))
            for node_id in nodes_to_replace:
                current_label = graph_data.y[node_id].item()
                replacement_candidates = [
                    candidate_id for candidate_id in range(graph_data.num_nodes)
                    if candidate_id != node_id and graph_data.y[candidate_id].item() != current_label
                    and (inductive and (original_node_ids[candidate_id] if original_node_ids else candidate_id) in known_label_original_ids
                         or not inductive and candidate_id in known_label_original_ids)
                ]
                if replacement_candidates:
                    replacement_node = random.choice(replacement_candidates)
                    modified_raw_texts[node_id] = raw_texts[replacement_node]
                    replacement_label = classes[dataset_name][graph_data.y[replacement_node].item()]
                    current_label_name = classes[dataset_name][current_label]
                    print(f"NoiseFull Text: Replaced text for node {node_id} (label: {current_label_name}) with text from node {replacement_node} (label: {replacement_label})")
    
    data_contents = []
    for local_node_id in focus_nodes: 
        origin_txt = modified_raw_texts[local_node_id]
        label = classes[dataset_name][graph_data.y[local_node_id].item()]
        
        data_item = {
            "id": original_node_ids[local_node_id] if inductive else local_node_id,
            "input": origin_txt, 
            "output": label
        }
        
        if requires_neighbor_info(prompt_type) and edge_list is not None:
            neighbors = edge_list[local_node_id]
            
            if prompt_type == "rand":
                if len(neighbors) > maximum_neighbor:
                    neighbors = random.sample(neighbors, maximum_neighbor)
            elif prompt_type == "noise" and data_type == "train":
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
                if random.random() < 0.1:
                    available_nodes = [i for i in range(graph_data.num_nodes) 
                                     if i != local_node_id and i not in neighbors]
                    if available_nodes:
                        noise_neighbor = random.choice(available_nodes)
                        if len(neighbors) < maximum_neighbor:
                            neighbors.append(noise_neighbor)
                            print(f"Added noise neighbor for node {local_node_id}: added node {noise_neighbor}")
                        else:
                            replace_idx = random.randint(0, len(neighbors) - 1)
                            replaced_neighbor = neighbors[replace_idx]
                            neighbors[replace_idx] = noise_neighbor
                            print(f"Added noise neighbor for node {local_node_id}: replaced neighbor {replaced_neighbor} with node {noise_neighbor}")
            elif prompt_type == "noisefull" and data_type == "train":
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
                if random.random() < 0.1:
                    available_nodes = [i for i in range(graph_data.num_nodes) 
                                     if i != local_node_id and i not in neighbors]
                    if available_nodes:
                        noise_neighbor = random.choice(available_nodes)
                        if len(neighbors) < maximum_neighbor:
                            neighbors.append(noise_neighbor)
                            print(f"NoiseFull: Added noise neighbor for node {local_node_id}: added node {noise_neighbor}")
                        else:
                            replace_idx = random.randint(0, len(neighbors) - 1)
                            replaced_neighbor = neighbors[replace_idx]
                            neighbors[replace_idx] = noise_neighbor
                            print(f"NoiseFull: Added noise neighbor for node {local_node_id}: replaced neighbor {replaced_neighbor} with node {noise_neighbor}")
            else:
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
            
            neighbor_texts = [modified_raw_texts[neigh] for neigh in neighbors]
            neighbor_labels = [classes[dataset_name][graph_data.y[neigh].item()] 
                              if (inductive and (original_node_ids[neigh] if original_node_ids else neigh) in known_label_original_ids
                                  or not inductive and neigh in known_label_original_ids) else "unknown"
                              for neigh in neighbors] if prompt_type == "neighbor_label" else [None] * len(neighbors)
            
            data_item["neighbor_texts"] = neighbor_texts
            data_item["neighbor_labels"] = neighbor_labels
        
        data_contents.append(data_item)
    
    if prompt_type in ["auto"] and data_type == "train":
        print(f"Auto: Starting attack detection and recovery training for {len(data_contents)} original samples")
        augmented_samples = []
        all_nodes = list(range(graph_data.num_nodes))
        
        # Pre-compute nodes grouped by label for efficient candidate selection
        nodes_by_label = {}
        for node_id in all_nodes:
            label = graph_data.y[node_id].item()
            if label not in nodes_by_label:
                nodes_by_label[label] = []
            if inductive:
                node_original_id = original_node_ids[node_id] if original_node_ids else node_id
                if node_original_id in known_label_original_ids:
                    nodes_by_label[label].append(node_id)
            else:
                if node_id in known_label_original_ids:
                    nodes_by_label[label].append(node_id)
        
        # Pre-compute replacement candidates for each label
        replacement_candidates_by_label = {
            target_label: [node for label, nodes in nodes_by_label.items() if label != target_label for node in nodes]
            for target_label in nodes_by_label.keys()
        }
        
        # Calculate average degree for filtering
        avg_degree = node_degrees.float().mean().item() if node_degrees is not None else 0
        
        # Filter nodes for text attack        
        text_attack_eligible = [
            idx for idx, local_node_id in enumerate(focus_nodes)
            if "neighbor_texts" in data_contents[idx] and edge_list is not None and len(edge_list[local_node_id]) >= 2
        ]
        print(f"Auto: {len(text_attack_eligible)} nodes eligible for text attack")
        
        # Configuration for attack ratios
        ATTACK_RATIO = min(1 / len(classes[dataset_name]), 0.15)
        num_text_attack = max(1, int(len(focus_nodes) * ATTACK_RATIO))
        text_attack_samples = []
        
        if text_attack_eligible:
            selected_indices = random.sample(text_attack_eligible, min(num_text_attack, len(text_attack_eligible)))
            for idx in selected_indices:
                original_sample = data_contents[idx]
                local_node_id = focus_nodes[idx]
                current_label = graph_data.y[local_node_id].item()
                
                # Find replacement text from different class
                all_candidates = replacement_candidates_by_label.get(current_label, [])
                if all_candidates:
                    replacement_node = random.choice(all_candidates)
                    
                    # Create text attack sample
                    attack_sample = original_sample.copy()
                    attack_sample["input"] = modified_raw_texts[replacement_node]
                    attack_sample["output"] = "text_attacked"
                    attack_sample["original_label"] = classes[dataset_name][current_label]
                    
                    if requires_neighbor_info(prompt_type):
                        neighbors = edge_list[local_node_id]
                        if len(neighbors) > maximum_neighbor:
                            neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                            neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                            neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
                        attack_sample["neighbor_texts"] = [modified_raw_texts[neigh] for neigh in neighbors]
                        attack_sample["neighbor_labels"] = [
                            classes[dataset_name][graph_data.y[neigh].item()] 
                            if (inductive and (original_node_ids[neigh] if original_node_ids else neigh) in known_label_original_ids
                                or not inductive and neigh in known_label_original_ids) else "unknown"
                            for neigh in neighbors
                        ] if prompt_type == "neighbor_label" else [None] * len(neighbors)
                    
                    augmented_samples.append(attack_sample)
                    text_attack_samples.append(attack_sample)
                    print(f"Auto: Created text_attacked sample for node {original_sample['id']}")
        
        # Generate recovery samples
        num_recovery = max(1, int(len(focus_nodes) * ATTACK_RATIO))
        all_attack_samples = text_attack_samples
        
        if all_attack_samples:
            recovery_indices = random.sample(range(len(all_attack_samples)), min(num_recovery, len(all_attack_samples)))
            for idx in recovery_indices:
                attack_sample = all_attack_samples[idx]
                recovery_sample = attack_sample.copy()
                recovery_sample["output"] = attack_sample["original_label"]
                recovery_sample["input"] = ""
                recovery_sample["recovery_type"] = "neighbors_only"                
                augmented_samples.append(recovery_sample)
                print(f"Auto: Created recovery sample for node {recovery_sample['id']} using {recovery_sample['recovery_type']}")
        
        data_contents.extend(augmented_samples)
        print(f"Auto: Added {len(text_attack_samples)} text attack and {len(augmented_samples) - len(text_attack_samples)} recovery samples")
        print(f"Auto: Total training samples: {len(data_contents)} (increased by {len(augmented_samples)/len(focus_nodes)*100:.1f}%)")
    
    if data_type != "test":      
        random.shuffle(data_contents)
    
    return data_contents


def tokenizer_instruction_tuning_data(raw_data, tokenizer, dataset_name, prompt_type, model_name, max_txt_length=128, max_origin_txt_length=128, max_ans_length=16, return_prompt=False):
    """Tokenize instruction tuning data using chat template"""
    full_input_ids, full_attention_masks, full_labels = [], [], [] 
    query_length, txt_length, label_length = [], [], []
    prompts = []
    
    # Handle auto prompt type
    if prompt_type in ["auto"]:
        # For auto, we need to handle different sample types
        for sample in raw_data:
            recovery_type = sample.get("recovery_type", None)
            
            # Determine which prompt template and classes to use
            if recovery_type == "neighbors_only":
                # Text attacked recovery: Use neighbor-only prompt (text is unreliable, use neighbors)
                prompt_template = get_prompt_template("auto_neighbor_only", dataset_name)
                classes_str = get_classes_str(dataset_name, include_attacked=False)
                use_neighbors_only = True
            else:
                # Normal samples and attack detection: Use auto_attack_detection prompt
                prompt_template = get_prompt_template("auto_attack_detection", dataset_name)
                classes_str = get_classes_str(dataset_name, include_attacked=True)
                use_neighbors_only = False
            
            # Truncate the origin text
            origin_txt = sample["input"]
            if origin_txt == "":
                truncated_origin_txt = ""
            else:
                tokenized_origin_txt = tokenizer(origin_txt, add_special_tokens=False)
                origin_txt_ids = tokenized_origin_txt.input_ids[:max_origin_txt_length] 
                truncated_origin_txt = tokenizer.decode(origin_txt_ids, skip_special_tokens=True)
            
            # Prepare neighbor information
            neighbor_texts = sample.get("neighbor_texts", [])
            neighbor_labels = sample.get("neighbor_labels", [])
            truncated_neighbor_texts = []
            
            for i, neighbor_text in enumerate(neighbor_texts):
                tokenized_neighbor_txt = tokenizer(neighbor_text, add_special_tokens=False)
                neighbor_txt_ids = tokenized_neighbor_txt.input_ids[:max_origin_txt_length]
                truncated_neighbor_txt = tokenizer.decode(neighbor_txt_ids, skip_special_tokens=True)
                truncated_neighbor_texts.append(truncated_neighbor_txt)
            
            neighbor_str = "\n" + "\n".join([f"Neighbor {i+1}: {text}" for i, text in enumerate(truncated_neighbor_texts)])
            
            # Format prompt based on sample type
            if use_neighbors_only:
                # For text attack recovery, only use neighbor information
                user_content = prompt_template.format(neighbor_text=neighbor_str, classes=classes_str)
            else:
                # For normal samples and attack detection, use both center text and neighbors
                user_content = prompt_template.format(origin_text=truncated_origin_txt, neighbor_text=neighbor_str, classes=classes_str)
            
            # Create conversation using chat template
            messages = [{"role": "user", "content": user_content}]
            
            formatted_input = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            if return_prompt:
                prompts.append(formatted_input)
            
            target = f"{sample['output']}{tokenizer.eos_token}"
            source_ids = tokenizer(formatted_input, add_special_tokens=False).input_ids
            target_ids = tokenizer(target, add_special_tokens=False).input_ids
            input_ids = source_ids + target_ids
            labels = [IGNORE_INDEX] * len(source_ids) + target_ids
            
            full_input_ids.append(input_ids)
            full_attention_masks.append([1] * len(input_ids))
            full_labels.append(labels)
            
            # Statistics for choosing suitable maximum lengths
            query_length.append(len(tokenizer(user_content, add_special_tokens=False).input_ids)) 
            txt_length.append(len(tokenized_origin_txt.input_ids) if origin_txt != "" else 1)
            tokenized_label = tokenizer(sample["output"], add_special_tokens=False)
            label_length.append(len(tokenized_label.input_ids))
    else:
        # Original logic for other prompt types
        prompt_template = get_prompt_template(prompt_type, dataset_name)
        classes_str = get_classes_str(dataset_name)
        
        for sample in raw_data:
            # Truncate the origin text
            tokenized_origin_txt = tokenizer(sample["input"], add_special_tokens=False)
            origin_txt_ids = tokenized_origin_txt.input_ids[:max_origin_txt_length] 
            truncated_origin_txt = tokenizer.decode(origin_txt_ids, skip_special_tokens=True)
            
            # Prepare the user content based on prompt type
            if requires_neighbor_info(prompt_type):
                neighbor_texts = sample.get("neighbor_texts", [])
                neighbor_labels = sample.get("neighbor_labels", [])
                truncated_neighbor_texts = []
                
                for i, neighbor_text in enumerate(neighbor_texts):
                    tokenized_neighbor_txt = tokenizer(neighbor_text, add_special_tokens=False)
                    neighbor_txt_ids = tokenized_neighbor_txt.input_ids[:max_origin_txt_length]
                    truncated_neighbor_txt = tokenizer.decode(neighbor_txt_ids, skip_special_tokens=True)
                    
                    if prompt_type in ["neighbor_label"] and i < len(neighbor_labels):
                        label = neighbor_labels[i] if neighbor_labels[i] else "unknown"
                        truncated_neighbor_texts.append(f"node text: {truncated_neighbor_txt}, label: {label}")
                    else:
                        # For all other prompt types: neighbor, noise, rand, sim, noisetxt
                        truncated_neighbor_texts.append(truncated_neighbor_txt)
                
                neighbor_str = "\n".join([f"Neighbor {i+1}: {text}" for i, text in enumerate(truncated_neighbor_texts)])
                user_content = prompt_template.format(origin_text=truncated_origin_txt, neighbor_text=neighbor_str, classes=classes_str)
            else:
                user_content = prompt_template.format(node=truncated_origin_txt, classes=classes_str)
            
            # Create conversation using chat template
            messages = [{"role": "user", "content": user_content}]
            
            formatted_input = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            if return_prompt:
                prompts.append(formatted_input)
            
            target = f"{sample['output']}{tokenizer.eos_token}"
            source_ids = tokenizer(formatted_input, add_special_tokens=False).input_ids
            target_ids = tokenizer(target, add_special_tokens=False).input_ids
            input_ids = source_ids + target_ids
            labels = [IGNORE_INDEX] * len(source_ids) + target_ids
            
            full_input_ids.append(input_ids)
            full_attention_masks.append([1] * len(input_ids))
            full_labels.append(labels)
            
            # Statistics for choosing suitable maximum lengths
            query_length.append(len(tokenizer(user_content, add_special_tokens=False).input_ids)) 
            txt_length.append(len(origin_txt_ids))
            tokenized_label = tokenizer(sample["output"], add_special_tokens=False)
            label_length.append(len(tokenized_label.input_ids))
    
    max_length = max([len(x) for x in full_input_ids])
    print(f"Avg Query Prompt Length {sum(query_length)/len(query_length):.4f} | Avg OriginTxt Length {sum(txt_length)/len(txt_length):.4f} | Avg Output Length {sum(label_length)/len(label_length):.4f}")
    
    # Pad sequences
    for i in range(len(full_input_ids)):
        pad_length = max_length - len(full_input_ids[i])
        full_input_ids[i] = [tokenizer.pad_token_id] * pad_length + full_input_ids[i]
        full_attention_masks[i] = [0] * pad_length + full_attention_masks[i]
        full_labels[i] = [IGNORE_INDEX] * pad_length + full_labels[i]

    input_ids = torch.tensor(full_input_ids)
    attention_mask = torch.tensor(full_attention_masks)
    label_input_ids = torch.tensor(full_labels)

    result = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": label_input_ids
    }
    
    if return_prompt:
        result["prompt"] = prompts[0] if prompts else ""
    
    return result


def tokenizer_test_data(batch_data, tokenizer, dataset_name, prompt_type, model_name, max_txt_length=128, max_origin_txt_length=128):
    """Tokenize test data using chat template for inference with the new prompt system"""
    full_input_ids, full_attention_masks = [], []
    
    # Handle auto prompt type
    if prompt_type in ["auto"]:
        # For auto during inference
        for sample in batch_data:
            recovery_type = sample.get("recovery_type", None)
            
            # Determine which prompt template and classes to use
            if recovery_type == "neighbors_only":
                # Text attacked recovery: Use neighbor-only prompt (text is unreliable, use neighbors)
                prompt_template = get_prompt_template("auto_neighbor_only", dataset_name)
                classes_str = get_classes_str(dataset_name, include_attacked=False)
                use_neighbors_only = True
            else:
                # Normal inference: Use auto_attack_detection prompt
                prompt_template = get_prompt_template("auto_attack_detection", dataset_name)
                classes_str = get_classes_str(dataset_name, include_attacked=True)
                use_neighbors_only = False
            
            # Truncate the origin text
            origin_txt = sample["input"]
            if origin_txt == "":
                truncated_origin_txt = ""
            else:
                tokenized_origin_txt = tokenizer(origin_txt, add_special_tokens=False)
                origin_txt_ids = tokenized_origin_txt.input_ids[:max_origin_txt_length] 
                truncated_origin_txt = tokenizer.decode(origin_txt_ids, skip_special_tokens=True)
            
            # Prepare neighbor information
            neighbor_texts = sample.get("neighbor_texts", [])
            neighbor_labels = sample.get("neighbor_labels", [])
            truncated_neighbor_texts = []
            
            for i, neighbor_text in enumerate(neighbor_texts):
                tokenized_neighbor_txt = tokenizer(neighbor_text, add_special_tokens=False)
                neighbor_txt_ids = tokenized_neighbor_txt.input_ids[:max_origin_txt_length]
                truncated_neighbor_txt = tokenizer.decode(neighbor_txt_ids, skip_special_tokens=True)
                truncated_neighbor_texts.append(truncated_neighbor_txt)
            
            neighbor_str = "\n".join([f"Neighbor {i+1}: {text}" for i, text in enumerate(truncated_neighbor_texts)])
            
            # Format prompt based on sample type
            if use_neighbors_only:
                # For text attack recovery, only use neighbor information
                user_content = prompt_template.format(neighbor_text=neighbor_str, classes=classes_str)
            else:
                # For attack detection, use both center text and neighbors
                user_content = prompt_template.format(origin_text=truncated_origin_txt, neighbor_text=neighbor_str, classes=classes_str)
            
            # Create conversation using chat template
            messages = [{"role": "user", "content": user_content}]
            
            formatted_input = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            tokenized = tokenizer(formatted_input, add_special_tokens=False)
            full_input_ids.append(tokenized.input_ids)
            full_attention_masks.append([1] * len(tokenized.input_ids))
    else:
        # Original logic for other prompt types
        prompt_template = get_prompt_template(prompt_type, dataset_name)
        classes_str = get_classes_str(dataset_name)
        
        for sample in batch_data:
            # Truncate the origin text
            tokenized_origin_txt = tokenizer(sample["input"], add_special_tokens=False)
            origin_txt_ids = tokenized_origin_txt.input_ids[:max_origin_txt_length] 
            truncated_origin_txt = tokenizer.decode(origin_txt_ids, skip_special_tokens=True)
            
            # Prepare the user content based on prompt type
            if requires_neighbor_info(prompt_type):
                neighbor_texts = sample.get("neighbor_texts", [])
                neighbor_labels = sample.get("neighbor_labels", [])
                truncated_neighbor_texts = []
                
                for i, neighbor_text in enumerate(neighbor_texts):
                    tokenized_neighbor_txt = tokenizer(neighbor_text, add_special_tokens=False)
                    neighbor_txt_ids = tokenized_neighbor_txt.input_ids[:max_origin_txt_length]
                    truncated_neighbor_txt = tokenizer.decode(neighbor_txt_ids, skip_special_tokens=True)
                    
                    if prompt_type in ["neighbor_label"] and i < len(neighbor_labels):
                        label = neighbor_labels[i] if neighbor_labels[i] else "unknown"
                        truncated_neighbor_texts.append(f"node text: {truncated_neighbor_txt}, label: {label}")
                    else:
                        # For all other prompt types: neighbor, noise, rand, sim, noisetxt
                        truncated_neighbor_texts.append(truncated_neighbor_txt)
                
                neighbor_str = "\n".join([f"Neighbor {i+1}: {text}" for i, text in enumerate(truncated_neighbor_texts)])
                user_content = prompt_template.format(origin_text=truncated_origin_txt, neighbor_text=neighbor_str, classes=classes_str)
            else:
                user_content = prompt_template.format(node=truncated_origin_txt, classes=classes_str)
            
            # Create conversation using chat template
            messages = [{"role": "user", "content": user_content}]
            
            formatted_input = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            tokenized = tokenizer(formatted_input, add_special_tokens=False)
            full_input_ids.append(tokenized.input_ids)
            full_attention_masks.append([1] * len(tokenized.input_ids))
    
    # Pad sequences
    max_length = max([len(x) for x in full_input_ids])
    for i in range(len(full_input_ids)):
        pad_length = max_length - len(full_input_ids[i])
        full_input_ids[i] = [tokenizer.pad_token_id] * pad_length + full_input_ids[i]
        full_attention_masks[i] = [0] * pad_length + full_attention_masks[i]

    input_ids = torch.tensor(full_input_ids)
    attention_mask = torch.tensor(full_attention_masks)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask
    }


def get_eval_steps(dataset_name, re_split=1):
    """Get evaluation steps based on dataset"""
    if dataset_name in ["cora", "citeseer"]:
        return 20 if re_split == 1 else 100
    elif dataset_name in ["wikics", "instagram", "pubmed"]:
        return 40 if re_split == 1 else 200 
    elif dataset_name in ["reddit", "history", "photo", "computer"]:
        return 80 if re_split == 1 else 400 
    elif dataset_name in ["arxiv"]:
        return 2000
    else:
        return 100


def run_inference(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name):
    """Run inference using the trained model"""
    batch_size = data_args.batch_size * 2
    pred_labels, gt_labels = [], []
    results = []
    
    print("Running inference...")
    for i in range(0, len(test_contents), batch_size):
        batch_data = test_contents[i: min(i+batch_size, len(test_contents))]
        batch_encodings = tokenizer_test_data(
            batch_data, tokenizer, dataset_name, data_args.prompt_type, model_name,
            max_txt_length=data_args.max_txt_length, 
            max_origin_txt_length=data_args.max_origin_txt_length
        )
        
        batch_input_ids = batch_encodings["input_ids"]
        batch_attention_mask = batch_encodings["attention_mask"]
        
        # Let accelerate handle device placement
        batch_input_ids = batch_input_ids.to(accelerator.device)
        batch_attention_mask = batch_attention_mask.to(accelerator.device)
        
        # Use the trainer's model for generation
        model_for_generation = trainer.model
        if hasattr(model_for_generation, 'module'):
            model_for_generation = model_for_generation.module
        
        with torch.no_grad():
            output = model_for_generation.generate(
                input_ids=batch_input_ids, 
                attention_mask=batch_attention_mask, 
                max_new_tokens=data_args.max_ans_length,
                do_sample=False,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        
        decode_output = tokenizer.batch_decode(output, skip_special_tokens=True)
        origin_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
        
        for idx, pred_content in enumerate(decode_output):
            origin_prompt = origin_prompts[idx] 
            label, node_id = batch_data[idx]["output"], batch_data[idx]["id"]
            
            # Extract prediction by removing the original prompt
            pred_label = pred_content.replace(origin_prompt, "").strip()
            
            # Take the first line and first word to handle verbose outputs
            lines = [line.strip() for line in pred_label.split('\n') if line.strip()]
            if lines:
                pred_label = lines[0]
            
            # Remove quotes and take first word
            pred_label = pred_label.strip().strip('"').strip("'").strip()
            
            # Validate prediction
            pred_label = pred_label if pred_label in classes[dataset_name] else UNKNOW
            
            pred_labels.append(pred_label)
            gt_labels.append(label)
            
            results.append({
                "idx": node_id, 
                "ground-truth": label,
                "pred": pred_label,
                "raw_output": pred_content,
                "cleaned_pred": pred_content.replace(origin_prompt, "").strip()
            })
    
    return pred_labels, gt_labels, results


def run_simf_inference(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data=None):
    """Run inference for simf prompt (similarity filtering applied during data preparation)"""
    print("Running SimF inference (similarity filtering already applied)...")
    
    # SimF uses standard inference since filtering was applied during data preparation
    pred_labels, gt_labels, results = run_inference(
        trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name
    )
    
    # SimF-specific summary logging
    print(f"\nSimF Summary:")
    print(f"  - Total test nodes: {len(test_contents)}")
    print(f"  - Similarity filtering applied during data preparation")
    
    return pred_labels, gt_labels, results


def run_auto_inference(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data=None, prompt_type="auto"):
    """Run three-stage inference for auto prompt (clean data version):
    1) detect text-attacked and structure-attacked nodes
    2) recover text-attacked nodes using neighbors only  
    3) recover structure-attacked nodes by filtering dissimilar neighbors
    """
    # Use shared implementation from auto_utils.py with clean data logging
    from auto_utils import run_auto_inference_with_structure
    
    pred_labels, gt_labels, results = run_auto_inference_with_structure(
        trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data, is_attack_eval=False, prompt_type=prompt_type
    )
    
    # Clean data specific summary logging
    text_attacked_detected = sum(1 for r in results if r.get("stage1_pred") == "text_attacked")
    structure_attacked_detected = sum(1 for r in results if r.get("prediction_source") == "structure_recovery")
    text_recovered = sum(1 for r in results if r.get("prediction_source") == "text_recovery") 
    
    print(f"\nClean Data Auto Summary:")
    print(f"  - Nodes detected as text-attacked: {text_attacked_detected}")
    print(f"  - Nodes detected as structure-attacked: {structure_attacked_detected}")  
    print(f"  - Text recovery applied: {text_recovered}")
    print(f"  - Total test nodes: {len(test_contents)}")
    
    return pred_labels, gt_labels, results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--llm', type=str, default="Mistral-7B")
    parser.add_argument('--dataset', type=str, default='cora')
    parser.add_argument('--prompt_type', type=str, default='instruction_tuning', 
                       help=f'Type of prompt template to use. Available types: {", ".join(get_available_prompt_types())}')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--max_txt_length', type=int, default=128)
    parser.add_argument('--max_origin_txt_length', type=int, default=128)
    parser.add_argument('--max_ans_length', type=int, default=16)
    parser.add_argument('--maximum_neighbor', type=int, default=6)
    parser.add_argument('--num_epoch', type=int, default=2)
    parser.add_argument('--re_split', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--lora_r', type=int, default=8)
    parser.add_argument('--lora_alpha', type=int, default=16)
    parser.add_argument('--lora_dropout', type=float, default=0.1)
    parser.add_argument('--skip_inference', action='store_true', help='Skip inference after training')
    
    args = parser.parse_args()
    
    # Validate prompt_type and dataset combination
    available_prompt_types = get_available_prompt_types()
    if args.prompt_type not in available_prompt_types:
        raise ValueError(f"Invalid prompt_type: {args.prompt_type}. Available types: {available_prompt_types}")
    
    supported_datasets = get_supported_datasets(args.prompt_type)
    if args.dataset not in supported_datasets:
        raise ValueError(f"Dataset {args.dataset} not supported for prompt_type {args.prompt_type}. Supported datasets: {supported_datasets}")
    
    print(f"Using prompt_type: {args.prompt_type} for dataset: {args.dataset}")
    
    # Create dataclass instances
    model_args = ModelArguments(model_name_or_path=args.llm)
    data_args = DataArguments(
        dataset=args.dataset,
        prompt_type=args.prompt_type,
        max_txt_length=args.max_txt_length,
        max_origin_txt_length=args.max_origin_txt_length,
        max_ans_length=args.max_ans_length,
        re_split=args.re_split,
        maximum_neighbor=args.maximum_neighbor
    )
    data_args.batch_size = args.batch_size  # Add batch_size to data_args for inference
    
    lora_args = LoraArguments(
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout
    )
    
    set_seed(args.seed)
    
    # Unified is_inductive logic
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    # For auto and simf prompt types, only support inductive setting
    if args.prompt_type in ["auto", "simf"] and not is_inductive:
        raise ValueError(f"{args.prompt_type} prompt type only supports inductive setting (re_split=2 or arxiv with re_split=0)")
    
    # Generate model save path - simf uses neighbor model path
    if args.prompt_type == "simf":
        actual_prompt_for_path = "neighbor"
    else:
        actual_prompt_for_path = args.prompt_type
        
    save_path = get_model_save_path(
        model_name="InstructionTuning",
        dataset=args.dataset,
        re_split=args.re_split,
        llm=args.llm,
        seed=args.seed,
        atk_name=None,
        prompt_type=actual_prompt_for_path,
        num_epoch=args.num_epoch
    )
    
    # Setup directories
    alg_dir = f"../../results/InstructionTuning"
    save_dir = save_path
    
    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision='bf16',
        gradient_accumulation_steps=1,
        project_dir=save_dir,
    )

    if accelerator.is_main_process:
        print(f"Model save path: {save_path}")
        if args.prompt_type == "simf":
            print(f"SimF: Using neighbor model for similarity filtering")
        elif args.prompt_type == "auto":
            print(f"Auto: Using auto model for attack detection and recovery")
    
    device = accelerator.device
    
    # Load tokenizer and model
    llm_path = llm_paths[args.llm]
    tokenizer = AutoTokenizer.from_pretrained(llm_path)
    
    # Universal tokenizer setup for different models
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    tokenizer.padding_side = 'left'
    
    if accelerator.is_main_process:
        print(f"Tokenizer setup: pad_token='{tokenizer.pad_token}', pad_token_id={tokenizer.pad_token_id}")
        print(f"Chat template available: {hasattr(tokenizer, 'apply_chat_template') and tokenizer.chat_template is not None}")
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        llm_path,
        torch_dtype=torch.bfloat16,
        device_map=None,  # Let accelerator handle device mapping
    )
    
    # Resize token embeddings if we added new tokens
    if len(tokenizer) > model.config.vocab_size:
        model.resize_token_embeddings(len(tokenizer))
    
    # Setup LoRA
    peft_config = LoraConfig(
        task_type="CAUSAL_LM", 
        inference_mode=False, 
        r=lora_args.lora_r, 
        lora_alpha=lora_args.lora_alpha, 
        lora_dropout=lora_args.lora_dropout,
        target_modules=lora_args.lora_target_modules
    )
    ft_model = get_peft_model(model, peft_config)
    
    if accelerator.is_main_process:
        ft_model.print_trainable_parameters()
    
    # Move model to device
    ft_model = ft_model.to(device)
    
    # Check if trained model already exists
    model_files = ["adapter_config.json", ["adapter_model.bin", "adapter_model.safetensors"]]  # LoRA adapter files, either .bin or .safetensors
    should_train = True
    train_secs = 0
    
    def check_model_files(path, required_files):
        for file in required_files:
            if isinstance(file, list):  # For files where we accept multiple formats
                if not any(os.path.exists(os.path.join(path, f)) for f in file):
                    return False
            elif not os.path.exists(os.path.join(path, file)):
                return False
        return True
    
    # Handle model loading based on prompt type
    if args.prompt_type == "simf":
        # SimF must use existing neighbor model, never retrain
        if not check_model_files(save_path, model_files):
            raise FileNotFoundError(f"SimF requires existing neighbor model but not found: {save_path}. Train neighbor model first.")
        
        try:
            ft_model = AutoPeftModelForCausalLM.from_pretrained(
                save_path,
                torch_dtype=torch.bfloat16,
                use_safetensors=True
            )
            ft_model = ft_model.to(device)
            should_train = False
            if accelerator.is_main_process:
                print(f"SimF: Loaded existing neighbor model from {save_path}")
        except Exception as e:
            raise FileNotFoundError(f"SimF: Failed to load neighbor model: {e}")
            
    else:
        # Auto and other prompt types - standard behavior: load own model or train
        if check_model_files(save_path, model_files):
            if accelerator.is_main_process:
                print(f"Loading existing {args.prompt_type} model from {save_path}")
            try:
                ft_model = AutoPeftModelForCausalLM.from_pretrained(
                    save_path,
                    torch_dtype=torch.bfloat16,
                    use_safetensors=True
                )
                ft_model = ft_model.to(device)
                should_train = False
                if accelerator.is_main_process:
                    print(f"{args.prompt_type}: Skipping training as model already exists")
            except Exception as e:
                if accelerator.is_main_process:
                    print(f"{args.prompt_type}: Failed to load existing model: {e}")
                    print(f"{args.prompt_type}: Will proceed with training...")
                should_train = True
        else:
            should_train = True
            if accelerator.is_main_process:
                print(f"{args.prompt_type}: No existing model found, proceeding with training...")
    
    
    # Data loading section
    if not is_inductive:
        # Transductive setting
        graph_data = load_graph_dataset(dataset_name=args.dataset, device=device, re_split=args.re_split)
        train_contents = prepare_graph_instruction_tuning_data(graph_data, "train", args.dataset, args.prompt_type, args.maximum_neighbor)
        val_contents = prepare_graph_instruction_tuning_data(graph_data, "val", args.dataset, args.prompt_type, args.maximum_neighbor) 
        test_contents = prepare_graph_instruction_tuning_data(graph_data, "test", args.dataset, args.prompt_type, args.maximum_neighbor)
    
    else:
        graph_data, (train_data, val_data, test_data) = load_inductive_graph_dataset(dataset_name=args.dataset, device=device, re_split=args.re_split, seed=args.seed)
        
        # Load roberta embeddings for simf and auto prompts (required for their functionality)
        if args.prompt_type in ['simf', 'auto']:
            PATH = "/path/to/GraphAD_data"
            node_feat = torch.load(f"{PATH}/datasets/roberta/{args.dataset}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
            test_data.x = node_feat
            graph_data.x = node_feat
            if accelerator.is_main_process:
                print(f"{args.prompt_type.upper()}: Loaded roberta embeddings for {args.dataset} (shape: {node_feat.shape})")
        
        train_contents = prepare_graph_instruction_tuning_data(train_data, "train", args.dataset, args.prompt_type, args.maximum_neighbor, inductive=True, 
                                                              full_graph_data=graph_data)
        val_contents = prepare_graph_instruction_tuning_data(val_data, "val", args.dataset, args.prompt_type, args.maximum_neighbor, inductive=True, 
                                                              full_graph_data=graph_data) 
        test_contents = prepare_graph_instruction_tuning_data(test_data, "test", args.dataset, args.prompt_type, args.maximum_neighbor, inductive=True, 
                                                              full_graph_data=graph_data)
    
    train_encodings = tokenizer_instruction_tuning_data(
        train_contents, tokenizer, args.dataset, data_args.prompt_type, args.llm,
        max_txt_length=data_args.max_txt_length, 
        max_origin_txt_length=data_args.max_origin_txt_length, 
        max_ans_length=data_args.max_ans_length
    )
    train_dataset = TextDataset(train_encodings) 
    
    val_encodings = tokenizer_instruction_tuning_data(
        val_contents, tokenizer, args.dataset, data_args.prompt_type, args.llm,
        max_txt_length=data_args.max_txt_length, 
        max_origin_txt_length=data_args.max_origin_txt_length, 
        max_ans_length=data_args.max_ans_length
    )
    val_dataset = TextDataset(val_encodings)

    if accelerator.is_main_process:
        print(f"Train dataset size: {len(train_dataset)}, Val dataset size: {len(val_contents)}, Test dataset size: {len(test_contents)}")
    
    # Get evaluation steps
    eval_steps = get_eval_steps(args.dataset, args.re_split)
    if args.prompt_type == "auto":
        eval_steps *= 2
    
    training_args = HFTrainingArguments(
        output_dir=save_dir,
        learning_rate=1e-4,                           
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size * 2,
        num_train_epochs=args.num_epoch,
        weight_decay=0.01,
        eval_strategy=IntervalStrategy.STEPS,
        eval_steps=eval_steps,
        save_steps=eval_steps,
        save_total_limit=1,
        load_best_model_at_end=True,
        report_to="none",
        warmup_ratio=0.1,                            
        lr_scheduler_type="cosine",                   
        optim="adamw_torch",                         
        bf16=True,                                   
        ddp_find_unused_parameters=False,
    )
    
    # Setup trainer
    trainer = Trainer(
        model=ft_model, 
        args=training_args,
        train_dataset=train_dataset, 
        eval_dataset=val_dataset,  
        tokenizer=tokenizer,
        data_collator=None,
    )
    
    # Prepare everything for distributed training
    trainer = accelerator.prepare(trainer)
    
    # Conditional training
    if should_train:
        # Train
        st_time = time.time()
        trainer.train()
        train_secs = time.time() - st_time
        
        # Save the final model
        if accelerator.is_main_process:
            # Save the LoRA adapter to the unified save_path
            trainer.model.save_pretrained(save_path)
            print(f"LoRA adapter saved to {save_path}")
            
            # Save training info
            training_info = {
                "dataset": args.dataset,
                "model": args.llm,
                "prompt_type": args.prompt_type,
                "seed": args.seed,
                "train_time_minutes": train_secs/60,
                "lora_adapter_path": save_path,
                "base_model_path": llm_path
            }
            
            with open(f"{save_path}/training_info.json", "w") as f:
                json.dump(training_info, f, indent=2)
            
            print(f"Training completed in {train_secs/60:.3f} minutes")
    else:
        if accelerator.is_main_process:
            print("Skipped training as model already exists")
    
    # Run inference if not skipped
    if not args.skip_inference:
        # Define correct file suffix and setting name for different re_split values
        re_split_str, setting_name = get_setting_info(args.re_split)
        
        st_time = time.time()
        
        # Use different inference functions based on prompt type
        if args.prompt_type in ["auto"]:
            pred_labels, gt_labels, results = run_auto_inference(
                trainer, tokenizer, test_contents, args.dataset, data_args, accelerator, args.llm, graph_data, prompt_type=args.prompt_type
            )
        elif args.prompt_type == "simf":
            # For simf, pass test_data which has the filtered edges
            pred_labels, gt_labels, results = run_simf_inference(
                trainer, tokenizer, test_contents, args.dataset, data_args, accelerator, args.llm, test_data
            )
        else:
            pred_labels, gt_labels, results = run_inference(
                trainer, tokenizer, test_contents, args.dataset, data_args, accelerator, args.llm
            )
        
        inference_secs = time.time() - st_time
        
        # Compute metrics
        acc, macro_f1, weight_f1 = compute_acc_and_f1(pred_labels, gt_labels)
        
        # Calculate unknown predictions ratio
        unknown_count = sum(1 for pred in pred_labels if pred == UNKNOW)
        unknown_ratio = unknown_count / len(pred_labels) if len(pred_labels) > 0 else 0
        
        if accelerator.is_main_process:
            print(f"Accuracy: {acc:.4f}")
            print(f"Macro F1-Score: {macro_f1:.4f}")
            print(f"Weighted F1-Score: {weight_f1:.4f}")
            print(f"Inference time: {inference_secs:.2f} seconds")
            print(f"Unknown predictions: {unknown_count}/{len(pred_labels)} ({unknown_ratio:.2%})")
            
            # For auto prompt, also log attack detection statistics
            if args.prompt_type == "auto":
                text_attacked_detected = sum(1 for r in results if r.get("stage1_pred") == "text_attacked")
                stage2_used = sum(1 for r in results if r.get("prediction_source") == "recovery_inference")
                print(f"Auto-specific stats:")
                print(f"  - Nodes detected as text_attacked: {text_attacked_detected}")
                print(f"  - Nodes using stage2 recovery: {stage2_used}")
                
                # Print metrics if available
                if results and "text_attack_metrics" in results[0]:
                    metrics = results[0]["text_attack_metrics"]
                    print(f"  - Text attack detection rate: {metrics['detection_rate']:.2%}")
                    print(f"  - Avg degree (all attacked): {metrics['avg_degree_all_attacked']:.2f}")
                    print(f"  - Avg degree (correctly detected): {metrics['avg_degree_correctly_detected']:.2f}")
            
            # Save results
            write_dir = f"{alg_dir}/prediction"
            os.makedirs(write_dir, exist_ok=True)
            result_file = f"{write_dir}/{args.dataset}_{args.llm}_{args.prompt_type}{re_split_str}_seed{args.seed}.json"
            
            with open(result_file, "w") as f:
                for result in results:
                    f.write(json.dumps(result) + "\n")
            
            print(f"Results saved to {result_file}")
            
            # Save summary
            summary_file = f"{alg_dir}/summary.csv"
            with open(summary_file, 'a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([
                    args.dataset, args.llm, args.prompt_type, acc, macro_f1, weight_f1, 
                    setting_name, f"Seed-{args.seed}", 
                    f"Batch Size-{args.batch_size}", f"Epoch-{args.num_epoch}",
                    f"Train Minutes-{train_secs/60:.3f}", f"Inference Seconds-{inference_secs:.2f}"
                ])
    
    # Clean up
    accelerator.wait_for_everyone()
    
    accelerator.end_training()


if __name__ == "__main__":
    main() 