"""
WTGIA (Word-level Text Graph Injection Attack) for inductive setting
Implementation based on LLMGIAv2-crc FLIPGIA
"""

import os
import sys
import argparse
import torch
import numpy as np
import json
from datetime import datetime
import torch.nn.functional as F

# Add parent directory to path
sys.path.append("../")

# Third-party imports  
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

# Local imports
from common import set_seed, load_inductive_graph_dataset_for_gnn
from wtgia_utils import (
    EarlyStop, init_feat, fgsm_update_features,
    random_injection, tdgia_injection, atdgia_injection, random_class_injection,
    atdgia_ranking_select, avg_sparsity, adj_to_tensor, tensor_to_adj
)
from wtgia_text_utils import generate_wtgia_texts, save_wtgia_words

# Local Llama model path
LLAMA_MODEL_PATH = "/data/LLMBackbone/llama-3.1-8B-Instruct/"


class WTGIA:
    """
    WTGIA (Word-level Text Graph Injection Attack) 
    Based on FLIPGIA from LLMGIAv2-crc for BoW features
    """
    
    def __init__(self, epsilon=0.01, n_epoch=100, a_epoch=50, n_inject_max=60, n_edge_max=20,
                 feat_lim_min=0, feat_lim_max=1, loss=F.nll_loss, device='cpu',
                 early_stop=0, verbose=True, sequential_step=0.1,
                 feat_upd='flip', sp_level=0.05, batch_size=1,
                 injection="random", branching=False, iter_epoch=2, agia_pre=0.5):
        
        self.sequential_step = sequential_step
        self.device = device
        self.epsilon = epsilon
        self.n_epoch = n_epoch
        self.a_epoch = a_epoch
        self.n_inject_max = n_inject_max
        self.n_edge_max = n_edge_max
        self.feat_lim_min = feat_lim_min
        self.feat_lim_max = feat_lim_max
        self.loss = loss
        self.verbose = verbose
        self.sp_level = sp_level
        self.batch_size = batch_size
        self.injection = injection.lower()
        self.branching = branching
        self.iter_epoch = iter_epoch
        self.agia_pre = agia_pre
        
        # Early stop
        if early_stop:
            self.early_stop = EarlyStop(patience=early_stop, epsilon=1e-4)
        else:
            self.early_stop = early_stop
            
        # Feature update function (FGSM for BoW)
        self.feat_upd_func = fgsm_update_features
    
    def attack(self, model, adj, features, target_idx, labels=None):
        """Execute WTGIA attack"""
        model.to(self.device)
        model.eval()
        
        if labels is None:
            pred_orig = model(features, adj)
            origin_labels = torch.argmax(pred_orig, dim=1)
        else:
            origin_labels = labels.view(-1)
            pred_orig = model(features, adj)
        
        # Initialize sparsity level if not set
        if self.sp_level == 0:
            self.sp_level = avg_sparsity(features)
        
        # Initialize adjacency degrees for injected nodes
        self.adj_degs = torch.zeros((self.n_inject_max,)).long() + self.n_edge_max
        
        # Sequential injection
        n_inject_total = 0
        adj_attack = adj
        features_attack = None
        
        tot_target_nodes = len(target_idx)
        
        print(f"Starting sequential injection: {self.n_inject_max} nodes, {self.n_edge_max} edges each")
        
        while n_inject_total < self.n_inject_max:
            # Current prediction for target selection (exactly like FLIPGIA)
            if n_inject_total > 0:
                with torch.no_grad():
                    current_pred = F.softmax(model(torch.cat((features, features_attack), dim=0), adj_attack), dim=1)
            else:
                current_pred = pred_orig
            
            # Determine injection batch size (exactly like FLIPGIA)
            n_inject_cur = min(self.n_inject_max - n_inject_total,
                              max(1, int(self.n_inject_max * self.sequential_step)))
            n_target_cur = min(tot_target_nodes,
                              max(n_inject_cur * (self.n_edge_max + 1),
                                  int(tot_target_nodes * self.sequential_step)))
            
            # Select current targets
            if self.branching:
                cur_target_idx = atdgia_ranking_select(adj_attack, n_inject_cur, self.n_edge_max,
                                                      origin_labels, current_pred, target_idx,
                                                      ratio=n_target_cur/len(target_idx))
            else:
                cur_target_idx = target_idx
            
            if self.verbose:
                print(f"Sequential inject {n_inject_total + n_inject_cur}/{self.n_inject_max} nodes, "
                      f"target {len(cur_target_idx)}/{len(target_idx)} nodes")
            
            # Inject edges using specified strategy (exactly like FLIPGIA)
            if self.injection == "tdgia":
                adj_attack = tdgia_injection(adj_attack, n_inject_cur, self.n_edge_max, 
                                           origin_labels, current_pred, cur_target_idx, self.device)
            elif self.injection == "atdgia":
                adj_attack = atdgia_injection(adj_attack, n_inject_cur, self.n_edge_max, 
                                            origin_labels, current_pred, cur_target_idx, self.device)
            else:
                # Default to random injection
                adj_attack = random_injection(adj_attack, n_inject_cur, self.n_edge_max, cur_target_idx, self.device)
            
            # Initialize features for injected nodes  
            features_attack_new = init_feat(n_inject_cur, features, self.device, style="zeros",
                                           feat_lim_min=self.feat_lim_min, feat_lim_max=self.feat_lim_max)
            
            # Concatenate with existing injected features
            features_attack = torch.cat((features_attack, features_attack_new), dim=0) if features_attack is not None else features_attack_new
            
            n_inject_total += n_inject_cur

            # Update all injected features using FGSM (exactly like FLIPGIA)
            features_attack = self.feat_upd_func(
                self, model, adj_attack, features, features_attack, origin_labels, target_idx, sparsity_budget=self.sp_level, 
                batch_size=self.batch_size, verbose=True
            )
        
        if self.verbose:
            print(f"Final sparsity: {avg_sparsity(features_attack):.4f}, target: {self.sp_level:.4f}")
        
        return adj_attack, features_attack


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="WTGIA attack for inductive node classification")
    
    # Data paths
    parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data",
                       help="Root path for data")
    parser.add_argument("--save_dir", type=str, default="/path/to/GraphAD_data/atkg",
                       help="Directory to save attack results")
    
    # Dataset configuration
    parser.add_argument("--dataset", type=str, default="cora",
                       help="Dataset name")
    parser.add_argument("--emb_type", type=str, default="bow",
                       help="Embedding type")
    
    # Attack parameters  
    parser.add_argument("--n_inject", type=int, default=60,
                       help="Number of nodes to inject")
    parser.add_argument("--n_edges", type=int, default=20,
                       help="Number of edges per injected node")
    parser.add_argument("--ptb_rate", type=float, default=0.20,
                       help="Target node selection rate")
    
    # WTGIA-specific parameters
    parser.add_argument("--epsilon", type=float, default=0.01,
                       help="Step size for FGSM")
    parser.add_argument("--n_epoch", type=int, default=100,
                       help="Number of epochs for feature optimization")
    parser.add_argument("--sequential_step", type=float, default=0.2,
                       help="Sequential injection step size")
    parser.add_argument("--sp_level", type=float, default=0.05,
                       help="Sparsity level for BoW features")
    parser.add_argument("--batch_size", type=int, default=1,
                       help="Batch size for feature flipping")
    parser.add_argument("--branching", action='store_true',
                       help="Use branching target selection")
    parser.add_argument("--injection", type=str, default="random",
                       choices=["random", "tdgia", "atdgia", "meta", "agia"],
                       help="Injection strategy")
    
    # Training configuration
    parser.add_argument("--seeds", type=int, default=3,
                       help="Number of random seeds")
    parser.add_argument("--re_split", type=int, default=2,
                       help="Data split configuration")
    parser.add_argument("--patience", type=int, default=15,
                       help="Early stopping patience")
    parser.add_argument("--max_epochs", type=int, default=200,
                       help="Maximum training epochs")
    
    # System configuration
    parser.add_argument("--device", type=int, default=0,
                       help="GPU device ID")
    parser.add_argument("--verbose", action='store_true',
                       help="Verbose output")
    
    # Evaluation parameters
    parser.add_argument("--eval_robo", action='store_true',
                       help="Evaluate robustness")
    parser.add_argument("--batch_eval", action='store_true', 
                       help="Batch evaluation")
    parser.add_argument("--runs", type=int, default=10,
                       help="Number of evaluation runs")
    
    # Text generation parameters
    parser.add_argument("--llm_model", type=str, default="llama-3.1-8B",
                       help="LLM model for text generation")
    parser.add_argument("--model_path", type=str, default=LLAMA_MODEL_PATH,
                       help="Path to local Llama model")
    parser.add_argument("--text_save_dir", type=str, default=None,
                       help="Directory to save generated texts (default: auto)")
    
    return parser.parse_args()


