import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, Any, Tuple, Optional, List
from torch_geometric.data import Data
from torch_geometric.utils import to_dense_adj, dense_to_sparse, degree
import copy


class NodeCentricPerturbations:

    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.perturbation_config = config.get('perturbations', {})

    def topology_noise(self, data: Data, epsilon_e: float = None) -> Data:
        if epsilon_e is None:
            epsilon_e = self.perturbation_config.get('topology', {}).get('epsilon_e_max', 0.15)
        perturbed_data = data.clone()
        device = data.edge_index.device
        adj_matrix = to_dense_adj(data.edge_index, max_num_nodes=data.num_nodes)[0]
        num_nodes = data.num_nodes
        original_degrees = degree(data.edge_index[0], num_nodes=num_nodes, dtype=torch.long)
        perturbed_adj = adj_matrix.clone()
        existing_edges = torch.where(adj_matrix > 0)
        num_existing_edges = len(existing_edges[0])
        if num_existing_edges > 0:
            num_edges_to_remove = int(epsilon_e * num_existing_edges)
            edge_indices = torch.randperm(num_existing_edges)[:num_edges_to_remove]
            for idx in edge_indices:
                i, j = existing_edges[0][idx], existing_edges[1][idx]
                perturbed_adj[i, j] = 0
                perturbed_adj[j, i] = 0
        if self.perturbation_config.get('topology', {}).get('degree_preserving', True):
            perturbed_adj = self._rebalance_degrees(adj_matrix, perturbed_adj, original_degrees)
        perturbed_data.edge_index = dense_to_sparse(perturbed_adj)[0]
        return perturbed_data

    def _rebalance_degrees(self, original_adj: torch.Tensor, perturbed_adj: torch.Tensor,
                          original_degrees: torch.Tensor) -> torch.Tensor:
        num_nodes = original_adj.shape[0]
        device = original_adj.device
        current_degrees = perturbed_adj.sum(dim=1)
        for node_i in range(num_nodes):
            degree_deficit = original_degrees[node_i] - current_degrees[node_i]
            if degree_deficit > 0:
                unconnected_nodes = torch.where(
                    (perturbed_adj[node_i] == 0) &
                    (torch.arange(num_nodes, device=device) != node_i)
                )[0]
                if len(unconnected_nodes) > 0:
                    num_to_add = min(int(degree_deficit.item()), len(unconnected_nodes))
                    selected_nodes = unconnected_nodes[torch.randperm(len(unconnected_nodes))[:num_to_add]]
                    for node_j in selected_nodes:
                        perturbed_adj[node_i, node_j] = 1
                        perturbed_adj[node_j, node_i] = 1
        return perturbed_adj

    def attribute_noise(self, data: Data, epsilon_x: float = None) -> Data:
        if epsilon_x is None:
            epsilon_x = self.perturbation_config.get('attribute', {}).get('epsilon_x_max', 0.08)
        perturbed_data = data.clone()
        feature_stds = torch.std(data.x, dim=0)
        noise_type = self.perturbation_config.get('attribute', {}).get('noise_type', 'gaussian')
        if noise_type == 'gaussian':
            noise = torch.randn_like(data.x) * feature_stds.unsqueeze(0)
        elif noise_type == 'uniform':
            noise = (torch.rand_like(data.x) - 0.5) * 2 * feature_stds.unsqueeze(0)
        else:
            raise ValueError(f"Unsupported noise type: {noise_type}")
        perturbed_data.x = data.x + noise * epsilon_x
        if self.perturbation_config.get('attribute', {}).get('adaptive_sigma', True):
            feature_importance = torch.abs(data.x).mean(dim=0)
            importance_weights = feature_importance / feature_importance.max()
            perturbed_data.x = data.x + noise * epsilon_x * importance_weights.unsqueeze(0)
        return perturbed_data

    def label_noise(self, data: Data, epsilon_l: float = None) -> Data:
        if epsilon_l is None:
            epsilon_l = self.perturbation_config.get('label', {}).get('epsilon_l_max', 0.25)
        perturbed_data = data.clone()
        train_only = self.perturbation_config.get('label', {}).get('train_only', True)
        if train_only:
            target_mask = data.train_mask
        else:
            target_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        labeled_nodes = torch.where(target_mask)[0]
        unique_labels = torch.unique(data.y)
        num_flips = int(epsilon_l * len(labeled_nodes))
        flip_indices = labeled_nodes[torch.randperm(len(labeled_nodes))[:num_flips]]
        for idx in flip_indices:
            current_label = data.y[idx]
            other_labels = unique_labels[unique_labels != current_label]
            if len(other_labels) > 0:
                new_label = other_labels[torch.randint(len(other_labels), (1,))]
                perturbed_data.y[idx] = new_label
        return perturbed_data

    def sensitive_attribute_noise(self, data: Data, gamma: float = None) -> Data:
        if gamma is None:
            gamma = self.perturbation_config.get('sensitive', {}).get('gamma_max', 0.4)
        perturbed_data = data.clone()
        if not hasattr(data, 's'):
            return perturbed_data
        sens_attr = data.s.clone()
        N = sens_attr.shape[0]
        num_flips = int(gamma * N)
        balanced_flip = self.perturbation_config.get('sensitive', {}).get('balanced_flip', True)
        if balanced_flip:
            s0_indices = (sens_attr == 0).nonzero(as_tuple=True)[0]
            s1_indices = (sens_attr == 1).nonzero(as_tuple=True)[0]
            if len(s0_indices) > 0 and len(s1_indices) > 0:
                flip_s0 = s0_indices[torch.randperm(len(s0_indices))[:num_flips//2]]
                flip_s1 = s1_indices[torch.randperm(len(s1_indices))[:num_flips//2]]
                sens_attr[flip_s0] = 1
                sens_attr[flip_s1] = 0
        else:
            flip_indices = torch.randperm(N)[:num_flips]
            sens_attr[flip_indices] = 1 - sens_attr[flip_indices]
        perturbed_data.s = sens_attr
        return perturbed_data

    def combined_noise(self, data: Data, epsilon_e: float = None, epsilon_x: float = None,
                      epsilon_l: float = None, gamma: float = None) -> Data:
        perturbed_data = data.clone()
        if self.perturbation_config.get('topology', {}).get('enabled', True):
            perturbed_data = self.topology_noise(perturbed_data, epsilon_e)
        if self.perturbation_config.get('attribute', {}).get('enabled', True):
            perturbed_data = self.attribute_noise(perturbed_data, epsilon_x)
        if self.perturbation_config.get('label', {}).get('enabled', True):
            perturbed_data = self.label_noise(perturbed_data, epsilon_l)
        if self.perturbation_config.get('sensitive', {}).get('enabled', True):
            perturbed_data = self.sensitive_attribute_noise(perturbed_data, gamma)
        return perturbed_data

    def generate_k_fold_perturbations(self, data: Data, K: int = 5) -> List[Data]:
        epsilon_e_max = self.perturbation_config.get('topology', {}).get('epsilon_e_max', 0.15)
        epsilon_x_max = self.perturbation_config.get('attribute', {}).get('epsilon_x_max', 0.08)
        epsilon_l_max = self.perturbation_config.get('label', {}).get('epsilon_l_max', 0.25)
        gamma_max = self.perturbation_config.get('sensitive', {}).get('gamma_max', 0.4)
        perturbed_graphs = []
        for k in range(K):
            torch.manual_seed(k + 42)
            epsilon_e = np.random.uniform(0, epsilon_e_max)
            epsilon_x = np.random.uniform(0, epsilon_x_max)
            epsilon_l = np.random.uniform(0, epsilon_l_max)
            gamma = np.random.uniform(0, gamma_max)
            perturbed_graph = self.combined_noise(data, epsilon_e, epsilon_x, epsilon_l, gamma)
            perturbed_graphs.append(perturbed_graph)
        return perturbed_graphs

    def compute_wasserstein_distance(self, data1: Data, data2: Data, kappa: Dict[str, float]) -> float:
        device = data1.x.device
        adj1 = to_dense_adj(data1.edge_index, max_num_nodes=data1.num_nodes)[0]
        adj2 = to_dense_adj(data2.edge_index, max_num_nodes=data2.num_nodes)[0]
        topo_dist = kappa.get('edge', 1.0) * torch.norm(adj1 - adj2, p='fro') ** 2
        feat_dist = kappa.get('feature', 1.0) * torch.norm(data1.x - data2.x, p='fro') ** 2
        label_dist = kappa.get('label', 1.0) * torch.norm(data1.y.float() - data2.y.float(), p=2) ** 2
        sens_dist = 0.0
        if hasattr(data1, 's') and hasattr(data2, 's'):
            sens_dist = kappa.get('sensitive', 1.0) * torch.norm(data1.s.float() - data2.s.float(), p=2) ** 2
        total_distance = topo_dist + feat_dist + label_dist + sens_dist
        return total_distance.item()

    def adaptive_perturbation_strength(self, data: Data, target_distance: float,
                                     kappa: Dict[str, float], max_iters: int = 10) -> Dict[str, float]:
        params = {
            'epsilon_e': 0.1,
            'epsilon_x': 0.05,
            'epsilon_l': 0.2,
            'gamma': 0.3
        }
        for iteration in range(max_iters):
            perturbed_data = self.combined_noise(data, **params)
            current_distance = self.compute_wasserstein_distance(data, perturbed_data, kappa)
            if abs(current_distance - target_distance) < 0.1:
                break
            scale_factor = target_distance / (current_distance + 1e-8)
            for key in params:
                params[key] *= scale_factor ** 0.5
                params[key] = max(0.01, min(params[key], 0.5))
        return params