import os
import sys
import argparse
import torch
import numpy as np
import copy
import json
import asyncio
from datetime import datetime

# Add parent directory to path for local imports
sys.path.append("../")

# Third-party imports
from greatx.nn.models import GCN
from greatx.training import Trainer
from greatx.training.callbacks import ModelCheckpoint, EarlyStopping

# Local imports
from common import set_seed, load_graph_dataset_for_gnn, TextEncoder
from common.model_path import LLM_API_CONFIGS
from text_attack import encode_attacked_texts_efficiently, degree_weighted_node_selection
from text_llm_utils import apply_llm_text_attack, save_llm_attack_examples
from text_attack_db import TextAttackDB

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Generate LLM-based text attacks for transductive node classification")
    
    # Data paths
    parser.add_argument("--root_path", type=str, default="/path/to/GraphAD_data",
                       help="Root path for data")
    parser.add_argument("--atkg_save_dir", type=str, default="/path/to/GraphAD_data/atkg",
                       help="Directory to save attacked texts")
    
    # Dataset and attack configuration
    parser.add_argument("--dataset", type=str, default="cora",
                       help="Dataset name")
    parser.add_argument("--emb_type", type=str, default="bow", 
                       choices=["bow", "roberta", "MiniLM", "SentenceBert", "Mistral-7B"],
                       help="Text embedding type")
    
    # Attack parameters
    parser.add_argument("--ptb_rate", type=float, default=0.20,
                       help="Node perturbation rate")
    
    # LLM-specific parameters
    parser.add_argument("--llm_provider", type=str, default="openai",
                       choices=["openai", "deepseek", "zhipu"],
                       help="LLM API provider")
    parser.add_argument("--llm_model", type=str, default="gpt-3.5-turbo",
                       help="LLM model name")
    
    # Training configuration
    parser.add_argument("--seeds", type=int, default=3,
                       help="Number of random seeds")
    parser.add_argument("--re_split", type=int, default=1,
                       help="Data split configuration")
    parser.add_argument("--patience", type=int, default=15,
                       help="Early stopping patience")
    parser.add_argument("--epochs", type=int, default=200,
                       help="Maximum training epochs")
    parser.add_argument("--batch_size", type=int, default=4,
                       help="Batch size for text encoding")
    
    # System configuration
    parser.add_argument("--device", type=int, default=0,
                       help="GPU device ID")
    
    return parser.parse_args()


