import argparse
import os
import sys
import json
import csv
import time
import random
from dataclasses import dataclass, field
from typing import Optional, List

import torch
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments as HFTrainingArguments,
    Trainer, 
    IntervalStrategy
)

from dataset import TextDataset

sys.path.append("../..")
from common import set_seed, load_atk_graph_dataset, load_inductive_atk_graph_dataset, load_inductive_graph_dataset, compute_acc_and_f1
from common import IGNORE_INDEX, UNKNOW
from common import CLASSES as classes 
from common import MODEL_PATHs as llm_paths
from common import get_prompt_template, get_available_prompt_types, get_supported_datasets, requires_neighbor_info, get_classes_str
from common import prepare_edge_list
from common.model_path import get_model_save_path, check_model_exists, get_setting_info
from common.sft_prompts import CLASSES_WITH_ATTACKED, AUTO_SIMILARITY_THRESHOLDS


# Import shared similarity filtering function
from auto_utils import filter_edges_by_similarity


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="Mistral-7B")


@dataclass
class DataArguments:
    dataset: str = field(default="cora", metadata={"help": "Dataset name"})
    prompt_type: str = field(default="instruction_tuning", metadata={"help": "Type of prompt template to use"})
    max_txt_length: int = field(default=128, metadata={"help": "Maximum length of query prompt"})
    max_origin_txt_length: int = field(default=128, metadata={"help": "Maximum length of node's original text"})
    max_ans_length: int = field(default=16, metadata={"help": "Maximum length of answer"})
    re_split: int = field(default=0, metadata={"help": "Whether to use re-split"})
    maximum_neighbor: int = field(default=8, metadata={"help": "Maximum number of neighbors to include"})
    # Attack-specific arguments
    attack: str = field(default="pgd", metadata={"help": "Attack type"})
    atk_type: str = field(default="structure", metadata={"help": "Attack category"})
    ptb_rate: float = field(default=0.1, metadata={"help": "Perturbation rate"})
    atk_emb_type: str = field(default="bow", metadata={"help": "Embedding type used for attack"})
    atk_seed: int = field(default=0, metadata={"help": "Seed used for attack generation"})


@dataclass
class TrainingArguments(HFTrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=8192,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=True)
    num_epoch: int = field(default=2)
    batch_size: int = field(default=8)


@dataclass
class LoraArguments:
    lora_r: int = field(default=8)
    lora_alpha: int = field(default=16)
    lora_dropout: float = field(default=0.1)
    lora_target_modules: List[str] = field(
        default_factory=lambda: ["q_proj", "v_proj"]
    )
    lora_weight_path: str = field(default="")
    lora_bias: str = field(default="none")
    q_lora: bool = field(default=False)


