import os
import sys
import argparse
import torch
import numpy as np
import copy
import json
import asyncio
from datetime import datetime

# Add parent directory to path for local imports
sys.path.append("../")

# Third-party imports
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

# Local imports
from common import set_seed, load_inductive_graph_dataset_for_gnn, TextEncoder
from common.model_path import LLM_API_CONFIGS
from text_attack import encode_attacked_texts_efficiently, degree_weighted_node_selection
from text_llm_utils import apply_llm_text_attack, save_llm_attack_examples
from text_attack_db import TextAttackDB

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Generate LLM-based text attacks for inductive node classification")
    
    # Data paths
    parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data",
                       help="Root path for data")
    parser.add_argument("--atkg_save_dir", type=str, default="/path/to/GraphAD_data/atkg",
                       help="Directory to save attacked texts")
    
    # Dataset and attack configuration
    parser.add_argument("--dataset", type=str, default="cora",
                       help="Dataset name")
    parser.add_argument("--emb_type", type=str, default="bow", 
                       choices=["bow", "roberta", "MiniLM", "SentenceBert", "Mistral-7B"],
                       help="Text embedding type")
    
    # Attack parameters
    parser.add_argument("--ptb_rate", type=float, default=0.20,
                       help="Node perturbation rate")
    
    # LLM-specific parameters
    parser.add_argument("--llm_provider", type=str, default="openai",
                       choices=["openai", "deepseek", "zhipu"],
                       help="LLM API provider")
    parser.add_argument("--llm_model", type=str, default="gpt-3.5-turbo",
                       help="LLM model name")
    
    # Training configuration
    parser.add_argument("--seeds", type=int, default=3,
                       help="Number of random seeds")
    parser.add_argument("--re_split", type=int, default=2,
                       help="Data split configuration")
    parser.add_argument("--patience", type=int, default=15,
                       help="Early stopping patience")
    parser.add_argument("--epochs", type=int, default=200,
                       help="Maximum training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                       help="Batch size for text encoding")
    
    # System configuration
    parser.add_argument("--device", type=int, default=0,
                       help="GPU device ID")
    
    return parser.parse_args()


def save_attacked_texts(args, attack_global_indices, attacked_texts, seed):
    """Save attacked texts and their node IDs to specified directory"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_inductive")
    os.makedirs(save_dir, exist_ok=True)
    
    # Include ptb_rate in filename to ensure parameter consistency
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    modified_texts = []
    for idx, (orig_idx, new_text) in enumerate(zip(attack_global_indices, attacked_texts)):
        modified_texts.append({
            "node_id": int(orig_idx),
            "attacked_text": new_text
        })
    
    # Save with metadata for parameter validation
    data_to_save = {
        "metadata": {
            "ptb_rate": args.ptb_rate,
            "text_ptb_rate": 1.0,  # LLM can modify entire text, so set to 1.0
            "attack": f"llm_{args.llm_model}",  # Use LLM model name as attack type
            "emb_type": args.emb_type,
            "dataset": args.dataset,
            "num_attacked_nodes": len(attack_global_indices),
            # Additional LLM-specific metadata
            "llm_model": args.llm_model,
            "llm_provider": args.llm_provider
        },
        "attacked_texts": modified_texts
    }
    
    with open(save_path, 'w') as f:
        json.dump(data_to_save, f, indent=4)
    print(f"✓ Saved attacked texts to {save_path}")


def load_attacked_texts_if_exists(args, attack_global_indices, seed):
    """Load attacked texts if they already exist and parameters match, return None if not found or mismatched"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_inductive")
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            saved_data = json.load(f)
        
        # Check if this is new format with metadata
        if "metadata" in saved_data:
            metadata = saved_data["metadata"]
            # Verify key parameters match
            if (metadata.get("ptb_rate") == args.ptb_rate and 
                metadata.get("attack") == f"llm_{args.llm_model}" and
                metadata.get("llm_model") == args.llm_model and
                metadata.get("llm_provider") == args.llm_provider and
                metadata.get("num_attacked_nodes") == len(attack_global_indices)):
                
                print(f"✓ Loading existing attacked texts from {save_path}")
                attacked_texts_data = saved_data["attacked_texts"]
            else:
                print(f"⚠ Parameter mismatch in {save_path}, will regenerate")
                return None
        else:
            # Old format without metadata, assume mismatch for safety
            print(f"⚠ Old format file found {save_path}, will regenerate with new format")
            return None
        
        # Create a mapping from node_id to attacked_text
        node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
        
        # Extract attacked texts in the same order as attack_global_indices
        attacked_texts = []
        for idx in attack_global_indices:
            node_id = int(idx)
            if node_id in node_to_text:
                attacked_texts.append(node_to_text[node_id])
            else:
                print(f"⚠ Missing node {node_id} in saved data, will regenerate")
                return None  # Missing data, need to regenerate
        
        return attacked_texts
    
    return None


