import os
import torch
import numpy as np
import json
import random
from datetime import datetime
from tqdm import tqdm
import copy
import gc
import pickle
from sklearn.feature_extraction.text import CountVectorizer

# Text attack libraries
import textattack
from textattack.attack_recipes import TextFoolerJin2019, BAEGarg2019, PWWSRen2019, HotFlipEbrahimi2017
from textattack.models.wrappers import ModelWrapper

import sys
sys.path.append("../")
from common import TextEncoder


def load_vocabulary(vocab_cache_path="./vocab_cache/nltk_words.pkl"):
    """Load vocabulary from cache file"""
    if not os.path.exists(vocab_cache_path):
        raise FileNotFoundError(f"Vocabulary cache not found at {vocab_cache_path}. Please run 'python preprocess_vocab.py' first to create vocabulary cache.")
    
    print(f"Loading vocabulary from cache: {vocab_cache_path}")
    with open(vocab_cache_path, 'rb') as f:
        vocab = pickle.load(f)
    print(f"Loaded {len(vocab)} words from cache")
    return vocab


# Load vocabulary once at module import
vocab = load_vocabulary()


def load_bow_vectorizer(dataset, base_path="/path/to/GraphAD_data/datasets"):
    """Load BOW vectorizer from saved vocabulary"""
    vocab_dir = os.path.join(base_path, "vocab", dataset)
    vocab_path = os.path.join(vocab_dir, "bow_vocabulary.pkl")
    
    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"BOW vocabulary not found at {vocab_path}")
    
    with open(vocab_path, 'rb') as f:
        vectorizer = pickle.load(f)
    
    return vectorizer


def encode_texts_bow(texts, dataset, base_path="/path/to/GraphAD_data/datasets", device='cuda'):
    """Encode texts using BOW with saved vocabulary, return tensor on target device"""
    vectorizer = load_bow_vectorizer(dataset, base_path)
    bow_matrix = vectorizer.transform(texts)
    return torch.FloatTensor(bow_matrix.toarray()).to(device)


def random_perturb(text, ratio=0.2):
    """Random word substitution attack"""
    tokens = text.split()
    if len(tokens) == 0:
        return text
    
    n = max(1, int(len(tokens) * ratio))
    indices = random.sample(range(len(tokens)), min(n, len(tokens)))
    
    for i in indices:
        tokens[i] = random.choice(vocab)
    
    return ' '.join(tokens)


def encode_texts_batch(text_encoder, texts, batch_size=8, device='cuda'):
    """Encode texts in batches directly on target device for efficiency"""
    if not texts:
        # Return empty tensor on target device
        if text_encoder and hasattr(text_encoder, 'dim'):
            return torch.empty(0, text_encoder.dim, dtype=torch.float32, device=device)
        else:
            return torch.empty(0, dtype=torch.float32, device=device)
    
    # Check if text_encoder is None (should not happen after our fixes)
    if text_encoder is None:
        raise ValueError("text_encoder is None. This should not happen with proper BOW/encoder type handling.")
    
    embeddings = []
    # Reduce batch size for large models like roberta
    if hasattr(text_encoder, 'model_name') and 'roberta' in text_encoder.model_name.lower():
        batch_size = min(batch_size, 4)
    
    # Ensure text encoder model is on target device
    if hasattr(text_encoder, 'model') and hasattr(text_encoder.model, 'parameters'):
        text_encoder.model.to(device)
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        batch_embeddings = []
        
        with torch.no_grad():
            for text in batch_texts:
                emb = text_encoder(text)
                if torch.is_tensor(emb):
                    emb = emb.to(device)
                elif isinstance(emb, np.ndarray):
                    emb = torch.from_numpy(emb).float().to(device)
                else:
                    # Convert other types to tensor on target device
                    emb = torch.tensor(emb, dtype=torch.float32, device=device)
                batch_embeddings.append(emb)
                
        embeddings.extend(batch_embeddings)
        
        # Clean up memory periodically
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Concatenate all embeddings on target device
    if embeddings:
        result = torch.cat(embeddings, dim=0)
        return result.to(device)
    else:
        # If no embeddings were generated, return empty tensor on target device
        return torch.empty(0, dtype=torch.float32, device=device)


