import os
import sys
import argparse
import torch
import numpy as np
import copy
import json
from datetime import datetime

# Add parent directory to path for local imports
sys.path.append("../")

# Third-party imports
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

# Local imports
from common import set_seed, load_inductive_graph_dataset_for_gnn, TextEncoder
from text_attack import (
    random_perturb, 
    apply_text_attack, 
    save_results, 
    encode_attacked_texts_efficiently, 
    degree_weighted_node_selection
)


def configure_tensorflow_gpu(device_id):
    """Configure TensorFlow GPU settings to avoid conflicts with PyTorch"""
    try:
        import tensorflow as tf
        
        # Get list of GPUs
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # Set memory growth for all GPUs to avoid OOM
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                
                # If we have the target GPU, make it visible to TensorFlow
                if device_id < len(gpus):
                    tf.config.experimental.set_visible_devices(gpus[device_id], 'GPU')
                    print(f"✓ TensorFlow using GPU {device_id}")
                else:
                    # Use CPU if target GPU not available
                    tf.config.experimental.set_visible_devices([], 'GPU')
                    print("✓ TensorFlow using CPU (GPU not available)")
                    
            except RuntimeError as e:
                print(f"⚠ TensorFlow GPU configuration failed: {e}")
        else:
            print("✓ TensorFlow using CPU (no GPU available)")
            
    except ImportError:
        print("⚠ TensorFlow not available")


