"""
WTGIA (Word-level Text Graph Injection Attack) Utilities
Clean and efficient implementation based on LLMGIAv2-crc FLIPGIA
"""

import random
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch_sparse import SparseTensor, fill_diag, mul, sum as sparsesum


# ============================================================================
# Core Classes
# ============================================================================

class EarlyStop:
    """Early stopping mechanism for optimization"""
    
    def __init__(self, patience: int = 100, epsilon: float = 1e-4):
        self.patience = patience
        self.epsilon = epsilon
        self.min_score = None
        self.stop = False
        self.count = 0

    def __call__(self, score: float) -> None:
        if self.min_score is None:
            self.min_score = score
        elif self.min_score - score > 0:
            self.count = 0
            self.min_score = score
        elif self.min_score - score < self.epsilon:
            self.count += 1
            if self.count > self.patience:
                self.stop = True

    def reset(self) -> None:
        self.min_score = None
        self.stop = False
        self.count = 0


# ============================================================================
# Tensor/Matrix Conversion Utilities
# ============================================================================

def adj_to_tensor(adj: sp.spmatrix) -> SparseTensor:
    """
    Convert scipy sparse matrix to PyTorch SparseTensor
    
    Args:
        adj: scipy sparse matrix
        
    Returns:
        SparseTensor representation
    """
    if not sp.issparse(adj):
        raise ValueError("Input must be a scipy sparse matrix")
    
    adj = adj.tocoo()
    
    # Use torch.from_numpy for numpy 2.0+ compatibility
    row = torch.from_numpy(adj.row.astype(np.int64))
    col = torch.from_numpy(adj.col.astype(np.int64))
    value = torch.from_numpy(adj.data.astype(np.float32))
    
    return SparseTensor(
        row=row, col=col, value=value,
        sparse_sizes=(adj.shape[0], adj.shape[1])
    )


def tensor_to_adj(adj_tensor: Union[SparseTensor, torch.Tensor]) -> sp.csr_matrix:
    """
    Convert PyTorch tensor to scipy sparse matrix
    
    Args:
        adj_tensor: SparseTensor or edge_index format tensor
        
    Returns:
        scipy sparse CSR matrix
    """
    if isinstance(adj_tensor, SparseTensor):
        row, col, value = adj_tensor.coo()
        row_np = row.detach().cpu().numpy()
        col_np = col.detach().cpu().numpy()
        
        # Handle case where SparseTensor has no explicit values (all edges = 1)
        if value is None:
            value_np = np.ones(len(row_np), dtype=np.float32)
        else:
            value_np = value.detach().cpu().numpy()
            
        shape = (adj_tensor.size(0), adj_tensor.size(1))
        
    else:
        # Handle edge_index format [2, num_edges]
        row_np = adj_tensor[0].detach().cpu().numpy()
        col_np = adj_tensor[1].detach().cpu().numpy()
        value_np = np.ones(len(row_np), dtype=np.float32)
        max_node = max(row_np.max(), col_np.max()) + 1
        shape = (max_node, max_node)
    
    return sp.csr_matrix((value_np, (row_np, col_np)), shape=shape)


# ============================================================================
# Feature Initialization
# ============================================================================

def init_feat(
    num: int,
    features: torch.Tensor,
    device: str,
    style: str = "zeros",
    feat_lim_min: float = 0,
    feat_lim_max: float = 1
) -> torch.Tensor:
    """
    Initialize features for injected nodes
    
    Args:
        num: Number of nodes to initialize
        features: Original feature matrix
        device: Device to place tensors on
        style: Initialization style
        feat_lim_min/max: Feature bounds
        
    Returns:
        Initialized feature tensor
    """
    feat_dim = features.size(1)
    
    if style == "sample":
        # Sample from existing nodes
        indices = torch.randint(0, features.size(0), (num,), device=device)
        return features[indices].clone()
        
    elif style == "normal":
        return torch.randn(num, feat_dim, device=device)
        
    elif style == "zeros":
        return torch.zeros(num, feat_dim, device=device)
        
    elif style == "ball":
        # Unit ball initialization
        directions = torch.randn(num, feat_dim, device=device)
        return directions / torch.norm(directions, p=2, dim=1, keepdim=True)
        
    else:
        # Random uniform
        return torch.rand(num, feat_dim, device=device) * (feat_lim_max - feat_lim_min) + feat_lim_min

# ============================================================================
# Feature Optimization (FGSM for Binary Features)
# ============================================================================