class BertGCNWrapper(ModelWrapper):
    """Wrapper for TextAttack - Transductive setting with GPU-first processing"""
    def __init__(self, text_encoder, gcn_model, data, target_nodes, device='cuda', dataset=None, emb_type=None):
        self.text_encoder = text_encoder
        self.gcn_model = gcn_model
        self.data = data.to(device)  # Ensure data is on GPU
        self.target_nodes = target_nodes
        self.device = device
        self.dataset = dataset  # Added for BOW encoding
        self.emb_type = emb_type  # Added for encoding type
        self.cache = {}  # Add caching to reduce redundant computations
        
        # Move models to target device - handle None text_encoder for BOW
        if text_encoder is not None and hasattr(text_encoder, 'model'):
            text_encoder.model.to(device)
        gcn_model.to(device)
        
        # Required by TextAttack - point to the GCN model
        self.model = gcn_model
        
    def __call__(self, text_list):
        if isinstance(text_list, str):
            text_list = [text_list]
        
        # Check cache first
        cache_key = str(hash(tuple(text_list)))
        if cache_key in self.cache:
            return self.cache[cache_key]
            
        # Handle different encoding types
        with torch.no_grad():
            if self.emb_type == "bow":
                # Use BOW encoding
                if self.dataset is None:
                    raise ValueError("Dataset name required for BOW encoding")
                x_perturbed = encode_texts_bow(text_list, self.dataset, device=self.device)
            else:
                # Use text encoder for other types
                if self.text_encoder is None:
                    raise ValueError(f"Text encoder cannot be None for embedding type: {self.emb_type}")
                x_perturbed = encode_texts_batch(self.text_encoder, text_list, batch_size=4, device=self.device)
            
            # Process each text independently like inductive wrapper
            predictions = []
            for i in range(len(text_list)):
                # Create new data with perturbed features (all on GPU)
                new_data = copy.deepcopy(self.data)
                if i < len(x_perturbed):
                    # Update only the target node for this text
                    for node_idx in self.target_nodes:
                        new_data.x[node_idx] = x_perturbed[i]
                
                # Get predictions (everything already on GPU)
                logits = self.gcn_model(new_data.x, new_data.edge_index)
                probs = torch.softmax(logits, dim=-1)
                
                # Get prediction for target nodes
                result_for_text = probs[self.target_nodes]
                predictions.append(result_for_text)
            
            # Concatenate all predictions
            result = torch.cat(predictions, dim=0)
            
            # Cache result (keep on GPU for efficiency)
            self.cache[cache_key] = result
            
            return result


class BertGCNInductiveWrapper(ModelWrapper):
    """Wrapper for TextAttack - Inductive setting, single node attack"""
    def __init__(self, text_encoder, gcn_model, test_data, target_node_idx, device='cuda', dataset=None, emb_type=None):
        self.text_encoder = text_encoder
        self.gcn_model = gcn_model
        self.test_data = test_data.to(device)
        self.target_node_idx = target_node_idx.to(device)
        self.device = device
        self.dataset = dataset  # Added for BOW encoding
        self.emb_type = emb_type  # Added for encoding type
        self.cache = {}
        
        # Move models to target device - handle None text_encoder for BOW
        if text_encoder is not None and hasattr(text_encoder, 'model'):
            text_encoder.model.to(device)
        gcn_model.to(device)
        
        # Required by TextAttack - point to the GCN model
        self.model = gcn_model
        
    def __call__(self, text_list):
        if isinstance(text_list, str):
            text_list = [text_list]
        
        cache_key = str(hash(tuple(text_list)))
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        with torch.no_grad():
            # Handle different encoding types
            if self.emb_type == "bow":
                # Use BOW encoding
                if self.dataset is None:
                    raise ValueError("Dataset name required for BOW encoding")
                x_perturbed = encode_texts_bow(text_list, self.dataset, device=self.device)
            else:
                # Use text encoder for other types
                if self.text_encoder is None:
                    raise ValueError(f"Text encoder cannot be None for embedding type: {self.emb_type}")
                x_perturbed = encode_texts_batch(self.text_encoder, text_list, batch_size=4, device=self.device)
            
            predictions = []
            for i in range(len(text_list)):
                # Create new test data with perturbed feature for each text
                new_test_data = copy.deepcopy(self.test_data)
                if i < len(x_perturbed):
                    new_test_data.x[self.target_node_idx] = x_perturbed[i]
                
                # Get prediction for the target node
                logits = self.gcn_model(new_test_data.x, new_test_data.edge_index)
                probs = torch.softmax(logits, dim=-1)
                predictions.append(probs[self.target_node_idx].unsqueeze(0))
            
            # Concatenate all predictions
            result = torch.cat(predictions, dim=0)
            
            # Cache and return
            self.cache[cache_key] = result
            return result


