import torch
import copy
import random
import numpy as np


def add_structure_noise(edge_index, num_nodes, noise_ratio=0.1, seed=None, train_mask=None):
    """
    Add structural noise by adding random edges
    Args:
        edge_index: Original edge index [2, num_edges]
        num_nodes: Number of nodes in the graph
        noise_ratio: Ratio of noise edges to add (default: 0.1)
        seed: Random seed for reproducibility
        train_mask: Boolean mask for training nodes (if provided, only add edges involving training nodes)
    Returns:
        noisy_edge_index: Edge index with added noise edges
    """
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
    
    # Convert to numpy for easier manipulation
    edge_index_np = edge_index.cpu().numpy()
    existing_edges = set(zip(edge_index_np[0], edge_index_np[1]))
    
    # Calculate number of noise edges to add
    if train_mask is not None:
        # Only count edges involving training nodes
        train_nodes = set(train_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist())
        train_edges = 0
        for i in range(edge_index.shape[1]):
            src, dst = edge_index_np[0, i], edge_index_np[1, i]
            if src in train_nodes or dst in train_nodes:
                train_edges += 1
        num_noise_edges = int(train_edges * noise_ratio)
        print(f"Training nodes: {len(train_nodes)}, Training-related edges: {train_edges}")
    else:
        # Original behavior: use all edges
        num_existing_edges = edge_index.shape[1]
        num_noise_edges = int(num_existing_edges * noise_ratio)
    
    # Generate random edges
    noise_edges = []
    attempts = 0
    max_attempts = num_noise_edges * 10  # Prevent infinite loop
    
    while len(noise_edges) < num_noise_edges and attempts < max_attempts:
        if train_mask is not None:
            # Only add edges involving training nodes
            train_node_list = list(train_nodes)
            if len(train_node_list) < 2:
                break
            
            # At least one node must be a training node
            if random.random() < 0.5:
                # Both nodes are training nodes
                src = random.choice(train_node_list)
                dst = random.choice(train_node_list)
            else:
                # One training node, one any node
                src = random.choice(train_node_list)
                dst = np.random.randint(0, num_nodes)
        else:
            # Original behavior: any two nodes
            src = np.random.randint(0, num_nodes)
            dst = np.random.randint(0, num_nodes)
        
        # Avoid self-loops and existing edges
        if src != dst and (src, dst) not in existing_edges and (dst, src) not in existing_edges:
            noise_edges.append([src, dst])
            # Add reverse edge for undirected graph
            noise_edges.append([dst, src])
            existing_edges.add((src, dst))
            existing_edges.add((dst, src))
        
        attempts += 1
    
    if len(noise_edges) > 0:
        noise_edges = np.array(noise_edges).T
        # Combine original and noise edges
        noisy_edge_index = np.concatenate([edge_index_np, noise_edges], axis=1)
        return torch.LongTensor(noisy_edge_index).to(edge_index.device)
    else:
        return edge_index


def add_text_noise(raw_texts, noise_ratio=0.1, seed=None, train_mask=None):
    """
    Add text noise by shuffling a percentage of texts
    Args:
        raw_texts: List of raw text strings
        noise_ratio: Ratio of texts to shuffle (default: 0.1)
        seed: Random seed for reproducibility
        train_mask: Boolean mask for training nodes (if provided, only shuffle training node texts)
    Returns:
        noisy_texts: List with some texts shuffled
    """
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)
    
    # Create a copy to avoid modifying the original
    noisy_texts = raw_texts.copy()
    
    if train_mask is not None:
        # Only apply noise to training nodes
        train_indices = train_mask.nonzero(as_tuple=False).squeeze().cpu().numpy().tolist()
        if isinstance(train_indices, int):
            train_indices = [train_indices]
        
        num_train_texts = len(train_indices)
        num_noise_texts = int(num_train_texts * noise_ratio)
        
        if num_noise_texts > 0:
            # Randomly select training indices to shuffle
            noise_indices = random.sample(train_indices, min(num_noise_texts, num_train_texts))
            
            # Get the texts to be shuffled (only from training nodes)
            texts_to_shuffle = [noisy_texts[i] for i in noise_indices]
            random.shuffle(texts_to_shuffle)
            
            # Replace with shuffled texts
            for i, idx in enumerate(noise_indices):
                noisy_texts[idx] = texts_to_shuffle[i]
        
        print(f"Applied text noise to {num_noise_texts}/{num_train_texts} training texts")
    else:
        # Original behavior: shuffle any texts
        num_texts = len(raw_texts)
        num_noise_texts = int(num_texts * noise_ratio)
        
        if num_noise_texts > 0:
            # Randomly select indices to shuffle
            noise_indices = random.sample(range(num_texts), min(num_noise_texts, num_texts))
            
            # Get the texts to be shuffled
            texts_to_shuffle = [noisy_texts[i] for i in noise_indices]
            random.shuffle(texts_to_shuffle)
            
            # Replace with shuffled texts
            for i, idx in enumerate(noise_indices):
                noisy_texts[idx] = texts_to_shuffle[i]
    
    return noisy_texts


