import numpy as np
import torch
from scipy.spatial.distance import cdist
from .utils import thresholded_gaussian_kernel


def compute_node_domain(position, adj, node_idx):
    """
    Calculate node domain (region enclosed by midpoints of adjacent nodes)
    
    Args:
        position: Position coordinates of all nodes [n_nodes, 2]
        adj: Adjacency matrix [n_nodes, n_nodes]
        node_idx: Index of target node
        
    Returns:
        domain_center: Center point of node domain
        domain_radius: Radius of node domain (for perturbation range)
    """
    # Get adjacent nodes of this node
    neighbors = np.where(adj[node_idx] > 0)[0]
    
    if len(neighbors) == 0:
        # If no neighbors, use node's own position with perturbation range 0.5
        return position[node_idx], 0.5
    
    # Calculate distances to neighbor nodes
    center_pos = position[node_idx]
    neighbor_positions = position[neighbors]
    distances = np.linalg.norm(neighbor_positions - center_pos, axis=1)
    
    # Calculate midpoint positions
    midpoints = []
    for i, neighbor_idx in enumerate(neighbors):
        neighbor_pos = position[neighbor_idx]
        midpoint = (center_pos + neighbor_pos) / 2
        midpoints.append(midpoint)
    
    midpoints = np.array(midpoints)
    
    if len(midpoints) == 1:
        # Only one neighbor, perturbation range is half the distance to midpoint
        domain_center = center_pos
        domain_radius = np.linalg.norm(midpoints[0] - center_pos) * 0.5
    elif len(midpoints) == 2:
        # Two neighbors, node domain is the line segment between two midpoints
        domain_center = center_pos
        # Perturbation range is the minimum distance to two midpoints
        dist1 = np.linalg.norm(midpoints[0] - center_pos)
        dist2 = np.linalg.norm(midpoints[1] - center_pos)
        domain_radius = min(dist1, dist2) * 0.8
    else:
        # Multiple neighbors, calculate centroid and radius of midpoint enclosing region
        domain_center = center_pos
        # Perturbation range is the minimum distance to all midpoints
        distances_to_midpoints = np.linalg.norm(midpoints - center_pos, axis=1)
        domain_radius = np.min(distances_to_midpoints) * 0.8
    
    # Limit maximum perturbation range to 0.5
    domain_radius = min(domain_radius, 0.5)
    
    return domain_center, domain_radius


def perturb_node_positions(position, adj, perturbation_strength=0.5):
    """
    Perturb all node positions
    
    Args:
        position: Original node positions [n_nodes, 2]
        adj: Adjacency matrix [n_nodes, n_nodes]  
        perturbation_strength: Perturbation strength, default 0.5
        
    Returns:
        perturbed_position: Perturbed node positions [n_nodes, 2]
    """
    n_nodes = position.shape[0]
    perturbed_position = position.copy()
    
    for node_idx in range(n_nodes):
        domain_center, domain_radius = compute_node_domain(position, adj, node_idx)
        
        # Random perturbation within node domain
        # Use Gaussian distribution with std=domain_radius/3 to ensure 99.7% values are within domain
        std = min(domain_radius / 3, perturbation_strength / 3)
        
        # Generate random perturbation
        perturbation = np.random.normal(0, std, size=2)
        
        # Ensure perturbation doesn't exceed domain range
        perturbation_norm = np.linalg.norm(perturbation)
        if perturbation_norm > domain_radius:
            perturbation = perturbation * domain_radius / perturbation_norm
        
        perturbed_position[node_idx] = domain_center + perturbation
    
    return perturbed_position


def recompute_adjacency_matrix(perturbed_position, original_adj, adj_threshold=0.1, dataset_name=""):
    """
    Recompute adjacency matrix based on perturbed positions
    
    Args:
        perturbed_position: Perturbed node positions [n_nodes, 2]
        original_adj: Original adjacency matrix (for parameter reference)
        adj_threshold: Adjacency matrix threshold
        dataset_name: Dataset name for selecting appropriate similarity computation method
        
    Returns:
        new_adj: New adjacency matrix [n_nodes, n_nodes]
    """
    # Calculate new distance matrix
    distances = cdist(perturbed_position, perturbed_position, metric='euclidean')
    
    # Select appropriate similarity computation method based on dataset type
    if dataset_name.startswith('aqi'):
        # Air Quality dataset uses Gaussian kernel
        theta = np.std(distances)
        new_adj = thresholded_gaussian_kernel(distances, theta=theta, threshold=adj_threshold)
    elif dataset_name in ['la_point', 'bay_point']:
        # METR-LA and PEMS-BAY use DCRNN method
        finite_dist = distances.reshape(-1)
        finite_dist = finite_dist[~np.isinf(finite_dist)]
        sigma = finite_dist.std()
        new_adj = np.exp(-np.square(distances / sigma))
        new_adj[new_adj < adj_threshold] = 0.
    elif dataset_name == 'pems07_point':
        # PEMS07 uses similar method
        finite_dist = distances.reshape(-1)
        finite_dist = finite_dist[~np.isinf(finite_dist)]
        sigma = finite_dist.std()
        new_adj = np.exp(-np.square(distances / sigma))
        new_adj[new_adj < adj_threshold] = 0.
    else:
        # Default to Gaussian kernel
        theta = np.std(distances)
        new_adj = thresholded_gaussian_kernel(distances, theta=theta, threshold=adj_threshold)
    
    # Remove self-loops
    np.fill_diagonal(new_adj, 0.)
    
    return new_adj


def get_perturbed_adj_and_position(original_position, original_adj, dataset_name, adj_threshold=0.1, perturbation_strength=0.5):
    """
    Get perturbed positions and recomputed adjacency matrix in one go
    
    Args:
        original_position: Original node positions
        original_adj: Original adjacency matrix
        dataset_name: Dataset name
        adj_threshold: Adjacency matrix threshold
        perturbation_strength: Perturbation strength
        
    Returns:
        perturbed_position: Perturbed positions
        new_adj: New adjacency matrix
    """
    # Perturb node positions
    perturbed_position = perturb_node_positions(original_position, original_adj, perturbation_strength)
    
    # Recompute adjacency matrix
    new_adj = recompute_adjacency_matrix(perturbed_position, original_adj, adj_threshold, dataset_name)
    
    return perturbed_position, new_adj