def apply_textattack_batch(attack_type, texts, labels, wrapper_info, text_ptb_rate=0.5, batch_size=8):
    """Apply TextAttack-based attacks with batch processing"""
    print(f"Starting batch {attack_type} attack with batch size {batch_size}...")
    
    # Choose attack recipe
    if attack_type == "textfooler":
        attack_class = TextFoolerJin2019
    elif attack_type == "bae":
        attack_class = BAEGarg2019
    elif attack_type == "pwws":
        attack_class = PWWSRen2019
    elif attack_type == "hotflip":
        attack_class = HotFlipEbrahimi2017
    else:
        raise ValueError(f"Unknown attack type: {attack_type}")
    
    attacked_texts = texts.copy()
    success_count = 0
    
    # Process in batches
    for batch_start in tqdm(range(0, len(texts), batch_size), desc=f"Batch {attack_type} attack"):
        batch_end = min(batch_start + batch_size, len(texts))
        batch_texts = texts[batch_start:batch_end]
        batch_labels = labels[batch_start:batch_end]
        
        # Attack each text in the batch (TextAttack still requires individual attacks)
        for i, (text, label) in enumerate(zip(batch_texts, batch_labels)):
            try:
                # Create single-node wrapper for this text
                if wrapper_info['type'] == 'inductive':
                    single_wrapper = BertGCNInductiveWrapper(
                        wrapper_info['text_encoder'],
                        wrapper_info['gcn_model'],
                        wrapper_info['test_data'],
                        wrapper_info['target_nodes'][batch_start + i],
                        wrapper_info['device'],
                        wrapper_info['dataset'],
                        wrapper_info['emb_type']
                    )
                else:
                    single_wrapper = BertGCNWrapper(
                        wrapper_info['text_encoder'],
                        wrapper_info['gcn_model'],
                        wrapper_info['data'],
                        wrapper_info['target_nodes'][batch_start + i:batch_start + i + 1],
                        wrapper_info['device'],
                        wrapper_info['dataset'],
                        wrapper_info['emb_type']
                    )
                
                attack = attack_class.build(single_wrapper)
                
                # Configure perturbation budget
                if hasattr(attack, 'constraints'):
                    for constraint in attack.constraints:
                        if hasattr(constraint, 'max_percent'):
                            constraint.max_percent = text_ptb_rate
                        elif hasattr(constraint, 'modification_rate'):
                            constraint.modification_rate = text_ptb_rate
                
                result = attack.attack(text, label.item() if torch.is_tensor(label) else label)
                if result.perturbed_text():
                    attacked_texts[batch_start + i] = result.perturbed_text()
                    success_count += 1
                
                # Cleanup
                del single_wrapper, attack
                
            except Exception as e:
                print(f"Error attacking text {batch_start + i}: {e}")
        
        # Cleanup and memory management
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f"Batch attack success rate: {success_count}/{len(texts)} = {success_count/len(texts):.3f}")
    return attacked_texts


