import numpy as np
import torch
from typing import Union, List
import random

class DifferentialPrivacy:
    def __init__(self, epsilon: float, sensitivity: float):
        self.epsilon = epsilon
        self.sensitivity = sensitivity

    def add_noise(self, data: Union[torch.Tensor, np.ndarray]) -> Union[torch.Tensor, np.ndarray]:
        """添加拉普拉斯噪声实现差分隐私"""
        scale = self.sensitivity / self.epsilon
        if isinstance(data, torch.Tensor):
            noise = torch.tensor(np.random.laplace(0, scale, data.shape), device=data.device)
            return data + noise
        else:
            noise = np.random.laplace(0, scale, data.shape)
            return data + noise

    def clip_gradients(self, gradients: List[torch.Tensor], clip_norm: float) -> List[torch.Tensor]:
        """梯度裁剪"""
        total_norm = 0
        for grad in gradients:
            param_norm = grad.data.norm(2)
            total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        clip_coef = clip_norm / (total_norm + 1e-6)
        if clip_coef < 1:
            for grad in gradients:
                grad.data.mul_(clip_coef)
        return gradients

class HomomorphicEncryption:
    def __init__(self, key_size: int = 1024):
        self.key_size = key_size
        self.public_key, self.private_key = self._generate_keys()

    def _generate_keys(self):
        """生成同态加密密钥对"""
        # 这里使用简化的实现，实际应用中应使用专业的同态加密库
        p = self._generate_prime(self.key_size)
        q = self._generate_prime(self.key_size)
        n = p * q
        phi = (p - 1) * (q - 1)
        e = 65537  # 常用的公钥指数
        d = self._modinv(e, phi)
        return (e, n), (d, n)

    def _generate_prime(self, bits: int) -> int:
        """生成大素数"""
        while True:
            num = random.getrandbits(bits)
            if self._is_prime(num):
                return num

    def _is_prime(self, n: int) -> bool:
        """米勒-拉宾素性测试"""
        if n <= 1:
            return False
        for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
            if n % p == 0:
                return n == p
        d = n - 1
        s = 0
        while d % 2 == 0:
            d //= 2
            s += 1
        for a in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
            if a >= n:
                continue
            x = pow(a, d, n)
            if x == 1 or x == n - 1:
                continue
            for _ in range(s - 1):
                x = pow(x, 2, n)
                if x == n - 1:
                    break
            else:
                return False
        return True

    def _modinv(self, a: int, m: int) -> int:
        """计算模反元素"""
        g, x, y = self._extended_gcd(a, m)
        if g != 1:
            raise Exception('模反元素不存在')
        return x % m

    def _extended_gcd(self, a: int, b: int):
        """扩展欧几里得算法"""
        if a == 0:
            return (b, 0, 1)
        else:
            g, y, x = self._extended_gcd(b % a, a)
            return (g, x - (b // a) * y, y)

    def encrypt(self, plaintext: int) -> int:
        """加密"""
        e, n = self.public_key
        return pow(plaintext, e, n)

    def decrypt(self, ciphertext: int) -> int:
        """解密"""
        d, n = self.private_key
        return pow(ciphertext, d, n)

    def add_encrypted(self, a: int, b: int) -> int:
        """加密数据加法"""
        return (a * b) % (self.public_key[1] ** 2) 