def save_attacked_texts(args, target_nodes, attacked_texts, seed):
    """Save attacked texts and their node IDs to specified directory"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_transductive")
    os.makedirs(save_dir, exist_ok=True)
    
    # Include ptb_rate in filename to ensure parameter consistency
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    modified_texts = []
    for idx, (orig_idx, new_text) in enumerate(zip(target_nodes, attacked_texts)):
        modified_texts.append({
            "node_id": int(orig_idx),
            "attacked_text": new_text
        })
    
    # Save with metadata for parameter validation
    data_to_save = {
        "metadata": {
            "ptb_rate": args.ptb_rate,
            "text_ptb_rate": 1.0,  # LLM can modify entire text, so set to 1.0
            "attack": f"llm_{args.llm_model}",  # Use LLM model name as attack type
            "emb_type": args.emb_type,
            "dataset": args.dataset,
            "num_attacked_nodes": len(target_nodes),
            # Additional LLM-specific metadata
            "llm_model": args.llm_model,
            "llm_provider": args.llm_provider
        },
        "attacked_texts": modified_texts
    }
    
    with open(save_path, 'w') as f:
        json.dump(data_to_save, f, indent=4)
    print(f"✓ Saved attacked texts to {save_path}")


def load_attacked_texts_if_exists(args, target_nodes, seed):
    """Load attacked texts if they already exist and parameters match, return None if not found or mismatched"""
    save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_transductive")
    save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
    
    if os.path.exists(save_path):
        with open(save_path, 'r') as f:
            saved_data = json.load(f)
        
        # Check if this is new format with metadata
        if "metadata" in saved_data:
            metadata = saved_data["metadata"]
            # Verify key parameters match
            if (metadata.get("ptb_rate") == args.ptb_rate and 
                metadata.get("attack") == f"llm_{args.llm_model}" and
                metadata.get("llm_model") == args.llm_model and
                metadata.get("llm_provider") == args.llm_provider and
                metadata.get("num_attacked_nodes") == len(target_nodes)):
                
                print(f"✓ Loading existing attacked texts from {save_path}")
                attacked_texts_data = saved_data["attacked_texts"]
            else:
                print(f"⚠ Parameter mismatch in {save_path}, will regenerate")
                return None
        else:
            # Old format without metadata, assume mismatch for safety
            print(f"⚠ Old format file found {save_path}, will regenerate with new format")
            return None
        
        # Create a mapping from node_id to attacked_text
        node_to_text = {item['node_id']: item['attacked_text'] for item in attacked_texts_data}
        
        # Extract attacked texts in the same order as target_nodes
        attacked_texts = []
        for idx in target_nodes:
            node_id = int(idx)
            if node_id in node_to_text:
                attacked_texts.append(node_to_text[node_id])
            else:
                print(f"⚠ Missing node {node_id} in saved data, will regenerate")
                return None  # Missing data, need to regenerate
        
        return attacked_texts
    
    return None


def save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="transductive"):
    """Save comprehensive attack results including text modification examples"""
    aggregated_examples = []
    for seed_examples in all_examples:
        aggregated_examples.extend(seed_examples)
    
    summary_examples = aggregated_examples[:5] if len(aggregated_examples) >= 5 else aggregated_examples
    
    results = {
        "experiment_info": {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "args": vars(args),
            "setting": setting,
            "attack_method": "llm"
        },
        "performance": {
            "clean": {
                "mean": round(float(np.mean(clean_accs)) * 100, 2),
                "std": round(float(np.std(clean_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
            },
            "attacked": {
                "mean": round(float(np.mean(attacked_accs)) * 100, 2),
                "std": round(float(np.std(attacked_accs)) * 100, 2),
                "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs]
            }
        },
        "text_modification_examples": summary_examples,
        "detailed_results": all_detailed_results,
        "llm_info": {
            "provider": args.llm_provider,
            "model": args.llm_model,
            "base_url": LLM_API_CONFIGS[args.llm_provider]["base_url"]
        }
    }
    
    log_dir = os.path.join("./logs_text_attack", args.dataset, f"llm_{setting}")
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"results_{args.llm_provider}_{args.llm_model}_{args.emb_type}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{setting.capitalize()} LLM Attack Results Summary:")
    print(f"Clean - Acc: {results['performance']['clean']['mean']:.2f} ± {results['performance']['clean']['std']:.2f}")
    print(f"Attacked - Acc: {results['performance']['attacked']['mean']:.2f} ± {results['performance']['attacked']['std']:.2f}")
    print(f"✓ Results saved to: {log_path}")


async def main():
    """Main execution function"""
    args = parse_arguments()
    device = torch.device(f'cuda:{args.device}')
    
    print(f"✓ Using LLM provider: {args.llm_provider}")
    print(f"✓ Using LLM model: {args.llm_model}")
    
    os.makedirs(args.atkg_save_dir, exist_ok=True)
    
    # Initialize database for progress tracking
    db = TextAttackDB()
    
    clean_accs, attacked_accs = [], []
    all_examples = []
    all_detailed_results = []
    
    print(f"\n{'='*60}")
    print(f"Starting LLM text attack experiment on {args.dataset} (transductive)")
    print(f"{'='*60}")
    
    for seed in range(args.seeds):
        print(f"\n--- Seed {seed} ---")
        set_seed(seed)
        
        # Early skip if attacked texts already exist for this seed (skip GCN training/testing)
        save_dir = os.path.join(args.atkg_save_dir, args.dataset, f"llm_{args.llm_model}_transductive")
        save_path = os.path.join(save_dir, f"attacked_texts_seed{seed}_ptb{int(args.ptb_rate*100)}.json")
        if os.path.exists(save_path):
            try:
                with open(save_path, 'r') as f:
                    saved = json.load(f)
                meta = saved.get("metadata", {})
                if (
                    meta.get("dataset") == args.dataset and
                    meta.get("emb_type") == args.emb_type and
                    meta.get("llm_model") == args.llm_model and
                    meta.get("llm_provider") == args.llm_provider and
                    float(meta.get("ptb_rate")) == args.ptb_rate
                ):
                    print(f"✓ Found existing attacked texts: {save_path}")
                    print("✓ Skipping GCN training/testing for this seed.")
                    continue
            except Exception:
                print(f"⚠ Found existing file but failed to validate metadata: {save_path}. Will recompute.")
        
        # Generate experiment hash for this seed
        experiment_hash = db.get_experiment_hash(
            args.dataset, args.llm_provider, args.llm_model, 
            args.emb_type, args.ptb_rate, seed, "transductive"
        )
        
        # Check if this specific experiment is already completed
        if db.check_experiment_completion(experiment_hash):
            print(f"✓ Experiment for seed {seed} already completed, loading results...")
            
            # Try to load existing attacked texts
            target_nodes_dummy = torch.arange(10)  # Dummy for loading check
            attacked_texts = load_attacked_texts_if_exists(args, target_nodes_dummy, seed)
            
            if attacked_texts is not None:
                print(f"✓ Loaded existing results for seed {seed}, skipping to next seed")
                continue
            else:
                print(f"⚠ Could not load existing results for seed {seed}, will recompute")
        
        # Load data
        full_data = load_graph_dataset_for_gnn(args.dataset, device, re_split=args.re_split, 
                                               path_prefix=args.root_path, emb_model=args.emb_type)
        full_data = full_data.to(device)
        
        num_features = full_data.x.shape[-1]
        num_classes = full_data.y.max().item() + 1
        num_train_nodes = full_data.train_mask.sum().item()
        
        print(f"Dataset: {args.dataset} | Train Nodes: {num_train_nodes} | Features: {num_features}")
        
        # Setup node texts
        node_texts = full_data.raw_texts if hasattr(full_data, 'raw_texts') and full_data.raw_texts else \
                    [f"Text content for node {i}" for i in range(full_data.num_nodes)]
        
        # Setup text encoder
        text_encoder = None
        if args.emb_type != "bow":
            text_encoder = TextEncoder(args.emb_type, "LM" if args.emb_type != "Mistral-7B" else "LLM", device)
            print(f"Text encoder: {args.emb_type}")
        
        # Initialize and train model
        gcn_model = GCN(num_features, num_classes, hids=[64])
        trainer = Trainer(gcn_model, device=device)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        cb = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}.pt', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        
        trainer.fit(full_data, mask=(full_data.train_mask, full_data.val_mask), 
                   verbose=0, callbacks=[cb, early_stopping], epochs=args.epochs)

        # Evaluate clean performance
        clean_acc = trainer.evaluate(full_data, mask=full_data.test_mask)['acc']
        clean_accs.append(clean_acc)
        print(f"Clean accuracy: {clean_acc:.4f}")
        
        # Select attack targets (low-degree, correctly classified training nodes)
        num_train_nodes = full_data.train_mask.sum().item()
        num_attack_nodes = max(1, int(num_train_nodes * args.ptb_rate))
        dummy_test_data = full_data  # In transductive, test_data is same as full_data
        
        target_nodes, local_test_indices = degree_weighted_node_selection(
            full_data, dummy_test_data, num_attack_nodes, device, atk_type='transductive',
            model=gcn_model, only_correct_predictions=True)
        
        # Get texts and labels for attack
        target_texts = [node_texts[i.cpu().item()] for i in target_nodes]
        target_labels = full_data.y[target_nodes]
        
        print(f"Attacking {num_attack_nodes}/{num_train_nodes} train nodes with LLM")
        
        # Try to load existing attacked texts first
        attacked_texts = load_attacked_texts_if_exists(args, target_nodes, seed)
        
        if attacked_texts is None:
            try:
                attacked_texts = await apply_llm_text_attack(
                    args.dataset, target_nodes, target_texts, target_labels, 
                    full_data.label_name, args.llm_provider, args.llm_model, 
                    full_data.edge_index, full_data.y, args.emb_type, args.ptb_rate, seed, "transductive"
                )
                
                # Save attacked texts
                save_attacked_texts(args, target_nodes, attacked_texts, seed)
            
            except Exception as e:
                print(f"❌ LLM attack failed for seed {seed}: {str(e)}")
                print(f"Progress has been saved to database. Rerun the script to resume.")
                
                # Show current experiment progress
                progress = db.get_experiment_progress(experiment_hash)
                print(f"Current progress: {progress['completed_batches']}/{progress['total_batches']} batches "
                      f"({progress['completion_percentage']:.1f}%) completed")
                
                # Skip this seed and continue with next
                continue
        
        # Collect examples for logging
        seed_examples = save_llm_attack_examples(
            args.dataset, target_nodes, target_texts, attacked_texts, 
            target_labels, full_data.label_name, full_data.edge_index, full_data.y, seed)
        all_examples.append(seed_examples)
        
        # Apply attacks to data
        attacked_data = copy.deepcopy(full_data)
        new_embeddings, changed_indices, changed_nodes = encode_attacked_texts_efficiently(
            text_encoder, target_texts, attacked_texts, target_nodes, 
            batch_size=args.batch_size, dataset=args.dataset, emb_type=args.emb_type, device=device)
        
        if new_embeddings is not None and len(changed_nodes) > 0:
            for i, node_idx in enumerate(changed_nodes):
                attacked_data.x[node_idx] = new_embeddings[i]
            print(f"Re-encoded {len(changed_nodes)}/{len(target_texts)} changed texts")
        
        # Evaluate attacked performance on attacked data, retrained
        gcn_model.reset_parameters()
        trainer = Trainer(gcn_model, device=device)
        trainer.reset_optimizer(lr=0.01, weight_decay=5e-4)
        
        cb = ModelCheckpoint(f'model_{args.dataset}_{args.emb_type}_{seed}_attacked.pt', monitor='val_acc')
        early_stopping = EarlyStopping(monitor='val_acc', patience=args.patience, mode='max')
        
        trainer.fit(attacked_data, mask=(attacked_data.train_mask, attacked_data.val_mask), 
                   verbose=0, callbacks=[cb, early_stopping], epochs=args.epochs)
        attacked_acc = trainer.evaluate(attacked_data, mask=attacked_data.test_mask)['acc']
        attacked_accs.append(attacked_acc)
        print(f"Attacked accuracy: {attacked_acc:.4f}")
        
        # Store detailed results
        seed_detailed = {
            "seed": seed,
            "clean_acc": float(clean_acc),
            "attacked_acc": float(attacked_acc),
            "num_attack_nodes": num_attack_nodes,
            "num_train_nodes": num_train_nodes,
            "attack_success_rate": len(changed_nodes) / len(target_texts) if target_texts else 0,
            "examples": seed_examples
        }
        all_detailed_results.append(seed_detailed)
        
        # Cleanup
        del gcn_model, trainer, attacked_data
        if text_encoder:
            del text_encoder
        torch.cuda.empty_cache()
    
    # Save final results
    if len(clean_accs) > 0 and len(attacked_accs) > 0:
        save_detailed_results(args, clean_accs, attacked_accs, all_examples, all_detailed_results, setting="transductive")
        print(f"\n{'='*60}")
        print("LLM attack experiment completed successfully!")
        print(f"{'='*60}")
    else:
        print(f"\n{'='*60}")
        print("⚠ No complete results to save. Check database for partial progress.")
        print("Rerun the script to resume interrupted experiments.")
        
        # Show overall progress for all experiments
        experiments = db.list_experiments()
        if experiments:
            print("\nCurrent experiment status in database:")
            for exp in experiments:
                print(f"  Experiment {exp['short_hash']}: {exp['completed_batches']}/{exp['total_batches']} batches "
                      f"({exp['completion_percentage']:.1f}%) completed")
        print(f"{'='*60}")


if __name__ == "__main__":
    asyncio.run(main()) 