#!/usr/bin/env python3
"""
Kolmogorov-Smirnov (KS) Test Analysis for Differential Privacy Effectiveness

This script performs statistical analysis to evaluate the effectiveness of differential
privacy mechanisms on:
1. Data encodings transmitted from clients to server
2. Relational graphs generated by the GAT model

The KS test helps determine if the distributions of original and noisy data are
statistically different, which indicates privacy protection effectiveness.
"""

import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import pandas as pd
import os
import logging
from typing import Dict, List, Tuple, Optional
import pickle
from collections import defaultdict

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('DP_KS_Analysis')


class DifferentialPrivacyAnalyzer:
    """
    Analyzer for differential privacy effectiveness using Kolmogorov-Smirnov tests
    """
    
    def __init__(self, epsilon: float = 1.0, sensitivity: float = 1.0, output_dir: str = './dump/dp_analysis'):
        """
        Initialize the analyzer
        
        Args:
            epsilon: Privacy budget
            sensitivity: Sensitivity of the function
            output_dir: Directory to save analysis results
        """
        self.epsilon = epsilon
        self.sensitivity = sensitivity
        self.scale = sensitivity / epsilon  # Laplace scale parameter
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Storage for test results
        self.encoding_results = []
        self.graph_results = []
        
    def add_laplace_noise(self, data: np.ndarray, scale: Optional[float] = None) -> np.ndarray:
        """
        Add Laplace noise to data
        
        Args:
            data: Original data
            scale: Scale parameter for Laplace distribution (uses self.scale if None)
            
        Returns:
            Noisy data
        """
        if scale is None:
            scale = self.scale
            
        noise = np.random.laplace(0, scale, data.shape)
        return data + noise
    
    def perform_ks_test(self, original: np.ndarray, noisy: np.ndarray) -> Dict[str, float]:
        """
        Perform Kolmogorov-Smirnov test between original and noisy distributions
        
        Args:
            original: Original data
            noisy: Noisy data
            
        Returns:
            Dictionary containing KS statistic, p-value, and other metrics
        """
        # Flatten arrays if multidimensional
        original_flat = original.flatten()
        noisy_flat = noisy.flatten()
        
        # Perform KS test
        ks_statistic, p_value = stats.ks_2samp(original_flat, noisy_flat)
        
        # Calculate additional metrics
        # Total Variation Distance (TVD) approximation
        hist_bins = 50
        hist_original, bins = np.histogram(original_flat, bins=hist_bins, density=True)
        hist_noisy, _ = np.histogram(noisy_flat, bins=bins, density=True)
        tvd = 0.5 * np.sum(np.abs(hist_original - hist_noisy)) * (bins[1] - bins[0])
        
        # Wasserstein distance
        wasserstein_dist = stats.wasserstein_distance(original_flat, noisy_flat)
        
        # Jensen-Shannon divergence
        # Normalize histograms to sum to 1
        hist_original_norm = hist_original / np.sum(hist_original)
        hist_noisy_norm = hist_noisy / np.sum(hist_noisy)
        # Add small epsilon to avoid log(0)
        eps = 1e-10
        hist_original_norm = hist_original_norm + eps
        hist_noisy_norm = hist_noisy_norm + eps
        js_divergence = 0.5 * stats.entropy(hist_original_norm, hist_noisy_norm) + \
                       0.5 * stats.entropy(hist_noisy_norm, hist_original_norm)
        
        return {
            'ks_statistic': ks_statistic,
            'p_value': p_value,
            'tvd': tvd,
            'wasserstein_distance': wasserstein_dist,
            'js_divergence': js_divergence,
            'mean_diff': np.abs(np.mean(original_flat) - np.mean(noisy_flat)),
            'std_diff': np.abs(np.std(original_flat) - np.std(noisy_flat)),
            'max_diff': np.max(np.abs(original - noisy))
        }
    
    def analyze_encodings(self, encodings: List[torch.Tensor], client_ids: List[int], 
                         task_id: int, round_id: int) -> Dict[str, any]:
        """
        Analyze the effectiveness of DP on client encodings
        
        Args:
            encodings: List of encoding tensors from clients
            client_ids: List of client IDs
            task_id: Current task ID
            round_id: Current round ID
            
        Returns:
            Analysis results
        """
        logger.info(f"Analyzing encodings for Task {task_id}, Round {round_id}")
        
        results = {
            'task_id': task_id,
            'round_id': round_id,
            'client_results': {},
            'aggregate_metrics': {}
        }
        
        all_ks_stats = []
        all_p_values = []
        all_tvds = []
        
        for i, (encoding, client_id) in enumerate(zip(encodings, client_ids)):
            # Convert to numpy if tensor
            if isinstance(encoding, torch.Tensor):
                encoding_np = encoding.detach().cpu().numpy()
            else:
                encoding_np = encoding
            
            # Add Laplace noise
            noisy_encoding = self.add_laplace_noise(encoding_np)
            
            # Perform KS test
            test_results = self.perform_ks_test(encoding_np, noisy_encoding)
            
            results['client_results'][client_id] = test_results
            all_ks_stats.append(test_results['ks_statistic'])
            all_p_values.append(test_results['p_value'])
            all_tvds.append(test_results['tvd'])
        
        # Aggregate metrics
        results['aggregate_metrics'] = {
            'mean_ks_statistic': np.mean(all_ks_stats),
            'std_ks_statistic': np.std(all_ks_stats),
            'mean_p_value': np.mean(all_p_values),
            'mean_tvd': np.mean(all_tvds),
            'significant_differences': sum(p < 0.05 for p in all_p_values),
            'total_clients': len(client_ids)
        }
        
        self.encoding_results.append(results)
        return results
    
    def analyze_relational_graph(self, original_graph: np.ndarray, task_id: int) -> Dict[str, any]:
        """
        Analyze the effectiveness of DP on relational graphs
        
        Args:
            original_graph: Original relational graph
            task_id: Current task ID
            
        Returns:
            Analysis results
        """
        logger.info(f"Analyzing relational graph for Task {task_id}")
        
        # Add noise to graph while maintaining properties
        noisy_graph = self.add_laplace_noise_to_graph(original_graph)
        
        results = {
            'task_id': task_id,
            'graph_metrics': {},
            'row_analysis': {},
            'privacy_metrics': {}
        }
        
        # Overall graph analysis
        graph_test = self.perform_ks_test(original_graph, noisy_graph)
        results['graph_metrics'] = graph_test
        
        # Row-wise analysis (per client)
        row_ks_stats = []
        row_p_values = []
        
        for i in range(original_graph.shape[0]):
            row_test = self.perform_ks_test(original_graph[i], noisy_graph[i])
            results['row_analysis'][f'client_{i}'] = row_test
            row_ks_stats.append(row_test['ks_statistic'])
            row_p_values.append(row_test['p_value'])
        
        # Privacy-specific metrics
        results['privacy_metrics'] = {
            'mean_row_ks_statistic': np.mean(row_ks_stats),
            'std_row_ks_statistic': np.std(row_ks_stats),
            'mean_row_p_value': np.mean(row_p_values),
            'significant_row_differences': sum(p < 0.05 for p in row_p_values),
            'frobenius_norm_diff': np.linalg.norm(original_graph - noisy_graph, 'fro'),
            'spectral_norm_diff': np.linalg.norm(original_graph - noisy_graph, 2),
            'rank_original': np.linalg.matrix_rank(original_graph),
            'rank_noisy': np.linalg.matrix_rank(noisy_graph),
            'connectivity_preserved': self._check_connectivity_preservation(original_graph, noisy_graph)
        }
        
        self.graph_results.append(results)
        return results
    
    def add_laplace_noise_to_graph(self, graph: np.ndarray) -> np.ndarray:
        """
        Add Laplace noise to relational graph while maintaining graph properties
        
        Args:
            graph: Original graph
            
        Returns:
            Noisy graph
        """
        # Add noise
        noisy_graph = self.add_laplace_noise(graph)
        
        # Ensure non-negative values
        noisy_graph = np.maximum(noisy_graph, 0)
        
        # Normalize rows to maintain attention property
        row_sums = noisy_graph.sum(axis=1, keepdims=True)
        row_sums = np.maximum(row_sums, 1e-8)  # Avoid division by zero
        noisy_graph = noisy_graph / row_sums
        
        return noisy_graph
    
    def _check_connectivity_preservation(self, original: np.ndarray, noisy: np.ndarray, 
                                       threshold: float = 0.1) -> float:
        """
        Check how well connectivity is preserved after adding noise
        
        Args:
            original: Original graph
            noisy: Noisy graph
            threshold: Threshold for considering an edge exists
            
        Returns:
            Proportion of edges preserved
        """
        original_edges = original > threshold
        noisy_edges = noisy > threshold
        
        preserved = np.sum(original_edges & noisy_edges)
        total_original = np.sum(original_edges)
        
        if total_original == 0:
            return 1.0
        
        return preserved / total_original
    
    def visualize_encoding_analysis(self, save_path: Optional[str] = None):
        """
        Create visualizations for encoding analysis results
        """
        if not self.encoding_results:
            logger.warning("No encoding results to visualize")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Extract data for plotting
        tasks = []
        rounds = []
        mean_ks_stats = []
        mean_p_values = []
        mean_tvds = []
        significant_ratios = []
        
        for result in self.encoding_results:
            tasks.append(result['task_id'])
            rounds.append(result['round_id'])
            mean_ks_stats.append(result['aggregate_metrics']['mean_ks_statistic'])
            mean_p_values.append(result['aggregate_metrics']['mean_p_value'])
            mean_tvds.append(result['aggregate_metrics']['mean_tvd'])
            sig_ratio = result['aggregate_metrics']['significant_differences'] / \
                       result['aggregate_metrics']['total_clients']
            significant_ratios.append(sig_ratio)
        
        # Plot 1: KS Statistics over time
        ax = axes[0, 0]
        ax.plot(range(len(mean_ks_stats)), mean_ks_stats, 'o-', linewidth=2, markersize=8)
        ax.set_xlabel('Training Progress', fontsize=12)
        ax.set_ylabel('Mean KS Statistic', fontsize=12)
        ax.set_title('KS Statistic Evolution\n(Higher = More Privacy)', fontsize=14)
        ax.grid(True, alpha=0.3)
        
        # Plot 2: P-values over time
        ax = axes[0, 1]
        ax.plot(range(len(mean_p_values)), mean_p_values, 'o-', linewidth=2, markersize=8)
        ax.axhline(y=0.05, color='r', linestyle='--', label='α=0.05')
        ax.set_xlabel('Training Progress', fontsize=12)
        ax.set_ylabel('Mean P-value', fontsize=12)
        ax.set_title('P-value Evolution\n(Lower = Distributions More Different)', fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.legend()
        
        # Plot 3: Total Variation Distance
        ax = axes[1, 0]
        ax.plot(range(len(mean_tvds)), mean_tvds, 'o-', linewidth=2, markersize=8)
        ax.set_xlabel('Training Progress', fontsize=12)
        ax.set_ylabel('Mean TVD', fontsize=12)
        ax.set_title('Total Variation Distance\n(Higher = More Privacy)', fontsize=14)
        ax.grid(True, alpha=0.3)
        
        # Plot 4: Proportion of significant differences
        ax = axes[1, 1]
        ax.plot(range(len(significant_ratios)), significant_ratios, 'o-', linewidth=2, markersize=8)
        ax.set_xlabel('Training Progress', fontsize=12)
        ax.set_ylabel('Proportion of Clients', fontsize=12)
        ax.set_title('Clients with Significant Distribution Changes\n(Higher = Better Privacy)', fontsize=14)
        ax.set_ylim(0, 1.1)
        ax.grid(True, alpha=0.3)
        
        plt.suptitle(f'Encoding Privacy Analysis (ε={self.epsilon})', fontsize=16)
        plt.tight_layout()
        
        if save_path is None:
            save_path = os.path.join(self.output_dir, 'encoding_privacy_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Saved encoding analysis visualization to {save_path}")
    
    def visualize_graph_analysis(self, save_path: Optional[str] = None):
        """
        Create visualizations for graph analysis results
        """
        if not self.graph_results:
            logger.warning("No graph results to visualize")
            return
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # Extract data
        tasks = []
        ks_stats = []
        p_values = []
        frobenius_diffs = []
        spectral_diffs = []
        connectivity_preserved = []
        rank_changes = []
        
        for result in self.graph_results:
            tasks.append(result['task_id'])
            ks_stats.append(result['graph_metrics']['ks_statistic'])
            p_values.append(result['graph_metrics']['p_value'])
            frobenius_diffs.append(result['privacy_metrics']['frobenius_norm_diff'])
            spectral_diffs.append(result['privacy_metrics']['spectral_norm_diff'])
            connectivity_preserved.append(result['privacy_metrics']['connectivity_preserved'])
            rank_diff = abs(result['privacy_metrics']['rank_original'] - 
                          result['privacy_metrics']['rank_noisy'])
            rank_changes.append(rank_diff)
        
        # Plot 1: KS Statistics
        ax = axes[0, 0]
        ax.bar(tasks, ks_stats, color='skyblue', edgecolor='black')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('KS Statistic', fontsize=12)
        ax.set_title('Graph KS Statistics by Task', fontsize=14)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Plot 2: P-values
        ax = axes[0, 1]
        ax.bar(tasks, p_values, color='lightcoral', edgecolor='black')
        ax.axhline(y=0.05, color='r', linestyle='--', label='α=0.05')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('P-value', fontsize=12)
        ax.set_title('Graph P-values by Task', fontsize=14)
        ax.grid(True, alpha=0.3, axis='y')
        ax.legend()
        
        # Plot 3: Frobenius Norm Differences
        ax = axes[0, 2]
        ax.bar(tasks, frobenius_diffs, color='lightgreen', edgecolor='black')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('Frobenius Norm Diff', fontsize=12)
        ax.set_title('Matrix Norm Differences', fontsize=14)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Plot 4: Spectral Norm Differences
        ax = axes[1, 0]
        ax.bar(tasks, spectral_diffs, color='lightyellow', edgecolor='black')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('Spectral Norm Diff', fontsize=12)
        ax.set_title('Spectral Norm Differences', fontsize=14)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Plot 5: Connectivity Preservation
        ax = axes[1, 1]
        ax.bar(tasks, connectivity_preserved, color='lightblue', edgecolor='black')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('Proportion Preserved', fontsize=12)
        ax.set_title('Edge Connectivity Preservation', fontsize=14)
        ax.set_ylim(0, 1.1)
        ax.grid(True, alpha=0.3, axis='y')
        
        # Plot 6: Rank Changes
        ax = axes[1, 2]
        ax.bar(tasks, rank_changes, color='lavender', edgecolor='black')
        ax.set_xlabel('Task', fontsize=12)
        ax.set_ylabel('Rank Change', fontsize=12)
        ax.set_title('Matrix Rank Changes', fontsize=14)
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.suptitle(f'Relational Graph Privacy Analysis (ε={self.epsilon})', fontsize=16)
        plt.tight_layout()
        
        if save_path is None:
            save_path = os.path.join(self.output_dir, 'graph_privacy_analysis.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Saved graph analysis visualization to {save_path}")
    
    def generate_privacy_utility_tradeoff(self, epsilons: List[float], 
                                        encoding_samples: List[np.ndarray],
                                        graph_sample: np.ndarray,
                                        save_path: Optional[str] = None):
        """
        Generate privacy-utility tradeoff curves for different epsilon values
        
        Args:
            epsilons: List of epsilon values to test
            encoding_samples: Sample encodings to test
            graph_sample: Sample graph to test
            save_path: Path to save the visualization
        """
        encoding_ks_stats = []
        encoding_utilities = []
        graph_ks_stats = []
        graph_utilities = []
        
        original_epsilon = self.epsilon
        
        for eps in epsilons:
            self.epsilon = eps
            self.scale = self.sensitivity / eps
            
            # Test encodings
            ks_stats = []
            utilities = []
            
            for encoding in encoding_samples:
                noisy = self.add_laplace_noise(encoding)
                test_result = self.perform_ks_test(encoding, noisy)
                ks_stats.append(test_result['ks_statistic'])
                
                # Utility measured as inverse of mean squared error
                mse = np.mean((encoding - noisy) ** 2)
                utility = 1 / (1 + mse)
                utilities.append(utility)
            
            encoding_ks_stats.append(np.mean(ks_stats))
            encoding_utilities.append(np.mean(utilities))
            
            # Test graph
            noisy_graph = self.add_laplace_noise_to_graph(graph_sample)
            graph_test = self.perform_ks_test(graph_sample, noisy_graph)
            graph_ks_stats.append(graph_test['ks_statistic'])
            
            # Graph utility measured as connectivity preservation
            graph_utility = self._check_connectivity_preservation(graph_sample, noisy_graph)
            graph_utilities.append(graph_utility)
        
        # Restore original epsilon
        self.epsilon = original_epsilon
        self.scale = self.sensitivity / original_epsilon
        
        # Create visualization
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Plot 1: Encodings
        ax1_twin = ax1.twinx()
        
        line1 = ax1.plot(epsilons, encoding_ks_stats, 'b-o', linewidth=2, 
                        markersize=8, label='Privacy (KS Statistic)')
        line2 = ax1_twin.plot(epsilons, encoding_utilities, 'r-s', linewidth=2, 
                             markersize=8, label='Utility')
        
        ax1.set_xlabel('Privacy Budget (ε)', fontsize=12)
        ax1.set_ylabel('KS Statistic', fontsize=12, color='b')
        ax1_twin.set_ylabel('Utility Score', fontsize=12, color='r')
        ax1.tick_params(axis='y', labelcolor='b')
        ax1_twin.tick_params(axis='y', labelcolor='r')
        ax1.set_title('Encoding Privacy-Utility Tradeoff', fontsize=14)
        ax1.grid(True, alpha=0.3)
        
        # Add legend
        lines1 = line1 + line2
        labels1 = [l.get_label() for l in lines1]
        ax1.legend(lines1, labels1, loc='best')
        
        # Plot 2: Graphs
        ax2_twin = ax2.twinx()
        
        line3 = ax2.plot(epsilons, graph_ks_stats, 'b-o', linewidth=2, 
                        markersize=8, label='Privacy (KS Statistic)')
        line4 = ax2_twin.plot(epsilons, graph_utilities, 'r-s', linewidth=2, 
                             markersize=8, label='Utility (Connectivity)')
        
        ax2.set_xlabel('Privacy Budget (ε)', fontsize=12)
        ax2.set_ylabel('KS Statistic', fontsize=12, color='b')
        ax2_twin.set_ylabel('Connectivity Preserved', fontsize=12, color='r')
        ax2.tick_params(axis='y', labelcolor='b')
        ax2_twin.tick_params(axis='y', labelcolor='r')
        ax2.set_title('Graph Privacy-Utility Tradeoff', fontsize=14)
        ax2.grid(True, alpha=0.3)
        
        # Add legend
        lines2 = line3 + line4
        labels2 = [l.get_label() for l in lines2]
        ax2.legend(lines2, labels2, loc='best')
        
        plt.suptitle('Privacy-Utility Tradeoff Analysis', fontsize=16)
        plt.tight_layout()
        
        if save_path is None:
            save_path = os.path.join(self.output_dir, 'privacy_utility_tradeoff.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        
        logger.info(f"Saved privacy-utility tradeoff visualization to {save_path}")
    
    def save_results(self):
        """
        Save all analysis results to disk
        """
        results = {
            'epsilon': self.epsilon,
            'sensitivity': self.sensitivity,
            'scale': self.scale,
            'encoding_results': self.encoding_results,
            'graph_results': self.graph_results
        }
        
        # Save as pickle
        pickle_path = os.path.join(self.output_dir, 'dp_analysis_results.pkl')
        with open(pickle_path, 'wb') as f:
            pickle.dump(results, f)
        
        # Save summary as CSV
        if self.encoding_results:
            encoding_df = pd.DataFrame([
                {
                    'task_id': r['task_id'],
                    'round_id': r['round_id'],
                    'mean_ks_statistic': r['aggregate_metrics']['mean_ks_statistic'],
                    'mean_p_value': r['aggregate_metrics']['mean_p_value'],
                    'mean_tvd': r['aggregate_metrics']['mean_tvd'],
                    'significant_ratio': r['aggregate_metrics']['significant_differences'] / 
                                       r['aggregate_metrics']['total_clients']
                }
                for r in self.encoding_results
            ])
            encoding_df.to_csv(os.path.join(self.output_dir, 'encoding_analysis_summary.csv'), 
                             index=False)
        
        if self.graph_results:
            graph_df = pd.DataFrame([
                {
                    'task_id': r['task_id'],
                    'ks_statistic': r['graph_metrics']['ks_statistic'],
                    'p_value': r['graph_metrics']['p_value'],
                    'frobenius_norm_diff': r['privacy_metrics']['frobenius_norm_diff'],
                    'connectivity_preserved': r['privacy_metrics']['connectivity_preserved']
                }
                for r in self.graph_results
            ])
            graph_df.to_csv(os.path.join(self.output_dir, 'graph_analysis_summary.csv'), 
                          index=False)
        
        logger.info(f"Saved analysis results to {self.output_dir}")


def run_dp_analysis_example():
    """
    Example function showing how to use the DifferentialPrivacyAnalyzer
    """
    # Initialize analyzer
    analyzer = DifferentialPrivacyAnalyzer(epsilon=1.0, sensitivity=1.0)
    
    # Simulate some encoding data
    num_clients = 10
    encoding_dim = 512
    num_rounds = 5
    
    for round_id in range(num_rounds):
        # Generate fake encodings
        encodings = [np.random.randn(32, encoding_dim) for _ in range(num_clients)]
        client_ids = list(range(num_clients))
        
        # Analyze encodings
        analyzer.analyze_encodings(encodings, client_ids, task_id=0, round_id=round_id)
    
    # Simulate relational graphs
    for task_id in range(4):
        # Generate fake graph
        graph = np.random.rand(num_clients, num_clients)
        # Make it look like attention scores
        graph = graph / graph.sum(axis=1, keepdims=True)
        
        # Analyze graph
        analyzer.analyze_relational_graph(graph, task_id)
    
    # Create visualizations
    analyzer.visualize_encoding_analysis()
    analyzer.visualize_graph_analysis()
    
    # Test privacy-utility tradeoff
    epsilons = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
    sample_encodings = [np.random.randn(32, encoding_dim) for _ in range(5)]
    sample_graph = np.random.rand(num_clients, num_clients)
    sample_graph = sample_graph / sample_graph.sum(axis=1, keepdims=True)
    
    analyzer.generate_privacy_utility_tradeoff(epsilons, sample_encodings, sample_graph)
    
    # Save results
    analyzer.save_results()
    
    logger.info("DP analysis example completed!")


if __name__ == "__main__":
    run_dp_analysis_example()