def fgsm_update_features(attacker, model, adj_attack, features, features_attack, origin_labels, target_idx, 
                         sparsity_budget=0.2, batch_size=1, verbose=False):

    model.eval()

    features_attack.requires_grad_(True)
    features_attack.retain_grad()
    features_per_row = int(sparsity_budget * features_attack.shape[1])
    n_total = features.shape[0]

    # Initialize the counter for flips per row
    flips_count = (features_attack == 1).sum(dim=1).int()
    his_flips_count = []
    adj_attack = adj_attack.to_dense()

    while any(flips_count < features_per_row):
        model.zero_grad(set_to_none=True)
        if features_attack.grad is not None:
            features_attack.grad = None
        features_concat = torch.cat([features, features_attack], dim=0)
        pred = model(features_concat, adj_attack)
        loss_vec = attacker.loss(
            pred[:n_total][target_idx],
            origin_labels[target_idx],
            reduction="none",
        )
        loss = loss_vec.mean()
        loss.backward()
        grad = features_attack.grad.detach() * 1e-5 # [M, D]

        with torch.no_grad():
            mask = (flips_count < features_per_row).float().unsqueeze(1).to(grad.device)
            valid_grad = grad * mask
            flip_directions = (grad > 0).float() - (features_attack == 1).float()
            max_loss_grad = valid_grad * flip_directions

            _, max_indices = torch.topk(max_loss_grad.view(-1), batch_size)
            for idx in max_indices:
                r, c = divmod(idx.item(), features_attack.shape[1])
                if flips_count[r] < features_per_row:
                    features_attack[r, c] = 1 - features_attack[r, c]
                    flips_count[r] += 1 if features_attack[r, c] == 1 else -1

            his_flips_count.append(flips_count.sum())
            if len(his_flips_count) > 10 and his_flips_count[-1] <= his_flips_count[-10]:
                # Sometimes entries keeps flipping in cooc
                # We just break
                break

    return features_attack


# ============================================================================
# Edge Injection Strategies
# ============================================================================

def _prepare_adjacency(adj: Union[SparseTensor, torch.Tensor]) -> Tuple[sp.csr_matrix, int]:
    """
    Prepare adjacency matrix for injection
    
    Returns:
        Tuple of (scipy sparse matrix, number of nodes)
    """
    adj_sparse = tensor_to_adj(adj) if (torch.is_tensor(adj) or isinstance(adj, SparseTensor)) else adj
    n_nodes = adj_sparse.shape[0]
    return adj_sparse, n_nodes


def _finalize_adjacency(
    adj_sparse: sp.csr_matrix,
    n_inject: int,
    n_current: int,
    new_edges_x: List[int],
    new_edges_y: List[int],
    new_data: List[int],
    device: str
) -> SparseTensor:
    """
    Finalize adjacency matrix after injection
    """
    # Create extension matrices
    add1 = sp.csr_matrix((n_inject, n_current))
    add2 = sp.csr_matrix((n_current + n_inject, n_inject))
    
    # Stack matrices
    adj_attack = sp.vstack([adj_sparse, add1])
    adj_attack = sp.hstack([adj_attack, add2])
    
    # Add new edges (convert to COO for modification)
    adj_attack = adj_attack.tocoo()
    adj_attack.row = np.hstack([adj_attack.row, new_edges_x])
    adj_attack.col = np.hstack([adj_attack.col, new_edges_y])
    adj_attack.data = np.hstack([adj_attack.data, new_data])
    adj_attack = adj_attack.tocsr()
    
    return adj_to_tensor(adj_attack).to(device)


def random_injection(
    adj: Union[SparseTensor, torch.Tensor],
    n_inject: int,
    n_edge_max: int,
    target_idx: torch.Tensor,
    device: str
) -> SparseTensor:
    """
    Random edge injection strategy
    """
    adj_sparse, n_node = _prepare_adjacency(adj)
    target_idx = target_idx.cpu()
    n_test = target_idx.shape[0]
    
    new_edges_x, new_edges_y, new_data = [], [], []
    
    for i in range(n_inject):
        linked = set()
        for _ in range(n_edge_max):
            x = i + n_node
            
            # Select random target
            while True:
                y_idx = random.randint(0, n_test - 1)
                if y_idx not in linked:
                    linked.add(y_idx)
                    break
            
            y = target_idx[y_idx].item()
            new_edges_x.extend([x, y])
            new_edges_y.extend([y, x])
            new_data.extend([1, 1])
    
    return _finalize_adjacency(adj_sparse, n_inject, n_node, new_edges_x, new_edges_y, new_data, device)


