import os
import sys
import argparse
import torch
import numpy as np
import copy
import json
from datetime import datetime

# Add parent directory to path for local imports
sys.path.append("../")

# Third-party imports
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

# Local imports
from common import set_seed, load_graph_dataset_for_gnn, TextEncoder
from text_attack import (
    random_perturb, 
    apply_text_attack, 
    save_results, 
    encode_attacked_texts_efficiently, 
    degree_weighted_node_selection
)


def configure_tensorflow_gpu(device_id):
    """Configure TensorFlow GPU settings to avoid conflicts with PyTorch"""
    try:
        import tensorflow as tf
        
        # Get list of GPUs
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                # Set memory growth for all GPUs to avoid OOM
                for gpu in gpus:
                    tf.config.experimental.set_memory_growth(gpu, True)
                
                # If we have the target GPU, make it visible to TensorFlow
                if device_id < len(gpus):
                    tf.config.experimental.set_visible_devices(gpus[device_id], 'GPU')
                    print(f"✓ TensorFlow using GPU {device_id}")
                else:
                    # Use CPU if target GPU not available
                    tf.config.experimental.set_visible_devices([], 'GPU')
                    print("✓ TensorFlow using CPU (GPU not available)")
                    
            except RuntimeError as e:
                print(f"⚠ TensorFlow GPU configuration failed: {e}")
        else:
            print("✓ TensorFlow using CPU (no GPU available)")
            
    except ImportError:
        print("⚠ TensorFlow not available")