def prepare_graph_instruction_tuning_data_atk(graph_data, data_type="train", dataset_name="cora", prompt_type="instruction_tuning", maximum_neighbor=8, 
                                          inductive=False, full_graph_data=None):
    """Prepare data for different prompt types with attack support"""
    if inductive:
        raw_texts = graph_data.raw_texts
        
        if data_type == "train":
            focus_nodes = list(range(graph_data.num_nodes))
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist()
            known_label_original_ids = set(original_node_ids)
        elif data_type == "val":
            train_mask = full_graph_data.train_mask
            train_original_ids = train_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist()
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist()
            
            val_nodes_in_full_graph = set(original_node_ids) - set(train_original_ids)
            focus_nodes = [i for i, orig_id in enumerate(original_node_ids) if orig_id in val_nodes_in_full_graph]
            known_label_original_ids = set(original_node_ids)
        elif data_type == "test":
            train_val_mask = full_graph_data.train_mask | full_graph_data.val_mask
            test_mask = full_graph_data.test_mask
            test_original_ids = set(test_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist())
            
            original_node_ids = graph_data.node_ids.cpu().numpy().tolist() if hasattr(graph_data, 'node_ids') else list(range(graph_data.num_nodes))
            focus_nodes = [i for i, orig_id in enumerate(original_node_ids) if orig_id in test_original_ids]
            known_label_original_ids = set(train_val_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist())
    else:
        raw_texts = graph_data.raw_texts
        focus_mask = {"train": graph_data.train_mask, "val": graph_data.val_mask, "test": graph_data.test_mask}[data_type]
        focus_nodes = focus_mask.nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
        train_val_nodes = (graph_data.train_mask | graph_data.val_mask).nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
        known_label_original_ids = set(train_val_nodes)
        original_node_ids = None
    
    # Prepare edge list and degree information if neighbor information is needed
    edge_list = None
    node_degrees = None
    if requires_neighbor_info(prompt_type):
        edge_index = graph_data.edge_index.cpu()
        
        # Apply similarity filtering for simf during test phase only
        if prompt_type == "simf" and data_type == "test":
            # Check if node embeddings are available
            if hasattr(graph_data, 'x') and graph_data.x is not None:
                print(f"SimF: Applying similarity filtering during {data_type} phase...")
                # Ensure both embeddings and edge_index are on the same device for similarity filtering
                node_embeddings_cpu = graph_data.x.cpu()
                edge_index = filter_edges_by_similarity(node_embeddings_cpu, edge_index, threshold=0.5)
            else:
                print(f"SimF: Warning - Node embeddings not available, using original edges")
        
        edge_list = prepare_edge_list(edge_index, graph_data.num_nodes)
        node_degrees = torch.zeros(graph_data.num_nodes, dtype=torch.long)
        for i in range(graph_data.num_nodes):
            node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
    
    # For noisetxt: prepare text replacement strategy during training
    modified_raw_texts = list(raw_texts)
    if prompt_type == "noisetxt" and data_type == "train":
        avg_degree = node_degrees.float().mean().item() if node_degrees is not None else 0
        
        high_degree_nodes = []
        for node_id in focus_nodes:
            if node_degrees is not None and node_degrees[node_id].item() > avg_degree:
                high_degree_nodes.append(node_id)
        
        num_to_replace = max(1, int(len(high_degree_nodes) * 0.1))
        if high_degree_nodes:
            nodes_to_replace = random.sample(high_degree_nodes, min(num_to_replace, len(high_degree_nodes)))
            
            for node_id in nodes_to_replace:
                current_label = graph_data.y[node_id].item()
                
                replacement_candidates = []
                for candidate_id in range(graph_data.num_nodes):
                    if (candidate_id != node_id and 
                        graph_data.y[candidate_id].item() != current_label):
                        if inductive:
                            candidate_original_id = original_node_ids[candidate_id] if original_node_ids else candidate_id
                            if candidate_original_id in known_label_original_ids:
                                replacement_candidates.append(candidate_id)
                        else:
                            if candidate_id in known_label_original_ids:
                                replacement_candidates.append(candidate_id)
                
                if replacement_candidates:
                    replacement_node = random.choice(replacement_candidates)
                    replacement_text = raw_texts[replacement_node]
                    modified_raw_texts[node_id] = replacement_text
                    replacement_label = classes[dataset_name][graph_data.y[replacement_node].item()]
                    current_label_name = classes[dataset_name][current_label]
                    print(f"NoiseText: Replaced text for node {node_id} (label: {current_label_name}) with text from node {replacement_node} (label: {replacement_label})")
    
    elif prompt_type == "noisefull" and data_type == "train":
        avg_degree = node_degrees.float().mean().item() if node_degrees is not None else 0
        
        high_degree_nodes = []
        for node_id in focus_nodes:
            if node_degrees is not None and node_degrees[node_id].item() > avg_degree:
                high_degree_nodes.append(node_id)
        
        num_to_replace = max(1, int(len(high_degree_nodes) * 0.1))
        if high_degree_nodes:
            nodes_to_replace = random.sample(high_degree_nodes, min(num_to_replace, len(high_degree_nodes)))
            
            for node_id in nodes_to_replace:
                current_label = graph_data.y[node_id].item()
                
                replacement_candidates = []
                for candidate_id in range(graph_data.num_nodes):
                    if (candidate_id != node_id and 
                        graph_data.y[candidate_id].item() != current_label):
                        if inductive:
                            candidate_original_id = original_node_ids[candidate_id] if original_node_ids else candidate_id
                            if candidate_original_id in known_label_original_ids:
                                replacement_candidates.append(candidate_id)
                        else:
                            if candidate_id in known_label_original_ids:
                                replacement_candidates.append(candidate_id)
                
                if replacement_candidates:
                    replacement_node = random.choice(replacement_candidates)
                    replacement_text = raw_texts[replacement_node]
                    modified_raw_texts[node_id] = replacement_text
                    replacement_label = classes[dataset_name][graph_data.y[replacement_node].item()]
                    current_label_name = classes[dataset_name][current_label]
                    print(f"NoiseFull Text: Replaced text for node {node_id} (label: {current_label_name}) with text from node {replacement_node} (label: {replacement_label})")
    
    data_contents = []
    
    for local_node_id in focus_nodes: 
        origin_txt = modified_raw_texts[local_node_id]
        label = classes[dataset_name][graph_data.y[local_node_id].item()]
        
        data_item = {
            "id": original_node_ids[local_node_id] if inductive else local_node_id,
            "input": origin_txt, 
            "output": label
        }
        
        if requires_neighbor_info(prompt_type) and edge_list is not None:
            neighbors = edge_list[local_node_id]
            
            if prompt_type == "rand":
                if len(neighbors) > maximum_neighbor:
                    neighbors = random.sample(neighbors, maximum_neighbor)
            elif prompt_type == "noise" and data_type == "train":
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
                
                if random.random() < 0.1:
                    available_nodes = [i for i in range(graph_data.num_nodes) 
                                     if i != local_node_id and i not in neighbors]
                    if available_nodes:
                        noise_neighbor = random.choice(available_nodes)
                        
                        if len(neighbors) < maximum_neighbor:
                            neighbors.append(noise_neighbor)
                            print(f"Added noise neighbor for node {local_node_id}: added node {noise_neighbor}")
                        else:
                            replace_idx = random.randint(0, len(neighbors) - 1)
                            replaced_neighbor = neighbors[replace_idx]
                            neighbors[replace_idx] = noise_neighbor
                            print(f"Added noise neighbor for node {local_node_id}: replaced neighbor {replaced_neighbor} with node {noise_neighbor}")
            elif prompt_type == "noisefull" and data_type == "train":
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
                
                if random.random() < 0.1:
                    available_nodes = [i for i in range(graph_data.num_nodes) 
                                     if i != local_node_id and i not in neighbors]
                    if available_nodes:
                        noise_neighbor = random.choice(available_nodes)
                        
                        if len(neighbors) < maximum_neighbor:
                            neighbors.append(noise_neighbor)
                            print(f"NoiseFull: Added noise neighbor for node {local_node_id}: added node {noise_neighbor}")
                        else:
                            replace_idx = random.randint(0, len(neighbors) - 1)
                            replaced_neighbor = neighbors[replace_idx]
                            neighbors[replace_idx] = noise_neighbor
                            print(f"NoiseFull: Added noise neighbor for node {local_node_id}: replaced neighbor {replaced_neighbor} with node {noise_neighbor}")
            else:
                if len(neighbors) > maximum_neighbor:
                    neighbor_degrees = [(neigh, node_degrees[neigh].item()) for neigh in neighbors]
                    neighbor_degrees.sort(key=lambda x: x[1], reverse=True)
                    neighbors = [neigh for neigh, _ in neighbor_degrees[:maximum_neighbor]]
            
            neighbor_texts = [modified_raw_texts[neigh] for neigh in neighbors]
            
            if prompt_type in ["neighbor_label"]:
                neighbor_labels = []
                for neigh in neighbors:
                    if inductive:
                        neighbor_original_id = original_node_ids[neigh]
                        if neighbor_original_id in known_label_original_ids:
                            neighbor_labels.append(classes[dataset_name][graph_data.y[neigh].item()])
                        else:
                            neighbor_labels.append("unknown")
                    else:
                        if neigh in known_label_original_ids:
                            neighbor_labels.append(classes[dataset_name][graph_data.y[neigh].item()])
                        else:
                            neighbor_labels.append("unknown")
            else:
                neighbor_labels = [None] * len(neighbors)
            
            data_item["neighbor_texts"] = neighbor_texts
            data_item["neighbor_labels"] = neighbor_labels
        
        data_contents.append(data_item)
    
    # Add auto training data augmentation if enabled  
    if prompt_type in ["auto", "simf"] and data_type == "train":
        raise ValueError("Prompt type only support inductive setting")

    if data_type != "test":      
        random.shuffle(data_contents)
    
    return data_contents