def tdgia_injection(
    adj: Union[SparseTensor, torch.Tensor],
    n_inject: int,
    n_edge_max: int,
    origin_labels: torch.Tensor,
    current_pred: torch.Tensor,
    target_idx: torch.Tensor,
    device: str,
    self_connect_ratio: float = 0.0,
    weight1: float = 0.9,
    weight2: float = 0.1
) -> SparseTensor:
    """
    TDGIA (Target Degree-based GIA) injection strategy
    """
    adj_sparse, n_current = _prepare_adjacency(adj)
    target_idx = target_idx.cpu()
    n_test = target_idx.size(0)
    n_classes = origin_labels.max() + 1
    
    # Split edges between target and self-connections
    n_connect = int(n_edge_max * (1 - self_connect_ratio))
    n_self_connect = int(n_edge_max * self_connect_ratio)
    
    # Calculate node degrees
    deg = np.asarray(adj_sparse.sum(0)).flatten() + 1.0
    
    # Score target nodes
    scores = np.zeros(n_test)
    for i in range(n_test):
        node_id = target_idx[i].item()
        label = origin_labels[node_id]
        confidence = current_pred[node_id][label] + 2
        
        score1 = confidence / deg[node_id]
        score2 = confidence / np.sqrt(deg[node_id])
        scores[i] = weight1 * score1 + weight2 * score2 / np.sqrt(n_connect + n_self_connect)
    
    # Select top nodes
    sorted_idx = scores.argsort()[-n_inject * n_connect:]
    
    # Group by class
    class_groups = [[] for _ in range(n_classes)]
    for idx in sorted_idx:
        label = origin_labels[target_idx[idx]]
        class_groups[label].append(idx)
    
    # Inject edges
    new_edges_x, new_edges_y, new_data = [], [], []
    class_positions = np.zeros(n_classes, dtype=int)
    
    for i in range(n_inject):
        for _ in range(n_connect):
            # Find class with lowest coverage
            min_coverage = float('inf')
            selected_class = 0
            
            for c in range(n_classes):
                if class_groups[c] and (class_positions[c] / len(class_groups[c])) < min_coverage:
                    min_coverage = class_positions[c] / len(class_groups[c])
                    selected_class = c
            
            # Add edge to node from selected class
            target_local_idx = class_groups[selected_class][class_positions[selected_class]]
            class_positions[selected_class] += 1
            
            x = n_current + i
            y = target_idx[target_local_idx].item()
            
            new_edges_x.extend([x, y])
            new_edges_y.extend([y, x])
            new_data.extend([1, 1])
    
    # Add self-connections between injected nodes
    if n_self_connect > 0:
        connections = np.zeros((n_inject, n_inject))
        for i in range(n_inject):
            for _ in range(n_self_connect):
                j = random.randint(0, n_inject - 1)
                if i != j and connections[i][j] == 0:
                    connections[i][j] = connections[j][i] = 1
                    
                    x = n_current + i
                    y = n_current + j
                    
                    new_edges_x.extend([x, y])
                    new_edges_y.extend([y, x])
                    new_data.extend([1, 1])
    
    return _finalize_adjacency(adj_sparse, n_inject, n_current, new_edges_x, new_edges_y, new_data, device)