def setup_gpu_device(device_id):
    """Setup GPU device for all CUDA operations"""
    if not torch.cuda.is_available():
        print("CUDA is not available. Using CPU.")
        return torch.device('cpu')
    
    if device_id < 0:
        print("Using CPU as requested.")
        return torch.device('cpu')
    
    num_gpus = torch.cuda.device_count()
    if device_id >= num_gpus:
        raise RuntimeError(f"GPU device {device_id} is not available. Only {num_gpus} GPUs are available.")
    
    # Set the default GPU device
    torch.cuda.set_device(device_id)
    device = torch.device(f'cuda:{device_id}')
    
    print(f"✓ Using GPU device: {device_id}")
    print(f"✓ GPU name: {torch.cuda.get_device_name(device_id)}")
    print(f"✓ Available GPUs: {num_gpus}")
    print(f"✓ Current device: {torch.cuda.current_device()}")
    
    # Configure TensorFlow to use the same GPU or avoid conflicts
    configure_tensorflow_gpu(device_id)
    
    return device


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Generate text attacks for inductive node classification")
    
    # Data paths
    parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data",
                       help="Root path for data")
    parser.add_argument("--graph_save_dir", type=str, default="/path/to/GraphAD_data/text_atkg",
                       help="Directory to save graph attack results")
    parser.add_argument("--atkg_save_dir", type=str, default="/path/to/GraphAD_data/atkg",
                       help="Directory to save attacked texts")
    
    # Dataset and attack configuration
    parser.add_argument("--dataset", type=str, default="cora",
                       help="Dataset name")
    parser.add_argument("--attack", type=str, default="random", 
                       choices=["random", "textfooler", "bae", "pwws", "hotflip"],
                       help="Attack method")
    parser.add_argument("--emb_type", type=str, default="bow", 
                       choices=["bow", "roberta", "MiniLM", "SentenceBert", "Mistral-7B"],
                       help="Text embedding type")
    
    # Attack parameters
    parser.add_argument("--ptb_rate", type=float, default=0.20,
                       help="Node perturbation rate")
    
    # Batch attack configuration
    parser.add_argument("--use_batch", action="store_true", default=True,
                       help="Use batch processing for attacks (default: True)")
    parser.add_argument("--attack_batch_size", type=int, default=8,
                       help="Batch size for attack processing")
    
    # Training configuration
    parser.add_argument("--seeds", type=int, default=3,
                       help="Number of random seeds")
    parser.add_argument("--re_split", type=int, default=2,
                       help="Data split configuration")
    parser.add_argument("--patience", type=int, default=15,
                       help="Early stopping patience")
    parser.add_argument("--epochs", type=int, default=200,
                       help="Maximum training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                       help="Batch size for text encoding")
    
    # System configuration
    parser.add_argument("--device", type=int, default=0,
                       help="GPU device ID")
    
    return parser.parse_args()


def save_attacked_texts(args, attack_global_indices, attacked_texts, seed):
    """Save attacked texts and their node IDs to specified directory"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"text_{args.attack}_inductive")
    os.makedirs(save_dir, exist_ok=True)
    
    # Include ptb_rate in filename to ensure parameter consistency
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    # Only save nodes that were actually modified
    modified_texts = []
    for idx, (orig_idx, new_text) in enumerate(zip(attack_global_indices, attacked_texts)):
        modified_texts.append({
            "node_id": int(orig_idx),
            "attacked_text": new_text
        })
    
    # Save with metadata for parameter validation
    data_to_save = {
        "metadata": {
            "ptb_rate": args.ptb_rate,
            "attack": args.attack,
            "emb_type": args.emb_type,
            "dataset": args.dataset,
            "num_attacked_nodes": len(attack_global_indices)
        },
        "attacked_texts": modified_texts
    }
    
    with open(save_path, 'w') as f:
        json.dump(data_to_save, f, indent=4)
    print(f"✓ Saved attacked texts to {save_path}")


def load_attacked_texts_if_exists(args, attack_global_indices, seed):
    """Load attacked texts if they already exist and parameters match, return None if not found or mismatched"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"text_{args.attack}_inductive")
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            saved_data = json.load(f)
        
        # Check if this is new format with metadata
        if "metadata" in saved_data:
            metadata = saved_data["metadata"]
            # Verify key parameters match
            if (metadata.get("ptb_rate") == args.ptb_rate and 
                metadata.get("attack") == args.attack and
                metadata.get("emb_type") == args.emb_type and
                metadata.get("num_attacked_nodes") == len(attack_global_indices)):
                
                print(f"✓ Loading existing attacked texts from {save_path}")
                attacked_texts_data = saved_data["attacked_texts"]
            else:
                print(f"⚠ Parameter mismatch in {save_path}, will regenerate")
                return None
        else:
            # Old format without metadata, assume mismatch for safety
            print(f"⚠ Old format file found {save_path}, will regenerate with new format")
            return None
        
        # Create a mapping from node_id to attacked_text
        node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
        
        # Extract attacked texts in the same order as attack_global_indices
        attacked_texts = []
        for idx in attack_global_indices:
            node_id = int(idx)
            if node_id in node_to_text:
                attacked_texts.append(node_to_text[node_id])
            else:
                print(f"⚠ Missing node {node_id} in saved data, will regenerate")
                return None  # Missing data, need to regenerate
        
        return attacked_texts
    
    return None


def save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="inductive"):
    """Save comprehensive attack results including text modification examples"""
    # Aggregate examples from all runs
    aggregated_examples = []
    for seed_examples in all_examples:
        aggregated_examples.extend(seed_examples)
    
    # Take up to 5 examples per attack type for final summary
    summary_examples = aggregated_examples[:5] if len(aggregated_examples) >= 5 else aggregated_examples
    
    # Calculate ASR statistics
    clean_accs_attacked = [result["clean_acc_attacked_nodes"] for result in all_detailed_results]
    attacked_accs_attacked = [result["attacked_acc_attacked_nodes"] for result in all_detailed_results]
    asr_accuracy_drops = [result["asr_accuracy_drop"] for result in all_detailed_results]
    asr_node_levels = [result["asr_node_level"] for result in all_detailed_results]
    
    results = {
        "experiment_info": {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "args": vars(args),
            "setting": setting
        },
        "performance": {
            "clean": {
                "mean": round(float(np.mean(clean_accs)) * 100, 2),
                "std": round(float(np.std(clean_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
            },
            "attacked": {
                "mean": round(float(np.mean(attacked_accs)) * 100, 2),
                "std": round(float(np.std(attacked_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs]
            },
            "attacked_nodes_only": {
                "clean": {
                    "mean": round(float(np.mean(clean_accs_attacked)) * 100, 2),
                    "std": round(float(np.std(clean_accs_attacked)) * 100, 2),
                    "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs_attacked]
                },
                "attacked": {
                    "mean": round(float(np.mean(attacked_accs_attacked)) * 100, 2),
                    "std": round(float(np.std(attacked_accs_attacked)) * 100, 2),
                    "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs_attacked]
                }
            },
            "asr": {
                "accuracy_drop": {
                    "mean": round(float(np.mean(asr_accuracy_drops)), 2),
                    "std": round(float(np.std(asr_accuracy_drops)), 2),
                    "all_runs": [round(float(asr), 2) for asr in asr_accuracy_drops]
                },
                "node_level": {
                    "mean": round(float(np.mean(asr_node_levels)), 2),
                    "std": round(float(np.std(asr_node_levels)), 2),
                    "all_runs": [round(float(asr), 2) for asr in asr_node_levels]
                }
            }
        },
        "text_modification_examples": summary_examples,
        "detailed_results": all_detailed_results
    }
    
    log_dir = os.path.join("./logs_text_attack", args.dataset, f"text_{args.attack}_{setting}")
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"detailed_results_{args.emb_type}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{setting.capitalize()} Text Attack Results Summary:")
    print(f"Overall Performance:")
    print(f"  Clean - Acc: {results['performance']['clean']['mean']:.2f} ± {results['performance']['clean']['std']:.2f}")
    print(f"  Attacked - Acc: {results['performance']['attacked']['mean']:.2f} ± {results['performance']['attacked']['std']:.2f}")
    print(f"Attacked Nodes Only:")
    print(f"  Clean - Acc: {results['performance']['attacked_nodes_only']['clean']['mean']:.2f} ± {results['performance']['attacked_nodes_only']['clean']['std']:.2f}")
    print(f"  Attacked - Acc: {results['performance']['attacked_nodes_only']['attacked']['mean']:.2f} ± {results['performance']['attacked_nodes_only']['attacked']['std']:.2f}")
    print(f"Attack Success Rate (ASR):")
    print(f"  Accuracy Drop: {results['performance']['asr']['accuracy_drop']['mean']:.2f} ± {results['performance']['asr']['accuracy_drop']['std']:.2f}%")
    print(f"  Node-level: {results['performance']['asr']['node_level']['mean']:.2f} ± {results['performance']['asr']['node_level']['std']:.2f}%")
    print(f"✓ Results saved to: {log_path}")


def main():
    """Main execution function"""
    # Parse arguments and setup
    args = parse_arguments()
    device = setup_gpu_device(args.device)
    
    # Create output directories
    os.makedirs(args.graph_save_dir, exist_ok=True)
    
    # Configure seeds based on dataset
    if args.dataset == 'arxiv':
        seeds = [0]
        args.re_split = 0
    else:
        seeds = range(args.seeds)
    
    # Initialize result containers
    clean_accs, attacked_accs = [], []
    all_examples = []  # Store text modification examples from all runs
    all_detailed_results = []  # Store detailed results from all runs
    
    print(f"\n{'='*60}")
    print(f"Starting text attack experiment: {args.attack} on {args.dataset}")
    print(f"{'='*60}")
    
    for seed in seeds:
        print(f"\n--- Seed {seed} ---")
        set_seed(seed)
        
        # Load data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split, 
            path_prefix=args.root_path, emb_model=args.emb_type, seed=seed)
        
        train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device)
        
        # Dataset info
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        num_test_nodes = full_data.test_mask.sum().item()
        
        print(f"Dataset: {args.dataset} | Test Nodes: {num_test_nodes} | Features: {num_features}")
        
        # Setup text encoder
        text_encoder = None
        if args.emb_type != "bow":
            text_encoder = TextEncoder(args.emb_type, "LM" if args.emb_type != "Mistral-7B" else "LLM", device)
            print(f"Text encoder: {args.emb_type}")
        
        # Initialize and train model
        gcn_model = GCN(num_features, num_classes, hids=[128])
        trainer = Trainer(gcn_model, device=device, verbose=0)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        ckp = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        trainer.fit((train_data, val_data), verbose=0, callbacks=[ckp, early_stopping], epochs=args.epochs)

        # Evaluate clean performance
        clean_acc = trainer.evaluate(test_data, mask=full_data.test_mask)['acc']
        clean_accs.append(clean_acc)
        print(f"Clean accuracy: {clean_acc:.4f}")
        
        # Select attack targets (low-degree, correctly classified test nodes)
        num_attack_nodes = max(1, int(num_test_nodes * args.ptb_rate))
        test_node_indices = torch.where(full_data.test_mask)[0]
        
        attack_global_indices, attack_local_test_indices = degree_weighted_node_selection(
            full_data, test_data, num_attack_nodes, device, model=gcn_model, only_correct_predictions=True)
        
        # Evaluate clean performance on attacked nodes only
        attack_mask = torch.zeros_like(full_data.test_mask)
        attack_mask[attack_global_indices] = True
        clean_acc_attacked_nodes = trainer.evaluate(test_data, mask=attack_mask)['acc']
        print(f"Clean accuracy on attacked nodes: {clean_acc_attacked_nodes:.4f}")
        
        # Get texts and labels for attack
        attack_texts = [test_data.raw_texts[idx.item()] for idx in attack_global_indices]
        attack_labels = test_data.y[attack_global_indices]
        
        print(f"Attacking {num_attack_nodes}/{num_test_nodes} test nodes with {args.attack}")
        
        # Try to load existing attacked texts first
        attacked_texts = load_attacked_texts_if_exists(args, attack_global_indices, seed)
        
        if attacked_texts is None:
            # Generate attacks if not found
            if args.attack == "random":
                attacked_texts = [random_perturb(text, args.ptb_rate) for text in attack_texts]
            else:
                print(f"Attacking {num_attack_nodes} nodes using batch mode...")
                attacked_texts = apply_text_attack(
                    args.attack, attack_texts, attack_labels, text_encoder, gcn_model, full_data, 
                    attack_global_indices, device=device, inductive=True, test_data=test_data,
                    dataset=args.dataset, emb_type=args.emb_type,
                    batch_size=args.attack_batch_size, use_batch=args.use_batch)
            
            # Save attacked texts
            save_attacked_texts(args, attack_global_indices, attacked_texts, seed)
        
        # Collect examples for logging
        seed_examples = []
        num_examples = min(5, len(attack_texts))
        for i in range(num_examples):
            example = {
                "node_id": int(attack_global_indices[i]),
                "original_text": attack_texts[i],
                "attacked_text": attacked_texts[i],
                "label": int(attack_labels[i]),
                "changed": attack_texts[i] != attacked_texts[i]
            }
            seed_examples.append(example)
        all_examples.append(seed_examples)
        
        # Apply attacks to test data
        attacked_test_data = copy.deepcopy(test_data)
        new_embeddings, changed_local_indices, _ = encode_attacked_texts_efficiently(
            text_encoder, attack_texts, attacked_texts, list(range(len(attack_texts))), 
            batch_size=args.batch_size, dataset=args.dataset, emb_type=args.emb_type, device=device)
        
        if new_embeddings is not None:
            for i, local_idx in enumerate(changed_local_indices):
                global_node_idx = attack_global_indices[local_idx]
                attacked_test_data.x[global_node_idx] = new_embeddings[i]
            print(f"Re-encoded {len(changed_local_indices)}/{len(attack_texts)} changed texts")
        
        # Evaluate attacked performance
        attacked_acc = trainer.evaluate(attacked_test_data, mask=full_data.test_mask)['acc']
        attacked_accs.append(attacked_acc)
        print(f"Attacked accuracy: {attacked_acc:.4f}")
        
        # Evaluate attacked performance on attacked nodes only
        attacked_acc_attacked_nodes = trainer.evaluate(attacked_test_data, mask=attack_mask)['acc']
        print(f"Attacked accuracy on attacked nodes: {attacked_acc_attacked_nodes:.4f}")
        
        # Calculate ASR (Attack Success Rate)
        if clean_acc_attacked_nodes > 0:
            asr_rate = (clean_acc_attacked_nodes - attacked_acc_attacked_nodes) / clean_acc_attacked_nodes
            asr_percentage = asr_rate * 100
        else:
            asr_rate = 0
            asr_percentage = 0
        print(f"ASR (Attack Success Rate): {asr_percentage:.2f}%")
        
        # Calculate node-level success rate (nodes that changed prediction)
        with torch.no_grad():
            clean_preds = trainer.model(test_data.x, test_data.edge_index).argmax(dim=1)
            attacked_preds = trainer.model(attacked_test_data.x, attacked_test_data.edge_index).argmax(dim=1)
            
            # Check which attacked nodes changed their predictions
            changed_predictions = clean_preds[attack_global_indices] != attacked_preds[attack_global_indices]
            node_level_asr = changed_predictions.float().mean().item()
            print(f"Node-level ASR (changed predictions): {node_level_asr * 100:.2f}%")
        
        # Store detailed results
        seed_detailed = {
            "seed": seed,
            "clean_acc": float(clean_acc),
            "attacked_acc": float(attacked_acc),
            "clean_acc_attacked_nodes": float(clean_acc_attacked_nodes),
            "attacked_acc_attacked_nodes": float(attacked_acc_attacked_nodes),
            "asr_accuracy_drop": float(asr_percentage),
            "asr_node_level": float(node_level_asr * 100),
            "num_attack_nodes": num_attack_nodes,
            "num_test_nodes": num_test_nodes,
            "attack_success_rate": len(changed_local_indices) / len(attack_texts) if attack_texts else 0,
            "examples": seed_examples
        }
        all_detailed_results.append(seed_detailed)
        
        # Cleanup
        del gcn_model, trainer, test_data, attacked_test_data
        if text_encoder:
            del text_encoder
        torch.cuda.empty_cache()
    
    # Save final results
    save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="inductive")
    print(f"\n{'='*60}")
    print("Experiment completed successfully!")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()