def apply_noise_to_graph_data(graph_data, prompt_type=None, noise_ratio=0.1, seed=None):
    """
    Apply noise to graph data based on prompt type
    Args:
        graph_data: PyTorch Geometric Data object
        prompt_type: Type of noise ("noise" for structure, "noisetxt" for text, "noisefull" for both)
        noise_ratio: Ratio of noise to apply
        seed: Random seed for reproducibility
    Returns:
        noisy_graph_data: Data object with applied noise
    """
    # Create a deep copy to avoid modifying the original
    noisy_graph_data = copy.deepcopy(graph_data)
    
    # Check if we're in transductive setting (has train_mask)
    train_mask = getattr(graph_data, 'train_mask', None)
    
    if prompt_type == "noise":
        # Add structural noise
        if train_mask is not None:
            print(f"Applying structural noise (ratio={noise_ratio}) to training nodes only...")
        else:
            print(f"Applying structural noise (ratio={noise_ratio}) to graph data...")
        
        noisy_graph_data.edge_index = add_structure_noise(
            graph_data.edge_index, graph_data.num_nodes, noise_ratio, seed, train_mask
        )
        print(f"Original edges: {graph_data.edge_index.shape[1]}, "
              f"Noisy edges: {noisy_graph_data.edge_index.shape[1]}")
    
    elif prompt_type == "noisetxt":
        # Add text noise
        if train_mask is not None:
            print(f"Applying text noise (ratio={noise_ratio}) to training nodes only...")
        else:
            print(f"Applying text noise (ratio={noise_ratio}) to graph data...")
        
        noisy_graph_data.raw_texts = add_text_noise(graph_data.raw_texts, noise_ratio, seed, train_mask)
        if train_mask is None:
            print(f"Applied text noise to {len(graph_data.raw_texts)} texts")
    
    elif prompt_type == "noisefull":
        # Add both structural and text noise
        if train_mask is not None:
            print(f"Applying both structural and text noise (ratio={noise_ratio}) to training nodes only...")
        else:
            print(f"Applying both structural and text noise (ratio={noise_ratio}) to graph data...")
        
        noisy_graph_data.edge_index = add_structure_noise(
            graph_data.edge_index, graph_data.num_nodes, noise_ratio, seed, train_mask
        )
        noisy_graph_data.raw_texts = add_text_noise(graph_data.raw_texts, noise_ratio, seed, train_mask)
        print(f"Original edges: {graph_data.edge_index.shape[1]}, "
              f"Noisy edges: {noisy_graph_data.edge_index.shape[1]}")
        if train_mask is None:
            print(f"Applied text noise to {len(graph_data.raw_texts)} texts")
    
    return noisy_graph_data


def apply_noise_to_inductive_data(train_data, val_data, test_data, full_graph_data, prompt_type=None, noise_ratio=0.1, seed=None):
    """
    Apply noise to inductive data (only to train_data)
    Args:
        train_data, val_data, test_data: Split data objects
        full_graph_data: Full graph data object
        prompt_type: Type of noise ("noise" for structure, "noisetxt" for text, "noisefull" for both)
        noise_ratio: Ratio of noise to apply
        seed: Random seed for reproducibility
    Returns:
        (noisy_train_data, clean_val_data, clean_test_data, updated_full_graph_data)
    """
    if prompt_type is None:
        return train_data, val_data, test_data, full_graph_data
    
    # Create copies
    noisy_train_data = copy.deepcopy(train_data)
    clean_val_data = copy.deepcopy(val_data)
    clean_test_data = copy.deepcopy(test_data)
    updated_full_graph_data = copy.deepcopy(full_graph_data)
    
    if prompt_type == "noise":
        # Add structural noise only to training data
        print(f"Applying structural noise (ratio={noise_ratio}) to training data...")
        noisy_train_data.edge_index = add_structure_noise(
            train_data.edge_index, train_data.num_nodes, noise_ratio, seed, train_mask=None
        )
        print(f"Train edges: {train_data.edge_index.shape[1]} -> {noisy_train_data.edge_index.shape[1]}")
    
    elif prompt_type == "noisetxt":
        # Add text noise only to training data
        print(f"Applying text noise (ratio={noise_ratio}) to training data...")
        noisy_train_data.raw_texts = add_text_noise(train_data.raw_texts, noise_ratio, seed, train_mask=None)
        print(f"Applied text noise to training texts")
    
    elif prompt_type == "noisefull":
        # Add both structural and text noise only to training data
        print(f"Applying both structural and text noise (ratio={noise_ratio}) to training data...")
        noisy_train_data.edge_index = add_structure_noise(
            train_data.edge_index, train_data.num_nodes, noise_ratio, seed, train_mask=None
        )
        noisy_train_data.raw_texts = add_text_noise(train_data.raw_texts, noise_ratio, seed, train_mask=None)
        print(f"Train edges: {train_data.edge_index.shape[1]} -> {noisy_train_data.edge_index.shape[1]}")
        print(f"Applied text noise to training texts")
    
    return noisy_train_data, clean_val_data, clean_test_data, updated_full_graph_data 