# Import the existing tokenization functions from the original train.py
from train import tokenizer_instruction_tuning_data, tokenizer_test_data, get_eval_steps, run_inference

def run_auto_inference_with_structure(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data=None, prompt_type='auto'):
    """Legacy wrapper - use shared implementation from auto_utils.py"""
    from auto_utils import run_auto_inference_with_structure as shared_auto_inference
    return shared_auto_inference(trainer, tokenizer, test_contents, dataset_name, data_args, accelerator, model_name, graph_data, is_attack_eval=True, prompt_type=prompt_type)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--llm', type=str, default="Mistral-7B")
    parser.add_argument('--dataset', type=str, default='cora')
    parser.add_argument('--prompt_type', type=str, default='instruction_tuning', 
                       help=f'Type of prompt template to use. Available types: {", ".join(get_available_prompt_types())}')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--max_txt_length', type=int, default=128)
    parser.add_argument('--max_origin_txt_length', type=int, default=128)
    parser.add_argument('--max_ans_length', type=int, default=16)
    parser.add_argument('--maximum_neighbor', type=int, default=6)
    parser.add_argument('--num_epoch', type=int, default=2)
    parser.add_argument('--re_split', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--lora_r', type=int, default=8)
    parser.add_argument('--lora_alpha', type=int, default=16)
    parser.add_argument('--lora_dropout', type=float, default=0.1)
    parser.add_argument('--skip_inference', action='store_true', help='Skip inference after training')
    
    # Attack-specific arguments
    parser.add_argument('--attack', type=str, default="pgd", help="Attack type (pgd, grbcd, prbcd, text_fooler, etc.)")
    parser.add_argument('--atk_type', type=str, default="structure", choices=["structure", "text", "hybrid"], help="Attack category")
    parser.add_argument('--ptb_rate', type=float, default=0.1, help="Perturbation rate")
    parser.add_argument('--atk_emb_type', type=str, default="bow", help="Embedding type used for attack")
    parser.add_argument('--atk_seed', type=int, default=0, help="Seed used for attack generation")
    
    args = parser.parse_args()
    
    # Validate prompt_type and dataset combination
    available_prompt_types = get_available_prompt_types()
    if args.prompt_type not in available_prompt_types:
        raise ValueError(f"Invalid prompt_type: {args.prompt_type}. Available types: {available_prompt_types}")
    
    supported_datasets = get_supported_datasets(args.prompt_type)
    if args.dataset not in supported_datasets:
        raise ValueError(f"Dataset {args.dataset} not supported for prompt_type {args.prompt_type}. Supported datasets: {supported_datasets}")
    
    print(f"Using prompt_type: {args.prompt_type} for dataset: {args.dataset}")
    print(f"Attack Info: {args.attack} with ptb_rate={args.ptb_rate}")
    if args.prompt_type == "simf":
        print("SimF: Similarity filtering will be applied during test phase")
    
    # Create dataclass instances
    model_args = ModelArguments(model_name_or_path=args.llm)
    data_args = DataArguments(
        dataset=args.dataset,
        prompt_type=args.prompt_type,
        max_txt_length=args.max_txt_length,
        max_origin_txt_length=args.max_origin_txt_length,
        max_ans_length=args.max_ans_length,
        re_split=args.re_split,
        maximum_neighbor=args.maximum_neighbor,
        attack=args.attack,
        atk_type=args.atk_type,
        ptb_rate=args.ptb_rate,
        atk_emb_type=args.atk_emb_type,
        atk_seed=args.atk_seed
    )
    data_args.batch_size = args.batch_size  # Add batch_size to data_args for inference
    
    lora_args = LoraArguments(
        lora_r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout
    )
    
    set_seed(args.seed)
    
    # Prepare attack metadata
    atk_meta_info = {
        'attack': args.attack,
        'ptb_rate': args.ptb_rate,
        'atk_emb_type': args.atk_emb_type,
        'seed': args.atk_seed,
        'atk_type': args.atk_type
    }
    
    # Unified is_inductive logic
    is_inductive = (args.re_split == 2) or (args.re_split == 0 and args.dataset == "arxiv")
    
    # For auto prompt type, only support inductive setting
    if args.prompt_type in ["auto", "simf"] and not is_inductive:
        raise ValueError("Prompt type only supports inductive setting (re_split=2 or arxiv with re_split=0)")
    
    # Generate model save path using unified path management based on re_split logic
    if is_inductive:
        if args.prompt_type == "simf":
            actual_prompt_for_path = "neighbor"
        else:
            actual_prompt_for_path = args.prompt_type
            
        save_path = get_model_save_path(
            model_name="InstructionTuning",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=None,
            prompt_type=actual_prompt_for_path,
            num_epoch=args.num_epoch
        )
    else:
        if args.prompt_type == "simf":
            actual_prompt_for_path = "neighbor"
        else:
            actual_prompt_for_path = args.prompt_type
            
        save_path = get_model_save_path(
            model_name="InstructionTuning",
            dataset=args.dataset,
            re_split=args.re_split,
            llm=args.llm,
            seed=args.seed,
            atk_name=f"{args.attack}_ptb{int(args.ptb_rate*100)}",
            prompt_type=actual_prompt_for_path,
            num_epoch=args.num_epoch
        )
    
    alg_dir = f"../../results/InstructionTuning"
    save_dir = save_path
    
    # Initialize accelerator
    accelerator = Accelerator(
        mixed_precision='bf16',
        gradient_accumulation_steps=1,
        project_dir=save_dir,
    )
    
    if accelerator.is_main_process:
        print(f"Model save path for attack evaluation: {save_path}")
        if is_inductive:
            if args.prompt_type == "simf":
                print(f"Using neighbor model for SimF similarity filtering")
            else:
                print(f"Using {args.prompt_type} model for inductive evaluation")
        else:
            if args.prompt_type == "simf":
                print(f"Using neighbor model for SimF similarity filtering")
            else:
                print(f"Using {args.prompt_type} model for transductive evaluation")
    
    device = accelerator.device
    
    # Load tokenizer and model
    llm_path = llm_paths[args.llm]
    tokenizer = AutoTokenizer.from_pretrained(llm_path)
    
    # Universal tokenizer setup for different models
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
            tokenizer.pad_token_id = tokenizer.eos_token_id
        else:
            tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    tokenizer.padding_side = 'left'
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        llm_path,
        torch_dtype=torch.bfloat16,
        device_map=None,  # Let accelerator handle device mapping
    )
    
    # Resize token embeddings if we added new tokens
    if len(tokenizer) > model.config.vocab_size:
        model.resize_token_embeddings(len(tokenizer))
    
    # Setup LoRA
    peft_config = LoraConfig(
        task_type="CAUSAL_LM", 
        inference_mode=False, 
        r=lora_args.lora_r, 
        lora_alpha=lora_args.lora_alpha, 
        lora_dropout=lora_args.lora_dropout,
        target_modules=lora_args.lora_target_modules
    )
    ft_model = get_peft_model(model, peft_config)
    
    # Move model to device
    ft_model = ft_model.to(device)
    
    # Check if trained model already exists based on re_split logic
    model_files = ["adapter_config.json", ["adapter_model.bin", "adapter_model.safetensors"]]  # LoRA adapter files, either .bin or .safetensors
    should_train = True
    train_secs = 0
    
    def check_model_files(path, required_files):
        for file in required_files:
            if isinstance(file, list):  # For files where we accept multiple formats
                if not any(os.path.exists(os.path.join(path, f)) for f in file):
                    return False
            elif not os.path.exists(os.path.join(path, file)):
                return False
        return True
    
    if is_inductive:
        # For inductive setting, must load clean pre-trained model
        if not check_model_files(save_path, model_files):
            if args.prompt_type == "simf":
                model_type = "neighbor-based"
            else:
                model_type = f"{args.prompt_type} pre-trained"
            raise FileNotFoundError(f"{model_type} LoRA model not found: {save_path}")
        try:
            ft_model = AutoPeftModelForCausalLM.from_pretrained(
                save_path,
                torch_dtype=torch.bfloat16,
                use_safetensors=True
            )
            ft_model = ft_model.to(device)
            should_train = False
            if accelerator.is_main_process:
                if args.prompt_type == "simf":
                    print(f"Loaded neighbor-based pre-trained model for SimF evaluation: {save_path}")
                else:
                    print(f"Loaded {args.prompt_type} pre-trained model for inductive attack evaluation: {save_path}")
        except Exception as e:
            if accelerator.is_main_process:
                if args.prompt_type == "simf":
                    model_type = "neighbor-based"
                else:
                    model_type = f"{args.prompt_type} pre-trained"
                print(f"Failed to load {model_type} model: {e}")
            raise e
    elif args.re_split == 1:
        # For re_split=1, always retrain attack model from scratch
        should_train = True
        if accelerator.is_main_process:
            print(f"Will retrain attack-specific model for supervised setting: {save_path}")
    else:
        # For other transductive settings (re_split=0, non-arxiv), check if attack model exists
        if check_model_files(save_path, model_files):
            try:
                ft_model = AutoPeftModelForCausalLM.from_pretrained(
                    save_path,
                    torch_dtype=torch.bfloat16,
                    use_safetensors=True
                )
                ft_model = ft_model.to(device)
                should_train = False
                if accelerator.is_main_process:
                    if args.prompt_type == "simf":
                        print(f"Loaded neighbor-based model for SimF evaluation: {save_path}")
                    else:
                        print(f"Loaded existing attack-specific model: {save_path}")
            except Exception as e:
                if accelerator.is_main_process:
                    model_type = "neighbor-based" if args.prompt_type == "simf" else "attack"
                    print(f"Failed to load existing {model_type} model: {e}")
                    print("Will proceed with training...")
                should_train = True
        else:
            should_train = True
            if accelerator.is_main_process:
                model_type = "neighbor-based" if args.prompt_type == "simf" else "attack"
                print(f"No existing {model_type} model found, will train: {save_path}")
    
    # Prepare attacked data
    if not is_inductive:
        # Transductive setting
        graph_data = load_atk_graph_dataset(dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info, re_split=args.re_split)
        train_contents = prepare_graph_instruction_tuning_data_atk(graph_data, "train", args.dataset, args.prompt_type, args.maximum_neighbor)
        val_contents = prepare_graph_instruction_tuning_data_atk(graph_data, "val", args.dataset, args.prompt_type, args.maximum_neighbor) 
        test_contents = prepare_graph_instruction_tuning_data_atk(graph_data, "test", args.dataset, args.prompt_type, args.maximum_neighbor)
    
    else:
        # Inductive setting (includes auto prompt special handling)
        # Load attacked data
        if accelerator.is_main_process:
            print(f"Loading inductive attack data for {args.dataset} with attack: {args.attack}")
            print(f"Attack metadata: {atk_meta_info}")
        
        graph_data, (train_data, val_data, test_data) = load_inductive_atk_graph_dataset(dataset_name=args.dataset, device=device, atk_meta_info=atk_meta_info, re_split=args.re_split, seed=args.seed)
        # Load roberta embeddings for simf and auto prompts (required for their functionality)
        if args.prompt_type in ['simf', 'auto']:
            PATH = "/path/to/GraphAD_data"
            node_feat = torch.load(f"{PATH}/datasets/roberta/{args.dataset}.pt", map_location=device, weights_only=False).to(device).type(torch.float)
            graph_data.x = node_feat
            if accelerator.is_main_process:
                print(f"{args.prompt_type.upper()}: Loaded roberta embeddings for {args.dataset} (shape: {node_feat.shape})")

        if accelerator.is_main_process:
            print(f"Loaded attack data successfully!")
            print(f"Test data edge_index shape: {test_data.edge_index.shape}")
            print(f"Test data num_nodes: {test_data.num_nodes}")
            print(f"Test data num_edges: {test_data.edge_index.shape[1] if test_data.edge_index.numel() > 0 else 0}")
            
            # Load clean data for comparison
            try:
                print("Loading clean data for comparison...")
                clean_graph_data, (clean_train, clean_val, clean_test) = load_inductive_graph_dataset(
                    dataset_name=args.dataset, device=device, re_split=args.re_split, seed=args.seed
                )
                clean_edges = clean_test.edge_index.shape[1] if clean_test.edge_index.numel() > 0 else 0
                attack_edges = test_data.edge_index.shape[1] if test_data.edge_index.numel() > 0 else 0
                
                print(f"Clean test data num_edges: {clean_edges}")
                print(f"Attack test data num_edges: {attack_edges}")
                
                if clean_edges == attack_edges:
                    print("⚠️  WARNING: Clean and attack data have same number of edges!")
                    print("This suggests attack might not be applied correctly.")
                else:
                    edge_diff = abs(clean_edges - attack_edges)
                    edge_diff_ratio = edge_diff / clean_edges if clean_edges > 0 else 0
                    print(f"✅ Edge difference: {edge_diff} ({edge_diff_ratio:.2%})")
                    print(f"Expected perturbation rate: {args.ptb_rate:.1%}")
                    
                del clean_graph_data, clean_train, clean_val, clean_test
                
            except Exception as e:
                print(f"Could not load clean data for comparison: {e}")
        
        # For inductive setting, only prepare test data since we're not training
        # Train and val data preparation is skipped to avoid unnecessary noise addition
        if accelerator.is_main_process:
            print("Inductive setting: Skipping train/val data preparation, only preparing test data")
        
        test_contents = prepare_graph_instruction_tuning_data_atk(test_data, "test", args.dataset, args.prompt_type, args.maximum_neighbor, inductive=True, 
                                                              full_graph_data=graph_data)
        
        train_contents = []
        val_contents = []
    
    # Create training datasets (minimal for inductive, full for transductive)
    if is_inductive:
        # For inductive, create minimal dummy datasets since we won't train

        dummy_sample = {
            "input_ids": torch.tensor([[0]]),
            "attention_mask": torch.tensor([[1]]),
            "labels": torch.tensor([[-100]])
        }
        train_dataset = TextDataset(dummy_sample)
        val_dataset = TextDataset(dummy_sample)
        
        if accelerator.is_main_process:
            print("Created minimal dummy datasets for inductive setting (no training)")
    else:
        # For transductive, create full datasets as usual
        train_encodings = tokenizer_instruction_tuning_data(
            train_contents, tokenizer, args.dataset, data_args.prompt_type, args.llm,
            max_txt_length=data_args.max_txt_length, 
            max_origin_txt_length=data_args.max_origin_txt_length, 
            max_ans_length=data_args.max_ans_length
        )
        train_dataset = TextDataset(train_encodings) 
        
        val_encodings = tokenizer_instruction_tuning_data(
            val_contents, tokenizer, args.dataset, data_args.prompt_type, args.llm,
            max_txt_length=data_args.max_txt_length, 
            max_origin_txt_length=data_args.max_origin_txt_length, 
            max_ans_length=data_args.max_ans_length
        )
        val_dataset = TextDataset(val_encodings)
    
    # Get evaluation steps
    eval_steps = get_eval_steps(args.dataset, args.re_split)
    if args.prompt_type in ["auto"]:
        eval_steps *= 2
    
    training_args = HFTrainingArguments(
        output_dir=save_dir,
        learning_rate=1e-4,                           
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size * 2,
        num_train_epochs=args.num_epoch,
        weight_decay=0.01,
        eval_strategy=IntervalStrategy.STEPS,
        eval_steps=eval_steps,
        save_steps=eval_steps,
        save_total_limit=1,
        load_best_model_at_end=True,
        report_to="none",
        warmup_ratio=0.1,                            
        lr_scheduler_type="cosine",                   
        optim="adamw_torch",                         
        bf16=True,                                   
        ddp_find_unused_parameters=False,
    )
    
    # Setup trainer
    trainer = Trainer(
        model=ft_model, 
        args=training_args,
        train_dataset=train_dataset, 
        eval_dataset=val_dataset,  
        tokenizer=tokenizer,
        data_collator=None,
    )
    
    # Prepare everything for distributed training
    trainer = accelerator.prepare(trainer)
    
    # Conditional training
    if should_train:
        # Train
        st_time = time.time()
        trainer.train()
        train_secs = time.time() - st_time
        
        # Save the final model
        if accelerator.is_main_process:
            # Save the LoRA adapter to the unified save_path
            trainer.model.save_pretrained(save_path)
            
            # Save training info
            training_info = {
                "dataset": args.dataset,
                "model": args.llm,
                "prompt_type": args.prompt_type,
                "seed": args.seed,
                "train_time_minutes": train_secs/60,
                "lora_adapter_path": save_path,
                "base_model_path": llm_path,
                "attack": args.attack,
                "ptb_rate": args.ptb_rate,
                "atk_emb_type": args.atk_emb_type,
                "atk_seed": args.atk_seed
            }
    else:
        pass
    
    # Run inference if not skipped
    if not args.skip_inference:
        # Load attacked node information for detailed analysis (only for gpt text attacks)
        attacked_node_ids = None
        if args.attack == 'gpt' and args.atk_type == 'text':
            try:
                import json
                from common import ATKG_PATH
                
                if is_inductive:
                    # Inductive setting - load attacked node IDs
                    atk_path = f"{ATKG_PATH}/{args.dataset}/llm_gpt-4o-mini_inductive/attacked_texts_seed{args.atk_seed}_ptb{int(args.ptb_rate*100)}.json"

                if os.path.exists(atk_path):
                    with open(atk_path, 'r') as f:
                        attacked_data = json.load(f)
                    
                    if isinstance(attacked_data, dict) and "attacked_texts" in attacked_data:
                        attacked_texts_data = attacked_data["attacked_texts"]
                        attacked_node_ids = set(int(item['node_id']) for item in attacked_texts_data)
                    else:
                        attacked_node_ids = set(int(k) for k in attacked_data.keys())
                    
                    if accelerator.is_main_process:
                        print(f"GPT Attack Analysis: Loaded {len(attacked_node_ids)} attacked node IDs from {atk_path}")
                else:
                    if accelerator.is_main_process:
                        print(f"Warning: Attack file not found: {atk_path}")
            except Exception as e:
                if accelerator.is_main_process:
                    print(f"Failed to load attacked node information: {e}")
                attacked_node_ids = None
        # Define correct file suffix and setting name for different re_split values
        re_split_str, setting_name = get_setting_info(args.re_split)
        
        st_time = time.time()
        
        # Use different inference functions based on prompt type
        if args.prompt_type in ["auto"]:
            # For auto prompt, use graph_data which must have embeddings loaded
            # Verify embeddings are present for auto prompt
            if not hasattr(graph_data, 'x') or graph_data.x is None:
                raise ValueError(f"Auto prompt requires node embeddings but graph_data.x is not available. "
                               f"Embeddings should be loaded from {PATH}/datasets/roberta/{args.dataset}.pt")
            pred_labels, gt_labels, results = run_auto_inference_with_structure(
                trainer, tokenizer, test_contents, args.dataset, data_args, accelerator, args.llm, graph_data, args.prompt_type
            )
        else:
            pred_labels, gt_labels, results = run_inference(
                trainer, tokenizer, test_contents, args.dataset, data_args, accelerator, args.llm
            )
        
        inference_secs = time.time() - st_time
        
        # Compute metrics
        acc, macro_f1, weight_f1 = compute_acc_and_f1(pred_labels, gt_labels)
        
        # Calculate unknown predictions ratio
        unknown_count = sum(1 for pred in pred_labels if pred == UNKNOW)
        unknown_ratio = unknown_count / len(pred_labels) if len(pred_labels) > 0 else 0
        
        if accelerator.is_main_process:
            print(f"Attack Results - {args.attack} (ptb_rate={args.ptb_rate}):")
            print(f"Accuracy: {acc:.4f}")
            print(f"Macro F1-Score: {macro_f1:.4f}")
            print(f"Weighted F1-Score: {weight_f1:.4f}")
            print(f"Inference time: {inference_secs:.2f} seconds")
            print(f"Unknown predictions: {unknown_count}/{len(pred_labels)} ({unknown_ratio:.2%})")
            
            # For auto prompt, also log attack detection statistics
            if args.prompt_type in ["auto"]:
                text_attacked_detected = sum(1 for r in results if r.get("stage1_pred") == "text_attacked")
                stage2_used = sum(1 for r in results if r.get("prediction_source") == "text_recovery")
                print(f"Auto-specific stats:")
                print(f"  - Nodes detected as text_attacked: {text_attacked_detected}")
                print(f"  - Nodes using stage2 recovery: {stage2_used}")
                
                # Additional analysis for GPT text attacks
                if args.attack == 'gpt' and args.atk_type == 'text' and attacked_node_ids is not None:
                    print(f"\n=== GPT Text Attack Detailed Analysis ===")
                    
                    # Prepare degree information based on inductive setting
                    if is_inductive:
                        # Inductive setting - use test_data
                        edge_index = test_data.edge_index.cpu()
                        node_degrees = torch.zeros(test_data.num_nodes, dtype=torch.long)
                        for i in range(test_data.num_nodes):
                            node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
                        
                        # Inductive setting: attacked_node_ids are global IDs, need to map to local test indices
                        test_attacked_nodes = []
                        test_non_attacked_nodes = []
                        
                        for i, result in enumerate(results):
                            node_id = result["idx"]
                            if node_id in attacked_node_ids:
                                test_attacked_nodes.append(i)
                            else:
                                test_non_attacked_nodes.append(i)
                    else:
                        # Transductive setting: use graph_data
                        edge_index = graph_data.edge_index.cpu()
                        node_degrees = torch.zeros(graph_data.num_nodes, dtype=torch.long)
                        for i in range(graph_data.num_nodes):
                            node_degrees[i] = (edge_index[0] == i).sum() + (edge_index[1] == i).sum()
                        
                        # Transductive setting: use node IDs directly
                        test_mask = graph_data.test_mask
                        test_node_indices = test_mask.nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
                        
                        test_attacked_nodes = [i for i, node_id in enumerate(test_node_indices) 
                                             if node_id in attacked_node_ids]
                        test_non_attacked_nodes = [i for i, node_id in enumerate(test_node_indices) 
                                                 if node_id not in attacked_node_ids]
                    
                    # 1. Check how many attacked nodes are correctly identified as text_attacked
                    correctly_detected_attacked = []
                    missed_attacked = []
                    
                    for i in test_attacked_nodes:
                        if i < len(results) and results[i].get("stage1_pred") == "text_attacked":
                            correctly_detected_attacked.append(i)
                        else:
                            missed_attacked.append(i)
                    
                    # 2. Check how many attacked nodes are correctly recovered in stage2
                    correctly_recovered_attacked = []
                    
                    for i in correctly_detected_attacked:
                        if i < len(results):
                            # Check if node was recovered via stage2 and if prediction is correct
                            final_pred = results[i].get("final_pred", "")
                            ground_truth = results[i].get("ground-truth", "")
                            prediction_source = results[i].get("prediction_source", "")
                            
                            if prediction_source == "text_recovery" and final_pred == ground_truth:
                                correctly_recovered_attacked.append(i)
                    
                    print(f"GPT Attack Detection Results:")
                    print(f"  - Total nodes with GPT text attacks: {len(test_attacked_nodes)}")
                    print(f"  - Correctly detected as text_attacked (Stage 1): {len(correctly_detected_attacked)}/{len(test_attacked_nodes)} ({len(correctly_detected_attacked)/len(test_attacked_nodes)*100:.1f}%)")
                    print(f"  - Missed text attacks: {len(missed_attacked)}/{len(test_attacked_nodes)} ({len(missed_attacked)/len(test_attacked_nodes)*100:.1f}%)")
                    print(f"  - Correctly recovered after detection (Stage 2): {len(correctly_recovered_attacked)}/{len(correctly_detected_attacked)} ({len(correctly_recovered_attacked)/len(correctly_detected_attacked)*100:.1f}% of detected)")
                    
                    # 3. Degree analysis: attacked vs non-attacked nodes
                    if len(test_attacked_nodes) > 0 and len(test_non_attacked_nodes) > 0:
                        # Get degrees for attacked and non-attacked test nodes
                        if is_inductive:
                            # Inductive setting
                            attacked_degrees = [node_degrees[i].item() for i in test_attacked_nodes if i < len(node_degrees)]
                            non_attacked_degrees = [node_degrees[i].item() for i in test_non_attacked_nodes if i < len(node_degrees)]
                        else:
                            # Transductive setting
                            test_node_indices = graph_data.test_mask.nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
                            attacked_degrees = [node_degrees[test_node_indices[i]].item() for i in test_attacked_nodes]
                            non_attacked_degrees = [node_degrees[test_node_indices[i]].item() for i in test_non_attacked_nodes]
                        
                        if attacked_degrees and non_attacked_degrees:
                            attacked_deg_mean = sum(attacked_degrees) / len(attacked_degrees)
                            non_attacked_deg_mean = sum(non_attacked_degrees) / len(non_attacked_degrees)
                            
                            print(f"\nDegree Distribution Analysis:")
                            print(f"  - Attacked nodes avg degree: {attacked_deg_mean:.2f} (min: {min(attacked_degrees)}, max: {max(attacked_degrees)})")
                            print(f"  - Non-attacked nodes avg degree: {non_attacked_deg_mean:.2f} (min: {min(non_attacked_degrees)}, max: {max(non_attacked_degrees)})")
                            print(f"  - Degree difference (attacked - non-attacked): {attacked_deg_mean - non_attacked_deg_mean:.2f}")
                    
                    # 4. Degree analysis: correctly detected vs missed attacked nodes
                    if len(correctly_detected_attacked) > 0 and len(missed_attacked) > 0:
                        if is_inductive:
                            # Inductive setting
                            detected_degrees = [node_degrees[i].item() for i in correctly_detected_attacked if i < len(node_degrees)]
                            missed_degrees = [node_degrees[i].item() for i in missed_attacked if i < len(node_degrees)]
                        else:
                            # Transductive setting
                            test_node_indices = graph_data.test_mask.nonzero(as_tuple=False).squeeze(-1).cpu().numpy().tolist()
                            detected_degrees = [node_degrees[test_node_indices[i]].item() for i in correctly_detected_attacked]
                            missed_degrees = [node_degrees[test_node_indices[i]].item() for i in missed_attacked]
                        
                        if detected_degrees and missed_degrees:
                            detected_deg_mean = sum(detected_degrees) / len(detected_degrees)
                            missed_deg_mean = sum(missed_degrees) / len(missed_degrees)
                            
                            print(f"\nStage 1 Detection vs Degree Analysis:")
                            print(f"  - Correctly detected attacked nodes avg degree: {detected_deg_mean:.2f}")
                            print(f"  - Missed attacked nodes avg degree: {missed_deg_mean:.2f}")
                            print(f"  - Degree difference (detected - missed): {detected_deg_mean - missed_deg_mean:.2f}")
                    
                    print(f"=" * 50)
            
            # Save results
            attack_dir = f"attack_{args.attack}_ptb{int(args.ptb_rate*100)}"
            write_dir = f"{alg_dir}/prediction/{attack_dir}"
            os.makedirs(write_dir, exist_ok=True)
            result_file = f"{write_dir}/{args.dataset}_{args.llm}_{args.prompt_type}{re_split_str}_seed{args.seed}.json"
            
            with open(result_file, "w") as f:
                for result in results:
                    result["attack"] = args.attack
                    result["ptb_rate"] = args.ptb_rate
                    result["atk_emb_type"] = args.atk_emb_type
                    result["atk_seed"] = args.atk_seed
                    result["atk_type"] = args.atk_type
                    result["model"] = args.llm
                    result["prompt_type"] = args.prompt_type
                    result["data_seed"] = args.seed
            
            print(f"Attack results saved to {result_file}")
            
            # Save summary
            summary_file = f"{alg_dir}/summary_attack.csv"
            with open(summary_file, 'a', newline='') as file:
                writer = csv.writer(file)
                writer.writerow([
                    args.dataset, args.llm, args.prompt_type, acc, macro_f1, weight_f1, 
                    setting_name, f"Seed-{args.seed}", 
                    f"Batch Size-{args.batch_size}", f"Epoch-{args.num_epoch}",
                    f"Train Minutes-{train_secs/60:.3f}", f"Inference Seconds-{inference_secs:.2f}",
                    args.attack, args.ptb_rate, args.atk_emb_type, args.atk_seed
                ])
    
    # Clean up
    accelerator.wait_for_everyone()
    accelerator.end_training()


if __name__ == "__main__":
    main() 