def apply_textattack(attack_type, texts, labels, wrapper_info, text_ptb_rate=0.5):
    """Apply TextAttack-based attacks with memory efficiency (original serial version)"""
    # Choose attack recipe
    if attack_type == "textfooler":
        attack_class = TextFoolerJin2019
    elif attack_type == "bae":
        attack_class = BAEGarg2019
    elif attack_type == "pwws":
        attack_class = PWWSRen2019
    elif attack_type == "hotflip":
        attack_class = HotFlipEbrahimi2017
    else:
        raise ValueError(f"Unknown attack type: {attack_type}")
    
    # Attack all provided texts
    attacked_texts = texts.copy()
    success_count = 0
    
    for i in tqdm(range(len(texts)), desc=f"Applying {attack_type} attack"):
        # Create individual wrapper for this node
        if wrapper_info['type'] == 'inductive':
            wrapper = BertGCNInductiveWrapper(
                wrapper_info['text_encoder'], 
                wrapper_info['gcn_model'], 
                wrapper_info['test_data'], 
                wrapper_info['target_nodes'][i], 
                wrapper_info['device'],
                wrapper_info['dataset'],
                wrapper_info['emb_type']
            )
        else:
            wrapper = BertGCNWrapper(
                wrapper_info['text_encoder'],
                wrapper_info['gcn_model'],
                wrapper_info['data'],
                wrapper_info['target_nodes'][i:i+1],
                wrapper_info['device'],
                wrapper_info['dataset'],
                wrapper_info['emb_type']
            )
        
        # Build attack for this specific wrapper
        attack = attack_class.build(wrapper)
        
        # Configure perturbation budget if supported
        if hasattr(attack, 'constraints'):
            for constraint in attack.constraints:
                if hasattr(constraint, 'max_percent'):
                    constraint.max_percent = text_ptb_rate
                elif hasattr(constraint, 'modification_rate'):
                    constraint.modification_rate = text_ptb_rate
        
        result = attack.attack(texts[i], labels[i].item())
        if result.perturbed_text():
            attacked_texts[i] = result.perturbed_text()
            success_count += 1
            
        # Clean up memory after each attack
        del wrapper, attack
        torch.cuda.empty_cache()
        gc.collect()
    
    print(f"Attack success rate: {success_count}/{len(texts)} = {success_count/len(texts):.3f}")
    return attacked_texts


def apply_text_attack(attack_type, texts, labels, text_encoder=None, gcn_model=None, 
                     data=None, target_nodes=None, text_ptb_rate=0.5, device='cuda', **kwargs):
    """Apply the specified text attack method with batch processing optimization"""
    if attack_type == "random":
        return [random_perturb(text, text_ptb_rate) for text in texts]
    
    elif attack_type in ["textfooler", "bae", "pwws", "hotflip"]:
        # Prepare wrapper info
        if kwargs.get('inductive', False):
            print("Using inductive attack")
            wrapper_info = {
                'type': 'inductive',
                'text_encoder': text_encoder,
                'gcn_model': gcn_model,
                'test_data': kwargs['test_data'],
                'target_nodes': target_nodes,
                'device': device,
                'dataset': kwargs.get('dataset'),
                'emb_type': kwargs.get('emb_type')
            }
        else:
            print("Using transductive attack")
            wrapper_info = {
                'type': 'transductive',
                'text_encoder': text_encoder,
                'gcn_model': gcn_model,
                'data': data,
                'target_nodes': target_nodes,
                'device': device,
                'dataset': kwargs.get('dataset'),
                'emb_type': kwargs.get('emb_type')
            }
        
        # Choose attack strategy: batch or serial
        use_batch = kwargs.get('use_batch', True)
        batch_size = kwargs.get('batch_size', 8)
        
        if use_batch:
            print(f"Using batch attack with batch size {batch_size}")
            return apply_textattack_batch(attack_type, texts, labels, wrapper_info, text_ptb_rate, batch_size)
        else:
            print("Using serial attack")
            return apply_textattack(attack_type, texts, labels, wrapper_info, text_ptb_rate)
    
    else:
        raise ValueError(f"Attack method {attack_type} not implemented")