def save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="inductive"):
    """Save comprehensive attack results including text modification examples"""
    aggregated_examples = []
    for seed_examples in all_examples:
        aggregated_examples.extend(seed_examples)
    
    summary_examples = aggregated_examples[:5] if len(aggregated_examples) >= 5 else aggregated_examples
    
    # Calculate ASR statistics
    clean_accs_attacked = [result["clean_acc_attacked_nodes"] for result in all_detailed_results]
    attacked_accs_attacked = [result["attacked_acc_attacked_nodes"] for result in all_detailed_results]
    asr_accuracy_drops = [result["asr_accuracy_drop"] for result in all_detailed_results]
    asr_node_levels = [result["asr_node_level"] for result in all_detailed_results]
    
    results = {
        "experiment_info": {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "args": vars(args),
            "setting": setting,
            "attack_method": "llm"
        },
        "performance": {
            "clean": {
                "mean": round(float(np.mean(clean_accs)) * 100, 2),
                "std": round(float(np.std(clean_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
            },
            "attacked": {
                "mean": round(float(np.mean(attacked_accs)) * 100, 2),
                "std": round(float(np.std(attacked_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs]
            },
            "attacked_nodes_only": {
                "clean": {
                    "mean": round(float(np.mean(clean_accs_attacked)) * 100, 2),
                    "std": round(float(np.std(clean_accs_attacked)) * 100, 2),
                    "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs_attacked]
                },
                "attacked": {
                    "mean": round(float(np.mean(attacked_accs_attacked)) * 100, 2),
                    "std": round(float(np.std(attacked_accs_attacked)) * 100, 2),
                    "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs_attacked]
                }
            },
            "asr": {
                "accuracy_drop": {
                    "mean": round(float(np.mean(asr_accuracy_drops)), 2),
                    "std": round(float(np.std(asr_accuracy_drops)), 2),
                    "all_runs": [round(float(asr), 2) for asr in asr_accuracy_drops]
                },
                "node_level": {
                    "mean": round(float(np.mean(asr_node_levels)), 2),
                    "std": round(float(np.std(asr_node_levels)), 2),
                    "all_runs": [round(float(asr), 2) for asr in asr_node_levels]
                }
            }
        },
        "text_modification_examples": summary_examples,
        "detailed_results": all_detailed_results,
        "llm_info": {
            "provider": args.llm_provider,
            "model": args.llm_model,
            "base_url": LLM_API_CONFIGS[args.llm_provider]["base_url"]
        }
    }
    
    log_dir = os.path.join("./logs_text_attack", args.dataset, f"llm_{setting}")
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"results_{args.llm_provider}_{args.llm_model}_{args.emb_type}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{setting.capitalize()} LLM Attack Results Summary:")
    print(f"Overall Performance:")
    print(f"  Clean - Acc: {results['performance']['clean']['mean']:.2f} ± {results['performance']['clean']['std']:.2f}")
    print(f"  Attacked - Acc: {results['performance']['attacked']['mean']:.2f} ± {results['performance']['attacked']['std']:.2f}")
    print(f"Attacked Nodes Only:")
    print(f"  Clean - Acc: {results['performance']['attacked_nodes_only']['clean']['mean']:.2f} ± {results['performance']['attacked_nodes_only']['clean']['std']:.2f}")
    print(f"  Attacked - Acc: {results['performance']['attacked_nodes_only']['attacked']['mean']:.2f} ± {results['performance']['attacked_nodes_only']['attacked']['std']:.2f}")
    print(f"Attack Success Rate (ASR):")
    print(f"  Accuracy Drop: {results['performance']['asr']['accuracy_drop']['mean']:.2f} ± {results['performance']['asr']['accuracy_drop']['std']:.2f}%")
    print(f"  Node-level: {results['performance']['asr']['node_level']['mean']:.2f} ± {results['performance']['asr']['node_level']['std']:.2f}%")
    print(f"✓ Results saved to: {log_path}")


async def main():
    """Main execution function"""
    args = parse_arguments()
    device = torch.device(f'cuda:{args.device}')
    
    print(f"✓ Using LLM provider: {args.llm_provider}")
    print(f"✓ Using LLM model: {args.llm_model}")
    
    os.makedirs(args.atkg_save_dir, exist_ok=True)
    
    # Initialize database for progress tracking
    db = TextAttackDB()
    
    # Configure seeds based on dataset
    if args.dataset == 'arxiv':
        seeds = [0]
        args.re_split = 0
    else:
        seeds = range(args.seeds)
    
    clean_accs, attacked_accs = [], []
    all_examples = []
    all_detailed_results = []
    
    print(f"\n{'='*60}")
    print(f"Starting LLM text attack experiment on {args.dataset}")
    print(f"{'='*60}")
    
    for seed in seeds:
        print(f"\n--- Seed {seed} ---")
        set_seed(seed)
        
        # Early skip if attacked texts already exist for this seed (skip GCN training/testing)
        save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_inductive")
        save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
        if os.path.exists(save_path):
            try:
                with open(save_path, 'r') as f:
                    saved = json.load(f)
                meta = saved.get("metadata", {})
                if (
                    meta.get("dataset") == args.dataset and
                    meta.get("emb_type") == args.emb_type and
                    meta.get("llm_model") == args.llm_model and
                    meta.get("llm_provider") == args.llm_provider and
                    float(meta.get("ptb_rate")) == args.ptb_rate
                ):
                    print(f"✓ Found existing attacked texts: {save_path}")
                    print("✓ Skipping GCN training/testing for this seed.")
                    continue
            except Exception:
                print(f"⚠ Found existing file but failed to validate metadata: {save_path}. Will recompute.")
        
        # Generate experiment hash for this seed
        experiment_hash = db.get_experiment_hash(
            args.dataset, args.llm_provider, args.llm_model, 
            args.emb_type, args.ptb_rate, seed, "inductive"
        )
        
        # Check if this specific experiment is already completed
        if db.check_experiment_completion(experiment_hash):
            print(f"✓ Experiment for seed {seed} already completed, loading results...")
            
            # Try to load existing attacked texts
            target_nodes_dummy = torch.arange(10)  # Dummy for loading check
            attacked_texts = load_attacked_texts_if_exists(args, target_nodes_dummy, seed)
            
            if attacked_texts is not None:
                print(f"✓ Loaded existing results for seed {seed}, skipping to next seed")
                continue
            else:
                print(f"⚠ Could not load existing results for seed {seed}, will recompute")
        
        # Load data
        full_data, (train_data, val_data, test_data) = load_inductive_graph_dataset_for_gnn(
            args.dataset, device, re_split=args.re_split, 
            path_prefix=args.root_path, emb_model=args.emb_type, seed=seed)
        
        train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device)
        
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        num_test_nodes = full_data.test_mask.sum().item()
        
        print(f"Dataset: {args.dataset} | Test Nodes: {num_test_nodes} | Features: {num_features}")
        
        # Setup text encoder
        text_encoder = None
        if args.emb_type != "bow":
            text_encoder = TextEncoder(args.emb_type, "LM" if args.emb_type != "Mistral-7B" else "LLM", device)
            print(f"Text encoder: {args.emb_type}")
        
        
        # Initialize and train model
        gcn_model = GCN(num_features, num_classes, hids=[128])
        trainer = Trainer(gcn_model, device=device, verbose=0)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        ckp = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}.pth', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        trainer.fit((train_data, val_data), verbose=0, callbacks=[ckp, early_stopping], epochs=args.epochs)

        # Evaluate clean performance
        clean_acc = trainer.evaluate(test_data, mask=full_data.test_mask)['acc']
        clean_accs.append(clean_acc)
        print(f"Clean accuracy: {clean_acc:.4f}")
        
        # Select attack targets (low-degree nodes)
        num_attack_nodes = max(1, int(num_test_nodes * args.ptb_rate))
        attack_global_indices, attack_local_test_indices = degree_weighted_node_selection(
            full_data, test_data, num_attack_nodes, device, model=gcn_model, only_correct_predictions=True)
        
        # Evaluate clean performance on attacked nodes only
        attack_mask = torch.zeros_like(full_data.test_mask)
        attack_mask[attack_global_indices] = True
        clean_acc_attacked_nodes = trainer.evaluate(test_data, mask=attack_mask)['acc']
        print(f"Clean accuracy on attacked nodes: {clean_acc_attacked_nodes:.4f}")
        
        # Get texts and labels for attack
        attack_texts = [full_data.raw_texts[idx.item()] for idx in attack_global_indices]
        attack_labels = full_data.y[attack_global_indices]
        
        print(f"Attacking {num_attack_nodes}/{num_test_nodes} test nodes with LLM")
        
        # Try to load existing attacked texts first
        attacked_texts = load_attacked_texts_if_exists(args, attack_global_indices, seed)
        
        if attacked_texts is None:
            try:
                # Generate LLM attacks if not found - now with database support
                attacked_texts = await apply_llm_text_attack(
                    args.dataset, attack_global_indices, attack_texts, attack_labels, 
                    full_data.label_name, args.llm_provider, args.llm_model, 
                    full_data.edge_index, full_data.y, args.emb_type, args.ptb_rate, seed, "inductive"
                )
                
                # Save attacked texts
                save_attacked_texts(args, attack_global_indices, attacked_texts, seed)
                
            except Exception as e:
                print(f"❌ LLM attack failed for seed {seed}: {str(e)}")
                print(f"Progress has been saved to database. Rerun the script to resume.")
                
                # Show current experiment progress
                progress = db.get_experiment_progress(experiment_hash)
                print(f"Current progress: {progress['completed_batches']}/{progress['total_batches']} batches "
                      f"({progress['completion_percentage']:.1f}%) completed")
                
                # Skip this seed and continue with next
                continue
        
        # Collect examples for logging
        seed_examples = save_llm_attack_examples(
            args.dataset, attack_global_indices, attack_texts, attacked_texts, 
            attack_labels, full_data.label_name, full_data.edge_index, full_data.y, seed)
        all_examples.append(seed_examples)
        
        # Apply attacks to test data
        attacked_test_data = copy.deepcopy(test_data)
        new_embeddings, changed_local_indices, _ = encode_attacked_texts_efficiently(
            text_encoder, attack_texts, attacked_texts, list(range(len(attack_texts))), 
            batch_size=args.batch_size, dataset=args.dataset, emb_type=args.emb_type, device=device)
        
        if new_embeddings is not None:
            for i, local_idx in enumerate(changed_local_indices):
                global_node_idx = attack_global_indices[local_idx]
                attacked_test_data.x[global_node_idx] = new_embeddings[i]
            print(f"Re-encoded {len(changed_local_indices)}/{len(attack_texts)} changed texts")
        
        # Evaluate attacked performance
        attacked_acc = trainer.evaluate(attacked_test_data, mask=full_data.test_mask)['acc']
        attacked_accs.append(attacked_acc)
        print(f"Attacked accuracy: {attacked_acc:.4f}")
        
        # Evaluate attacked performance on attacked nodes only
        attacked_acc_attacked_nodes = trainer.evaluate(attacked_test_data, mask=attack_mask)['acc']
        print(f"Attacked accuracy on attacked nodes: {attacked_acc_attacked_nodes:.4f}")
        
        # Calculate ASR (Attack Success Rate)
        if clean_acc_attacked_nodes > 0:
            asr_rate = (clean_acc_attacked_nodes - attacked_acc_attacked_nodes) / clean_acc_attacked_nodes
            asr_percentage = asr_rate * 100
        else:
            asr_rate = 0
            asr_percentage = 0
        print(f"ASR (Attack Success Rate): {asr_percentage:.2f}%")
        
        # Calculate node-level success rate (nodes that changed prediction)
        with torch.no_grad():
            clean_preds = trainer.model(test_data.x, test_data.edge_index).argmax(dim=1)
            attacked_preds = trainer.model(attacked_test_data.x, attacked_test_data.edge_index).argmax(dim=1)
            
            # Check which attacked nodes changed their predictions
            changed_predictions = clean_preds[attack_global_indices] != attacked_preds[attack_global_indices]
            node_level_asr = changed_predictions.float().mean().item()
            print(f"Node-level ASR (changed predictions): {node_level_asr * 100:.2f}%")
        
        # Store detailed results
        seed_detailed = {
            "seed": seed,
            "clean_acc": float(clean_acc),
            "attacked_acc": float(attacked_acc),
            "clean_acc_attacked_nodes": float(clean_acc_attacked_nodes),
            "attacked_acc_attacked_nodes": float(attacked_acc_attacked_nodes),
            "asr_accuracy_drop": float(asr_percentage),
            "asr_node_level": float(node_level_asr * 100),
            "num_attack_nodes": num_attack_nodes,
            "num_test_nodes": num_test_nodes,
            "attack_success_rate": len(changed_local_indices) / len(attack_texts) if attack_texts else 0,
            "examples": seed_examples
        }
        all_detailed_results.append(seed_detailed)
        
        # Cleanup
        del gcn_model, trainer, test_data, attacked_test_data
        if text_encoder:
            del text_encoder
        torch.cuda.empty_cache()
    
    # Save final results
    if len(clean_accs) > 0 and len(attacked_accs) > 0:
        save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="inductive")
        print(f"\n{'='*60}")
        print("LLM attack experiment completed successfully!")
        print(f"{'='*60}")
    else:
        print(f"\n{'='*60}")
        print("⚠ No complete results to save. Check database for partial progress.")
        print("Rerun the script to resume interrupted experiments.")
        
        # Show overall progress for all experiments
        experiments = db.list_experiments()
        if experiments:
            print("\nCurrent experiment status in database:")
            for exp in experiments:
                print(f"  Experiment {exp['short_hash']}: {exp['completed_batches']}/{exp['total_batches']} batches "
                      f"({exp['completion_percentage']:.1f}%) completed")
        print(f"{'='*60}")


if __name__ == "__main__":
    asyncio.run(main())