def atdgia_injection(
    adj: Union[SparseTensor, torch.Tensor],
    n_inject: int,
    n_edge_max: int,
    origin_labels: torch.Tensor,
    current_pred: torch.Tensor,
    target_idx: torch.Tensor,
    device: str,
    self_connect_ratio: float = 0.0,
    weight1: float = 0.9,
    weight2: float = 0.1
) -> SparseTensor:
    """
    ATDGIA (Adaptive TDGIA) injection strategy
    """
    adj_sparse, n_current = _prepare_adjacency(adj)
    target_idx = target_idx.cpu()
    n_test = target_idx.size(0)
    n_classes = origin_labels.max() + 1
    
    n_connect = int(n_edge_max * (1 - self_connect_ratio))
    n_self_connect = int(n_edge_max * self_connect_ratio)
    
    # Calculate node degrees
    deg = np.asarray(adj_sparse.sum(0)).flatten() + 1.0
    
    # Score nodes based on misclassification potential
    scores = np.zeros(n_test)
    for i in range(n_test):
        node_id = target_idx[i].item()
        label = origin_labels[node_id]
        pred_label = current_pred[node_id].argmax()
        
        # Score is higher for correctly classified nodes with low confidence
        if pred_label == label:
            confidence = 1.0 - current_pred[node_id][label]
        else:
            confidence = 0
        
        score1 = confidence / deg[node_id]
        score2 = confidence / np.sqrt(deg[node_id])
        scores[i] = weight1 * score1 + weight2 * score2 / np.sqrt(n_connect + n_self_connect)
    
    # Process similarly to TDGIA
    sorted_idx = scores.argsort()[-n_inject * n_connect:]
    
    class_groups = [[] for _ in range(n_classes)]
    for idx in sorted_idx:
        label = origin_labels[target_idx[idx]]
        class_groups[label].append(idx)
    
    new_edges_x, new_edges_y, new_data = [], [], []
    class_positions = np.zeros(n_classes, dtype=int)
    
    for i in range(n_inject):
        for _ in range(n_connect):
            min_coverage = float('inf')
            selected_class = 0
            
            for c in range(n_classes):
                if class_groups[c] and (class_positions[c] / len(class_groups[c])) < min_coverage:
                    min_coverage = class_positions[c] / len(class_groups[c])
                    selected_class = c
            
            target_local_idx = class_groups[selected_class][class_positions[selected_class]]
            class_positions[selected_class] += 1
            
            x = n_current + i
            y = target_idx[target_local_idx].item()
            
            new_edges_x.extend([x, y])
            new_edges_y.extend([y, x])
            new_data.extend([1, 1])
    
    return _finalize_adjacency(adj_sparse, n_inject, n_current, new_edges_x, new_edges_y, new_data, device)


def atdgia_ranking_select(
    adj: Union[SparseTensor, torch.Tensor],
    n_inject: int,
    n_edge_max: int,
    origin_labels: torch.Tensor,
    current_pred: torch.Tensor,
    target_idx: torch.Tensor,
    ratio: float = 0.5,
    weight1: float = 0.9,
    weight2: float = 0.1
) -> torch.Tensor:
    """
    Select vulnerable target nodes based on ATDGIA scoring
    """
    adj_sparse, _ = _prepare_adjacency(adj)
    target_idx = target_idx.cpu()
    n_test = target_idx.size(0)
    
    # Calculate degrees
    deg = np.asarray(adj_sparse.sum(0)).flatten() + 1.0
    
    # Score nodes
    scores = np.zeros(n_test)
    for i in range(n_test):
        node_id = target_idx[i].item()
        label = origin_labels[node_id]
        pred_label = current_pred[node_id].argmax()
        
        if pred_label == label:
            confidence = 1.0 - current_pred[node_id][label]
        else:
            confidence = 0
        
        score1 = confidence / deg[node_id]
        score2 = confidence / np.sqrt(deg[node_id])
        scores[i] = weight1 * score1 + weight2 * score2 / np.sqrt(n_edge_max)
    
    # Select top nodes
    select_num = int(n_test * ratio)
    sorted_rank = scores.argsort()
    
    return target_idx[sorted_rank[-select_num:]]

def random_class_injection(
    adj: Union[SparseTensor, torch.Tensor],
    n_inject: int,
    n_edge_max: int,
    origin_labels: torch.Tensor,
    target_idx: torch.Tensor,
    device: str,
    not_full: bool = False
) -> SparseTensor:
    """Random class-based injection"""
    adj_sparse, n_node = _prepare_adjacency(adj)
    target_idx = target_idx.cpu()
    n_classes = origin_labels.max() + 1
    
    # Group nodes by class
    class_nodes = [[] for _ in range(n_classes)]
    for i, node_id in enumerate(target_idx):
        label = origin_labels[node_id]
        class_nodes[label].append(i)
    
    new_edges_x, new_edges_y, new_data = [], [], []
    
    for i in range(n_inject):
        # Select random class
        class_id = random.randint(0, n_classes - 1)
        n_connections = min(len(class_nodes[class_id]), n_edge_max)
        
        if n_connections > 0:
            # Connect to random nodes from selected class
            selected = random.sample(class_nodes[class_id], n_connections)
            
            for local_idx in selected:
                x = i + n_node
                y = target_idx[local_idx].item()
                
                new_edges_x.extend([x, y])
                new_edges_y.extend([y, x])
                new_data.extend([1, 1])
    
    return _finalize_adjacency(adj_sparse, n_inject, n_node, new_edges_x, new_edges_y, new_data, device)


# ============================================================================
# Utility Functions
# ============================================================================

def avg_sparsity(X: torch.Tensor) -> float:
    """Calculate average sparsity (proportion of non-zero elements)"""
    non_zero = (X != 0).sum().item()
    total = X.numel()
    return non_zero / total if total > 0 else 0