def save_results(args, clean_accs, attacked_accs, setting="transductive"):
    """Save attack results to JSON file"""
    results = {
        "args": vars(args),
        "clean": {
            "mean": round(float(np.mean(clean_accs)) * 100, 2),
            "std": round(float(np.std(clean_accs)) * 100, 2),
            "all_runs": [round(float(acc) * 100, 2) for acc in clean_accs]
        },
        "attacked": {
            "mean": round(float(np.mean(attacked_accs)) * 100, 2),
            "std": round(float(np.std(attacked_accs)) * 100, 2),
            "all_runs": [round(float(acc) * 100, 2) for acc in attacked_accs]
        },
    }
    
    log_dir = os.path.join("./logs", args.dataset, f"text_{args.attack}_{setting}")
    os.makedirs(log_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_path = os.path.join(log_dir, f"results_{args.emb_type}_{int(args.ptb_rate*100)}_{timestamp}.json")
    
    with open(log_path, 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{setting.capitalize()} Text Attack Results Summary:")
    print(f"Clean - Acc: {results['clean']['mean']:.2f} ± {results['clean']['std']:.2f}")
    print(f"Attacked - Acc: {results['attacked']['mean']:.2f} ± {results['attacked']['std']:.2f}")


def encode_attacked_texts_efficiently(text_encoder, original_texts, attacked_texts, 
                                    target_nodes, batch_size=4, dataset=None, emb_type=None, device='cuda'):
    """Efficiently encode only the attacked texts that actually changed, return tensors on target device"""
    # Find which texts actually changed
    changed_indices = []
    changed_texts = []
    changed_nodes = []
    
    for i, (orig, attacked) in enumerate(zip(original_texts, attacked_texts)):
        if orig != attacked:
            changed_indices.append(i)
            changed_texts.append(attacked)
            changed_nodes.append(target_nodes[i])
    
    if not changed_texts:
        return None, [], []
    
    # Handle BOW encoding differently
    if emb_type == "bow":
        if dataset is None:
            raise ValueError("Dataset name required for BOW encoding")
        new_embeddings = encode_texts_bow(changed_texts, dataset, device=device)
    else:
        # Encode changed texts in smaller batches for memory efficiency
        new_embeddings = encode_texts_batch(text_encoder, changed_texts, batch_size, device=device)
    
    # Embeddings are already on the correct device
    return new_embeddings, changed_indices, changed_nodes 


def degree_weighted_node_selection(full_data, test_data, num_nodes, device, atk_type='inductive', 
                                   model=None, only_correct_predictions=True):
    """Select test nodes with probability inversely proportional to their degree
    
    Args:
        full_data: Full graph data with test_mask
        test_data: Test data containing all nodes (train+val+test)
        num_nodes: Number of nodes to select
        device: Device to use
        atk_type: attack type
        model: GCN model for checking correct predictions (optional)
        only_correct_predictions: Whether to only attack correctly classified nodes
    
    Returns:
        selected_global_indices: Global indices in full_data
        selected_local_indices: Local indices in test nodes only
    """
    # Calculate node degrees from full graph
    degrees = torch.zeros(full_data.x.shape[0], device=device)
    edge_index = full_data.edge_index
    
    # Count outgoing edges (degree)
    for i in range(edge_index.shape[1]):
        src = edge_index[0, i]
        degrees[src] += 1
    
    # Only consider actual test nodes using test_mask
    if atk_type == 'inductive':
        target_node_indices = torch.where(full_data.test_mask)[0]
        target_degrees = degrees[target_node_indices]
    else:
        target_node_indices = torch.where(full_data.train_mask)[0]
        target_degrees = degrees[target_node_indices]
    
    # Filter for correctly classified nodes if requested
    if only_correct_predictions and model is not None:
        with torch.no_grad():
            if atk_type == 'inductive':
                # For inductive setting, use test_data
                logits = model(test_data.x, test_data.edge_index)
                predictions = logits.argmax(dim=1)
                true_labels = test_data.y
            else:
                # For transductive setting, use full_data
                logits = model(full_data.x, full_data.edge_index)
                predictions = logits.argmax(dim=1)
                true_labels = full_data.y
            
            # Check which target nodes are correctly classified
            correct_predictions = predictions[target_node_indices] == true_labels[target_node_indices]
            
            # Filter target nodes to only include correctly classified ones
            correct_indices_mask = correct_predictions
            target_node_indices = target_node_indices[correct_indices_mask]
            target_degrees = target_degrees[correct_indices_mask]
            
            print(f"Filtered to {len(target_node_indices)} correctly classified nodes out of {len(correct_predictions)} total nodes")
            
            # Ensure we don't request more nodes than available
            num_nodes = min(num_nodes, len(target_node_indices))
            
            if len(target_node_indices) == 0:
                print("Warning: No correctly classified nodes found!")
                return torch.tensor([], device=device), torch.tensor([], device=device)
    
    # Inverse weighting: lower degree = higher probability
    # Add small epsilon to avoid division by zero
    inverse_weights = 1.0 / (target_degrees + 1e-8)
    probabilities = inverse_weights / inverse_weights.sum()
    
    # Sample nodes according to inverse degree weighting
    selected_local_indices = torch.multinomial(probabilities, num_nodes, replacement=False)
    selected_global_indices = target_node_indices[selected_local_indices]
    
    return selected_global_indices, selected_local_indices 