import random
from typing import Optional
import numpy as np
from tqdm.auto import tqdm

from greatx.attack.untargeted.dice_attack import DICEAttack


class STRGAttack(DICEAttack):
    r"""Implementation of `STRG` attack based on the DICE principle
    with training-aware edge modifications.

    STRG conducts attacks by modifying edges while ensuring:
    1. Each modified edge connects to at least one training node
    2. Node selection considers degree (lower degree = higher probability)
    3. Follows the DICE principle of "Disconnect Internally, Connect Externally"

    Parameters
    ----------
    data : Data
        PyG-like data denoting the input graph
    device : str, optional
        the device of the attack running on, by default "cpu"
    seed : Optional[int], optional
        the random seed for reproducing the attack, by default None
    name : Optional[str], optional
        name of the attacker, if None, it would be :obj:`__class__.__name__`,
        by default None
    kwargs : additional arguments of :class:`greatx.attack.Attacker`,

    Raises
    ------
    TypeError
        unexpected keyword argument in :obj:`kwargs`

    Example
    -------
    .. code-block:: python

        from greatx.dataset import GraphDataset
        import torch_geometric.transforms as T

        dataset = GraphDataset(root='.', name='Cora',
                                transform=T.LargestConnectedComponents())
        data = dataset[0]

        from greatx.attack.untargeted import STRGAttack
        attacker = STRGAttack(data)
        attacker.reset()
        
        # Get masks (example)
        train_mask = data.train_mask
        val_mask = data.val_mask  
        test_mask = data.test_mask
        
        attacker.attack(0.05, train_mask=train_mask, 
                       val_mask=val_mask, test_mask=test_mask)
        attacker.data() # get attacked graph

    Note
    ----
    * Please remember to call :meth:`reset` before each attack.
    """
    
    def attack(self, num_budgets=0.05, *, threshold=0.5, structure_attack=True,
               feature_attack=False, disable=False, train_mask=None, 
               val_mask=None, test_mask=None):
        """
        Conduct STRG attack with training-aware edge modifications.
        
        Parameters
        ----------
        num_budgets : float or int
            the number/percentage of perturbations allowed
        threshold : float
            probability threshold for add vs remove operations
        structure_attack : bool
            whether to attack graph structure
        feature_attack : bool  
            whether to attack node features
        disable : bool
            whether to disable progress bar
        train_mask : torch.Tensor
            boolean mask indicating training nodes
        val_mask : torch.Tensor
            boolean mask indicating validation nodes  
        test_mask : torch.Tensor
            boolean mask indicating test nodes
        """
        
        # Store masks for use in edge selection
        if train_mask is not None:
            self.train_mask = train_mask
            self.train_nodes = set(train_mask.nonzero().flatten().tolist())
        else:
            raise ValueError("train_mask is required for STRG attack")
            
        if val_mask is not None:
            self.val_mask = val_mask
            self.val_nodes = set(val_mask.nonzero().flatten().tolist())
        else:
            self.val_nodes = set()
            
        if test_mask is not None:
            self.test_mask = test_mask  
            self.test_nodes = set(test_mask.nonzero().flatten().tolist())
        else:
            self.test_nodes = set()

        num_budgets = self._check_budget(num_budgets,
                                         max_perturbations=self.num_edges // 2)

        self.num_budgets = num_budgets

        # Calculate degree-based probabilities for node selection
        self._calculate_node_probabilities()
        
        # Initialize used edges tracking to avoid duplicates
        self.used_add_edges = set()
        self.used_remove_edges = set()
        
        assert 0 < threshold <= 1
        random_arr = np.random.choice(2, self.num_budgets,
                                      p=[1 - threshold, threshold]) * 2 - 1

        # Use training nodes as influence nodes (only training nodes can be modified)
        influence_nodes = list(self.train_nodes)
        
        for it, remove_or_insert in tqdm(enumerate(random_arr),
                                         desc='Perturbing graph with STRG...',
                                         disable=disable):
            # randomly choose to add or remove edges
            if remove_or_insert > 0:
                edge = self.get_added_edge(influence_nodes)
                if edge is not None:
                    u, v = edge
                    self.add_edge(u, v, it)
                    # Track used edge
                    self.used_add_edges.add(tuple(sorted([u, v])))
                else:
                    # No more edges to add, skip this iteration
                    continue

            else:
                edge = self.get_removed_edge(influence_nodes)
                if edge is not None:
                    u, v = edge
                    self.remove_edge(u, v, it)
                    # Track used edge
                    self.used_remove_edges.add(tuple(sorted([u, v])))
                else:
                    # No more edges to remove, skip this iteration
                    continue

        return self
    
    def _calculate_node_probabilities(self):
        """Calculate node selection probabilities based on degree (inverse relationship)"""
        degrees = np.array([self.adjacency_matrix[i].nnz for i in range(self.num_nodes)])
        # Add small epsilon to avoid division by zero
        epsilon = 1e-8
        inv_degrees = 1.0 / (degrees + epsilon)
        # Normalize to get probabilities
        self.node_probs = inv_degrees / inv_degrees.sum()
    
    def _select_node_by_degree(self, candidate_nodes):
        """Select a node from candidates based on inverse degree probability"""
        if not candidate_nodes:
            return None
        
        candidate_list = list(candidate_nodes)
        candidate_probs = self.node_probs[candidate_list]
        candidate_probs = candidate_probs / candidate_probs.sum()
        
        selected_idx = np.random.choice(len(candidate_list), p=candidate_probs)
        return candidate_list[selected_idx]

    def get_added_edge(self, influence_nodes: list) -> Optional[tuple]:
        """Get an edge to add, ensuring at least one endpoint is a training node"""
        max_attempts = 100  # Limit attempts to avoid infinite loops
        
        for _ in range(max_attempts):
            # Select a training node with degree-based probability
            u = self._select_node_by_degree(self.train_nodes)
            if u is None:
                u = random.choice(list(self.train_nodes))
            
            neighbors = set(self.adjacency_matrix[u].indices.tolist())
            non_neighbors = self.nodes_set - neighbors - {u}
            
            if not non_neighbors:
                continue
                
            # Select target node with degree-based probability from non-neighbors
            v = self._select_node_by_degree(non_neighbors)
            if v is None:
                v = random.choice(list(non_neighbors))
            
            # Check if this edge meets DICE criteria and hasn't been used
            edge_pair = tuple(sorted([u, v]))
            if (edge_pair not in self.used_add_edges and 
                self.is_legal_edge(u, v) and 
                self.label[u] != self.label[v]):
                return (u, v)
        
        return None

    def get_removed_edge(self, influence_nodes: list) -> Optional[tuple]:
        """Get an edge to remove, ensuring at least one endpoint is a training node"""
        max_attempts = 100  # Limit attempts to avoid infinite loops
        
        for _ in range(max_attempts):
            # Select a training node with degree-based probability
            u = self._select_node_by_degree(self.train_nodes)
            if u is None:
                u = random.choice(list(self.train_nodes))
            
            neighbors = self.adjacency_matrix[u].indices.tolist()
            if not neighbors:
                continue
            
            # Select target node with degree-based probability from neighbors
            v = self._select_node_by_degree(set(neighbors))
            if v is None:
                v = random.choice(neighbors)
            
            # Check if this edge meets DICE criteria and hasn't been used
            edge_pair = tuple(sorted([u, v]))
            if (edge_pair not in self.used_remove_edges and
                not self.is_singleton_edge(u, v) and 
                self.is_legal_edge(u, v) and 
                self.label[u] == self.label[v]):
                return (u, v)
        
        return None