def setup_gpu_device(device_id):
    """Setup GPU device for all CUDA operations"""
    if not torch.cuda.is_available():
        print("CUDA is not available. Using CPU.")
        return torch.device('cpu')
    
    if device_id < 0:
        print("Using CPU as requested.")
        return torch.device('cpu')
    
    num_gpus = torch.cuda.device_count()
    if device_id >= num_gpus:
        raise RuntimeError(f"GPU device {device_id} is not available. Only {num_gpus} GPUs are available.")
    
    # Set the default GPU device
    torch.cuda.set_device(device_id)
    device = torch.device(f'cuda:{device_id}')
    
    print(f"✓ Using GPU device: {device_id}")
    print(f"✓ GPU name: {torch.cuda.get_device_name(device_id)}")
    print(f"✓ Available GPUs: {num_gpus}")
    print(f"✓ Current device: {torch.cuda.current_device()}")
    
    # Configure TensorFlow to use the same GPU or avoid conflicts
    configure_tensorflow_gpu(device_id)
    
    return device


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Generate text attacks for transductive node classification")
    
    # Data paths
    parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data",
                       help="Root path for data")
    parser.add_argument("--graph_save_dir", type=str, default="/path/to/GraphAD_data/text_atkg",
                       help="Directory to save graph attack results")
    parser.add_argument("--atkg_save_dir", type=str, default="/path/to/GraphAD_data/atkg",
                       help="Directory to save attacked texts")
    
    # Dataset and attack configuration
    parser.add_argument("--dataset", type=str, default="cora",
                       help="Dataset name")
    parser.add_argument("--attack", type=str, default="random", 
                       choices=["random", "textfooler", "bae", "pwws", "hotflip"],
                       help="Attack method")
    parser.add_argument("--emb_type", type=str, default="bow", 
                       choices=["bow", "roberta", "MiniLM", "SentenceBert", "Mistral-7B"],
                       help="Text embedding type")
    
    # Attack parameters
    parser.add_argument("--ptb_rate", type=float, default=0.20,
                       help="Node perturbation rate")
    
    # Batch attack configuration
    parser.add_argument("--use_batch", action="store_true", default=True,
                       help="Use batch processing for attacks (default: True)")
    parser.add_argument("--attack_batch_size", type=int, default=8,
                       help="Batch size for attack processing")
    
    # Training configuration
    parser.add_argument("--seeds", type=int, default=3,
                       help="Number of random seeds")
    parser.add_argument("--re_split", type=int, default=1,
                       help="Data split configuration")
    parser.add_argument("--patience", type=int, default=15,
                       help="Early stopping patience")
    parser.add_argument("--epochs", type=int, default=200,
                       help="Maximum training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                       help="Batch size for text encoding")
    
    # System configuration
    parser.add_argument("--device", type=int, default=0,
                       help="GPU device ID")
    
    return parser.parse_args()


def save_attacked_texts(args, target_nodes, attacked_texts, seed):
    """Save attacked texts and their node IDs to specified directory"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"text_{args.attack}_transductive")
    os.makedirs(save_dir, exist_ok=True)
    
    # Include ptb_rate in filename to ensure parameter consistency
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    # Only save nodes that were actually modified
    modified_texts = []
    for idx, (orig_idx, new_text) in enumerate(zip(target_nodes, attacked_texts)):
        modified_texts.append({
            "node_id": int(orig_idx),
            "attacked_text": new_text
        })
    
    # Save with metadata for parameter validation
    data_to_save = {
        "metadata": {
            "ptb_rate": args.ptb_rate,
            "attack": args.attack,
            "emb_type": args.emb_type,
            "dataset": args.dataset,
            "num_attacked_nodes": len(target_nodes)
        },
        "attacked_texts": modified_texts
    }
    
    with open(save_path, 'w') as f:
        json.dump(data_to_save, f, indent=4)
    print(f"✓ Saved attacked texts to {save_path}")


def load_attacked_texts_if_exists(args, target_nodes, seed):
    """Load attacked texts if they already exist and parameters match, return None if not found or mismatched"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"text_{args.attack}_transductive")
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            saved_data = json.load(f)
        
        # Check if this is new format with metadata
        if "metadata" in saved_data:
            metadata = saved_data["metadata"]
            # Verify key parameters match
            if (metadata.get("ptb_rate") == args.ptb_rate and 
                metadata.get("attack") == args.attack and
                metadata.get("emb_type") == args.emb_type and
                metadata.get("num_attacked_nodes") == len(target_nodes)):
                
                print(f"✓ Loading existing attacked texts from {save_path}")
                attacked_texts_data = saved_data["attacked_texts"]
            else:
                print(f"⚠ Parameter mismatch in {save_path}, will regenerate")
                return None
        else:
            # Old format without metadata, assume mismatch for safety
            print(f"⚠ Old format file found {save_path}, will regenerate with new format")
            return None
        
        # Create a mapping from node_id to attacked_text
        node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
        
        # Extract attacked texts in the same order as target_nodes
        attacked_texts = []
        for idx in target_nodes:
            node_id = int(idx)
            if node_id in node_to_text:
                attacked_texts.append(node_to_text[node_id])
            else:
                print(f"⚠ Missing node {node_id} in saved data, will regenerate")
                return None  # Missing data, need to regenerate
        
        return attacked_texts
    
    return None


def save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="transductive"):
    """Save comprehensive attack results including text modification examples"""
    # Aggregate examples from all runs
    aggregated_examples = []
    for seed_examples in all_examples:
        aggregated_examples.extend(seed_examples)
    
    # Take up to 5 examples per attack type for final summary
    summary_examples = aggregated_examples[:5] if len(aggregated_examples) >= 5 else aggregated_examples
    
    results = {
        "experiment_info": {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "args": vars(args),
            "setting": setting
        },
        "performance": {
            "clean": {
                "mean": round(float(np.mean(clean_accs)) * 100, 2),
                "std": round(float(np.std(clean_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
            },
            "attacked": {
                "mean": round(float(np.mean(attacked_accs)) * 100, 2),
                "std": round(float(np.std(attacked_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs]
            }
        },
        "text_modification_examples": summary_examples,
        "detailed_results": all_detailed_results
    }
    
    log_dir = os.path.join("./logs_text_attack", args.dataset, f"text_{args.attack}_{setting}")
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"detailed_results_{args.emb_type}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{setting.capitalize()} Text Attack Results Summary:")
    print(f"Clean - Acc: {results['performance']['clean']['mean']:.2f} ± {results['performance']['clean']['std']:.2f}")
    print(f"Attacked - Acc: {results['performance']['attacked']['mean']:.2f} ± {results['performance']['attacked']['std']:.2f}")
    print(f"✓ Results saved to: {log_path}")


def main():
    """Main execution function"""
    # Parse arguments and setup
    args = parse_arguments()
    device = setup_gpu_device(args.device)
    
    # Create output directories
    os.makedirs(args.graph_save_dir, exist_ok=True)
    
    # Initialize result containers
    clean_accs, attacked_accs = [], []
    all_examples = []  # Store text modification examples from all runs
    all_detailed_results = []  # Store detailed results from all runs
    
    print(f"\n{'='*60}")
    print(f"Starting text attack experiment: {args.attack} on {args.dataset}")
    print(f"{'='*60}")
    
    for seed in range(args.seeds):
        print(f"\n--- Seed {seed} ---")
        set_seed(seed)
        
        # Load data
        full_data = load_graph_dataset_for_gnn(args.dataset, device, re_split=args.re_split, 
                                               path_prefix=args.root_path, emb_model=args.emb_type)
        full_data = full_data.to(device)
        
        # Dataset info
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        num_test_nodes = full_data.test_mask.sum().item()
        
        print(f"Dataset: {args.dataset} | Features: {num_features} | Classes: {num_classes} | Test Nodes: {num_test_nodes}")
        
        # Setup node texts
        node_texts = full_data.raw_texts if hasattr(full_data, 'raw_texts') and full_data.raw_texts else \
                    [f"Text content for node {i}" for i in range(full_data.num_nodes)]
        
        # Setup text encoder
        text_encoder = None
        if args.emb_type != "bow":
            text_encoder = TextEncoder(args.emb_type, "LM" if args.emb_type != "Mistral-7B" else "LLM", device)
            print(f"Text encoder: {args.emb_type}")
        
        # Initialize and train model
        gcn_model = GCN(num_features, num_classes, hids=[64])
        trainer = Trainer(gcn_model, device=device)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        cb = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}.pt', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        
        trainer.fit(full_data, mask=(full_data.train_mask, full_data.val_mask), 
                   verbose=0, callbacks=[cb, early_stopping], epochs=args.epochs)
        
        # Evaluate clean performance
        clean_acc = trainer.evaluate(full_data, mask=full_data.test_mask)['acc']
        clean_accs.append(clean_acc)
        print(f"Clean accuracy: {clean_acc:.4f}")
        
        # Select attack targets (low-degree, correctly classified training nodes)
        num_train_nodes = full_data.train_mask.sum().item()
        num_attack_nodes = max(1, int(num_train_nodes * args.ptb_rate))
        dummy_test_data = full_data  # In transductive, test_data is same as full_data
        
        target_nodes, local_test_indices = degree_weighted_node_selection(
            full_data, dummy_test_data, num_attack_nodes, device, atk_type='transductive',
            model=gcn_model, only_correct_predictions=True)
        
        # Get texts and labels for attack
        target_texts = [node_texts[i.cpu().item()] for i in target_nodes]
        target_labels = full_data.y[target_nodes]
        
        print(f"Attacking {num_attack_nodes}/{num_train_nodes} train nodes with {args.attack}")
        
        # Try to load existing attacked texts first
        attacked_texts = load_attacked_texts_if_exists(args, target_nodes, seed)
        
        if attacked_texts is None:
            # Generate attacks if not found
            try:
                if args.attack == "random":
                    attacked_texts = [random_perturb(text, args.ptb_rate) for text in target_texts]
                else:
                    print(f"Attacking {num_attack_nodes} nodes using batch mode...")
                    attacked_texts = apply_text_attack(
                        args.attack, target_texts, target_labels, text_encoder, gcn_model, full_data, 
                        target_nodes, device=device, dataset=args.dataset, emb_type=args.emb_type,
                        batch_size=args.attack_batch_size, use_batch=args.use_batch, inductive=False)
            except Exception as e:
                raise e

            # Save attacked texts
            save_attacked_texts(args, target_nodes, attacked_texts, seed)
        
        # Collect examples for logging
        seed_examples = []
        num_examples = min(5, len(target_texts))
        for i in range(num_examples):
            example = {
                "node_id": int(target_nodes[i]),
                "original_text": target_texts[i],
                "attacked_text": attacked_texts[i],
                "label": int(target_labels[i]),
                "changed": target_texts[i] != attacked_texts[i]
            }
            seed_examples.append(example)
        all_examples.append(seed_examples)
        
        # Apply attacks to data
        attacked_data = copy.deepcopy(full_data)
        new_embeddings, changed_indices, changed_nodes = encode_attacked_texts_efficiently(
            text_encoder, target_texts, attacked_texts, target_nodes, 
            batch_size=args.batch_size, dataset=args.dataset, emb_type=args.emb_type, device=device)
        
        if new_embeddings is not None and len(changed_nodes) > 0:
            for i, node_idx in enumerate(changed_nodes):
                attacked_data.x[node_idx] = new_embeddings[i]
            print(f"Re-encoded {len(changed_nodes)}/{len(target_texts)} changed texts")
        
        # Evaluate attacked performance on attacked data, retrained
        gcn_model_2 = GCN(num_features, num_classes, hids=[64])
        trainer = Trainer(gcn_model_2, device=device)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        cb = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}_attacked.pt', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        
        trainer.fit(attacked_data, mask=(attacked_data.train_mask, attacked_data.val_mask), 
                   verbose=0, callbacks=[cb, early_stopping], epochs=args.epochs)
        attacked_acc = trainer.evaluate(attacked_data, mask=attacked_data.test_mask)['acc']
        attacked_accs.append(attacked_acc)
        print(f"Attacked accuracy: {attacked_acc:.4f}")
        
        # Store detailed results
        seed_detailed = {
            "seed": seed,
            "clean_acc": float(clean_acc),
            "attacked_acc": float(attacked_acc),
            "num_attack_nodes": num_attack_nodes,
            "num_test_nodes": num_test_nodes,
            "attack_success_rate": len(changed_nodes) / len(target_texts) if target_texts else 0,
            "examples": seed_examples
        }
        all_detailed_results.append(seed_detailed)
        
        # Cleanup
        del gcn_model, trainer, attacked_data, gcn_model_2
        if text_encoder:
            del text_encoder
        torch.cuda.empty_cache()
    
    # Save final results
    save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="transductive")
    print(f"\n{'='*60}")
    print("Experiment completed successfully!")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()