def main():
    """Main execution function"""
    args = parse_arguments()
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    
    print(f"Running WTGIA attack on {args.dataset}")
    print(f"Injecting {args.n_inject} nodes with {args.n_edges} edges each")
    print(f"Using {args.emb_type} embeddings with sparsity level {args.sp_level}")
    
    # Create save directory
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Configure seeds
    if args.dataset == 'arxiv':
        seeds = [0]
        args.re_split = 0
    else:
        seeds = range(args.seeds)
    
    all_clean_accs = []
    all_attack_accs = []
    
    for seed in seeds:
        print(f"\n--- Seed {seed} ---")
        set_seed(seed)
        
        # Ensure BOW embeddings exist
        if args.emb_type == "bow":
            # Check if BOW embeddings exist, create if not
            bow_path = f"{args.root_path}/datasets/{args.emb_type}/{args.dataset}.pt"
            vocab_path = f"{args.root_path}/datasets/vocab/{args.dataset}/bow_vocabulary.pkl"
            
            if not os.path.exists(bow_path) or not os.path.exists(vocab_path):
                print(f"BOW embeddings not found, please run:")
                print(f"cd LLMEncoder/GNN && python embedding.py --dataset {args.dataset} --encoder_name bow")
                print("Exiting...")
                return
        
        # Load data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split,
            path_prefix=args.root_path, emb_model=args.emb_type, seed=seed
        )
        
        train_data = train_data.to(device)
        val_data = val_data.to(device)
        test_data = test_data.to(device)
        
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        num_test_nodes = full_data.test_mask.sum().item()
        
        print(f"Dataset: Features={num_features}, Classes={num_classes}, Test nodes={num_test_nodes}")
        
        # Train clean model
        print("Training clean model...")
        model = GCN(num_features, num_classes, hids=[64])
        trainer = Trainer(model, device=device, verbose=0)
        trainer.reset_optimizer(lr=0.01)
        
        ckp = ModelCheckpoint(f'wtgia_model_{args.dataset}_{seed}.pth', monitor='val_acc')
        trainer.fit((train_data, val_data), verbose=0, callbacks=[ckp], epochs=args.max_epochs)
        
        # Evaluate clean performance
        clean_acc = trainer.evaluate(test_data, mask=full_data.test_mask)['acc']
        all_clean_accs.append(clean_acc)
        print(f"Clean accuracy: {clean_acc:.4f}")
        
        # Use all test nodes as targets (following LLMNodeBed convention)
        target_nodes = torch.where(full_data.test_mask)[0]
        print(f"Using all {len(target_nodes)} test nodes as targets")
        
        # Check if attack results already exist - use consistent naming for evaluation
        attack_name = f"wtgia_{args.injection}"
        save_path = f"{args.save_dir}/{args.dataset}/{attack_name}/{args.emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
        features_save_path = f"{args.save_dir}/{args.dataset}/{attack_name}_features/{args.emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
        texts_save_path = f"{args.save_dir}/{args.dataset}/{attack_name}_texts/{args.llm_model}_{int(args.ptb_rate*100)}_{seed}.json"
        
        if os.path.exists(save_path) and os.path.exists(features_save_path) and 0:
            print("✓ Attack results already exist, loading existing results...")
            attacked_edge_index = torch.load(save_path)
            features_attack = torch.load(features_save_path).to(device)
            print(f"✓ Loaded attacked edge_index from {save_path}")
            print(f"✓ Loaded injected features from {features_save_path}")
        else:
            # Initialize WTGIA attacker
            print("Attack results not found, executing WTGIA attack...")
            n_edges = test_data.edge_index.shape[1] * args.ptb_rate / 2 // args.n_inject  
            print(n_edges)
            attacker = WTGIA(
                epsilon=args.epsilon,
                n_epoch=args.n_epoch,
                n_inject_max=args.n_inject,
                n_edge_max=n_edges,
                feat_lim_min=0,
                feat_lim_max=1,
                device=device,
                verbose=args.verbose,
                sequential_step=args.sequential_step,
                sp_level=args.sp_level,
                batch_size=args.batch_size,
                branching=args.branching,
                injection=args.injection
            )
            
            # Execute attack
            print("Executing WTGIA attack...")
            # Convert edge_index to SparseTensor for WTGIA
            from torch_sparse import SparseTensor
            row, col = test_data.edge_index
            adj_sparse = SparseTensor(row=row, col=col, 
                                     sparse_sizes=(test_data.x.size(0), test_data.x.size(0)))
            adj_attack, features_attack = attacker.attack(
                model, adj_sparse, test_data.x, target_nodes, full_data.y
            )
            
            # Save attack results (following LLMNodeBed convention)
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            
            # Convert adj_attack to edge_index format if needed
            if hasattr(adj_attack, 'coo'):
                # SparseTensor format -> edge_index
                row, col, _ = adj_attack.coo()
                attacked_edge_index = torch.stack([row, col], dim=0)
            else:
                attacked_edge_index = adj_attack
            
            print(attacked_edge_index.shape, test_data.edge_index.shape)
            # Save only edge_index like other structure attacks
            torch.save(attacked_edge_index.cpu(), save_path)
            print(f"✓ Saved attacked edge_index to {save_path}")
            
            # Also save features for WTGIA evaluation
            os.makedirs(os.path.dirname(features_save_path), exist_ok=True)
            torch.save(features_attack.cpu(), features_save_path)
            print(f"✓ Saved injected features to {features_save_path}")
            
            # Save used_words and not_used_words arrays for debugging
            #save_wtgia_words(
            #    features_attack=features_attack.cpu(),
            #    dataset_name=args.dataset,
            #    save_dir=f"{args.save_dir}/{args.dataset}/{attack_name}_debug",
            #    vocab_path=f"{args.root_path}/datasets/vocab/{args.dataset}/bow_vocabulary.pkl",
            #    ptb_rate=args.ptb_rate,
            #    seed=seed
            #)
        
        # Generate natural text from BoW features and handle re-computation
        print("\nGenerating natural text from BoW features...")
        
        # Check if text already exists
        recomputed_save_path = f"{args.save_dir}/{args.dataset}/{attack_name}_recomputed_features/{args.emb_type}_{int(args.ptb_rate*100)}_{seed}.pt"
        text_needs_regeneration = False
        
        if os.path.exists(texts_save_path):
            print(f"✓ Generated texts already exist at {texts_save_path}")
            with open(texts_save_path, 'r', encoding='utf-8') as f:
                text_data = json.load(f)
                generated_texts = text_data['texts']
            print(f"✓ Loaded {len(generated_texts)} existing texts")
            
            # Check if embeddings need re-computation (this happens when text is re-generated)
            if not os.path.exists(recomputed_save_path):
                text_needs_regeneration = True
                print("Embeddings need to be re-computed from existing texts")
        else:
            text_needs_regeneration = True
            # Create directory for texts
            os.makedirs(os.path.dirname(texts_save_path), exist_ok=True)
            
            generated_texts = generate_wtgia_texts(
                features_attack=features_attack,
                dataset_name=args.dataset,
                base_path=args.root_path,
                model_path=args.model_path,
                save_dir=None  # We'll save manually
            )
            
            # Save texts with metadata
            text_data = {
                'texts': generated_texts,
                'dataset': args.dataset,
                'attack': attack_name,
                'llm_model': args.llm_model,
                'seed': seed,
                'n_inject': args.n_inject,
                'ptb_rate': args.ptb_rate,
                'sp_level': args.sp_level
            }
            
            with open(texts_save_path, 'w', encoding='utf-8') as f:
                json.dump(text_data, f, indent=2, ensure_ascii=False)
            
            print(f"✓ Successfully generated and saved {len(generated_texts)} texts")
            
            # Show example of generated text
            if len(generated_texts) > 0:
                print(f"\nExample generated text (Node 0):")
                print(f"'{generated_texts[0][:200]}...'")
        
        # Re-compute embeddings from generated texts when needed
        if text_needs_regeneration:
            print("\nRe-computing embeddings from generated texts...")
            
            # Import embedding computation utilities from wtgia_text_utils
            from wtgia_text_utils import compute_bow_embeddings_from_texts
            
            # Re-compute BoW features from generated texts with similarity analysis
            recomputed_features = compute_bow_embeddings_from_texts(
                texts=generated_texts,
                dataset_name=args.dataset,
                vocab_path=f"{args.root_path}/datasets/vocab/{args.dataset}/bow_vocabulary.pkl",
                original_features=features_attack  # Pass original features for comparison
            )
            
            # Save re-computed features
            os.makedirs(os.path.dirname(recomputed_save_path), exist_ok=True)
            torch.save(recomputed_features.cpu(), recomputed_save_path)
            
            print(f"✓ Re-computed features shape: {recomputed_features.shape}")
            print(f"✓ Original features shape: {features_attack.shape}")
            print(f"✓ Saved re-computed features to {recomputed_save_path}")
        else:
            # Load existing re-computed features
            recomputed_features = torch.load(recomputed_save_path, map_location=device)
            print(f"✓ Loaded existing re-computed features from {recomputed_save_path}")
            print(f"✓ Re-computed features shape: {recomputed_features.shape}")
        
        # Evaluate attack (if requested)
        if args.eval_robo:
            print("\nEvaluating attack with different feature combinations...")
            
            # 1. Evaluate GCN with generated features (original features_attack)
            print("1. Evaluating with generated features...")
            attacked_features_generated = torch.cat([test_data.x, features_attack], dim=0)
            
            attacked_test_data_generated = test_data.clone()
            attacked_test_data_generated.x = attacked_features_generated
            attacked_test_data_generated.edge_index = attacked_edge_index
            
            # Update test mask to exclude injected nodes
            n_orig = full_data.test_mask.size(0)
            n_inject = features_attack.size(0)
            extended_test_mask = torch.cat([
                full_data.test_mask, 
                torch.zeros(n_inject, dtype=torch.bool, device=device)
            ])
            
            attacked_test_data_generated.y = torch.cat([attacked_test_data_generated.y, torch.zeros(n_inject, dtype=torch.int, device=device)])
            attack_acc_generated = trainer.evaluate(attacked_test_data_generated, mask=extended_test_mask)['acc']
            
            print(f"Attack accuracy (generated features): {attack_acc_generated:.4f}")
            print(f"Accuracy drop (generated features): {(clean_acc - attack_acc_generated):.4f}")
            
            # 2. Evaluate GCN with re-computed BoW features from text
            print("\n2. Evaluating with re-computed BoW features...")
            attacked_features_recomputed = torch.cat([test_data.x, recomputed_features.to(device)], dim=0)
            
            attacked_test_data_recomputed = test_data.clone()
            attacked_test_data_recomputed.x = attacked_features_recomputed
            attacked_test_data_recomputed.edge_index = attacked_edge_index
            attacked_test_data_recomputed.y = torch.cat([attacked_test_data_recomputed.y, torch.zeros(n_inject, dtype=torch.int, device=device)])
            
            attack_acc_recomputed = trainer.evaluate(attacked_test_data_recomputed, mask=extended_test_mask)['acc']
            
            print(f"Attack accuracy (re-computed BoW): {attack_acc_recomputed:.4f}")
            print(f"Accuracy drop (re-computed BoW): {(clean_acc - attack_acc_recomputed):.4f}")
            
            # Store results for summary
            all_attack_accs.append({
                'generated': attack_acc_generated,
                'recomputed': attack_acc_recomputed
            })
            
            # Calculate ASR for both
            asr_generated = (clean_acc - attack_acc_generated) / clean_acc if clean_acc > 0 else 0
            asr_recomputed = (clean_acc - attack_acc_recomputed) / clean_acc if clean_acc > 0 else 0
            
            print(f"ASR (generated features): {asr_generated * 100:.2f}%")
            print(f"ASR (re-computed BoW): {asr_recomputed * 100:.2f}%")
        
        # Cleanup
        del model, trainer
        torch.cuda.empty_cache()
    
    # Print summary
    if args.eval_robo and len(all_attack_accs) > 0:
        print(f"\n{'='*60}")
        print("WTGIA Attack Summary")
        print(f"{'='*60}")
        print(f"Clean accuracy: {np.mean(all_clean_accs):.4f} ± {np.std(all_clean_accs):.4f}")
        
        # Extract generated and recomputed accuracies
        generated_accs = [result['generated'] for result in all_attack_accs]
        recomputed_accs = [result['recomputed'] for result in all_attack_accs]
        
        print(f"Attack accuracy (generated): {np.mean(generated_accs):.4f} ± {np.std(generated_accs):.4f}")
        print(f"Attack accuracy (recomputed): {np.mean(recomputed_accs):.4f} ± {np.std(recomputed_accs):.4f}")
        
        # Calculate average accuracy drops
        clean_accs_array = np.array(all_clean_accs)
        generated_drops = clean_accs_array - np.array(generated_accs)
        recomputed_drops = clean_accs_array - np.array(recomputed_accs)
        
        print(f"Average accuracy drop (generated): {np.mean(generated_drops):.4f}")
        print(f"Average accuracy drop (recomputed): {np.mean(recomputed_drops):.4f}")
        
        # Calculate average ASRs
        generated_asrs = generated_drops / clean_accs_array
        recomputed_asrs = recomputed_drops / clean_accs_array
        
        print(f"Average ASR (generated): {np.mean(generated_asrs) * 100:.2f}%")
        print(f"Average ASR (recomputed): {np.mean(recomputed_asrs) * 100:.2f}%")


if __name__ == "__main__":
    main()