import random

def is_prime(n):
    """Check if a number is prime."""
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

class EllipticCurve:
    def __init__(self, a : int, b : int, p : int):
        self.a = a 
        self.b = b
        assert is_prime(p), "p is not prime"
        self.p = p # prime

    def is_on_curve(self, x : int, y : int) -> bool:
        return (y**2 - x**3 - self.a*x - self.b) % self.p == 0

    def mod_inverse(self, value: int) -> int:
        """ uses fermat's little theorem for modular inverse """
        return pow(value, self.p - 2, self.p)

    def point_addition(self, P: tuple, Q: tuple) -> tuple:
        """ returns the sum of the two points, P, Q 
            uses (None, None) to represent infinity """
        # cases where either point are infinity
        if P == (None, None) : return Q
        if Q == (None, None) : return P

        # P + (-P) = 0 or if the y coordinate is 0, return point at infinity
        if P[0] == Q[0] and (P[1] != Q[1] or P[1] == 0) : return (None, None)

        if P != Q:
            # The lambda (slope) calculation for two distinct points
            m = (Q[1] - P[1]) * self.mod_inverse(Q[0] - P[0] + self.p) % self.p
        else:
            # The lambda (slope) calculation for point doubling
            m = (3 * P[0]**2 + self.a) * self.mod_inverse(2 * P[1]) % self.p

        x_r = (m**2 - P[0] - Q[0]) % self.p
        y_r = (m * (P[0] - x_r) - P[1]) % self.p

        return (x_r, y_r)
    
    def point_double(self, P: tuple) -> tuple:
        """ double the given point """
        return self.point_addition(P, P)

    def point_multiplication(self, k: int, P: tuple) -> tuple:
        """scalar multiplication of P by k."""
        if P == (None, None) or k == 0:
            return (None, None)

        result = (None, None)  # Initialize result as the identity element (infinity point)
        addend = P

        while k:
            if k & 1:
                result = self.point_addition(result, addend)
            addend = self.point_addition(addend, addend)
            k >>= 1

        return result

    def generate_keypair(self, G: tuple, n: int, d : int) -> tuple:
        """ Given an initial point G and an order n, construct a keypair, and d, the private key """
        assert 1 <= d and d <= n-1
        Q = self.point_multiplication(d, G) # public key
        return (d, Q) 

    def validate_keypair(self, d: int, Q: tuple, G: tuple, n: int) -> bool:
        """ Validate the given keypair, given an initial point G,
            a public key Q, a private key d, and a group order n """
        if not (1 <= d < n) : return False
        if not self.is_on_curve(Q[0], Q[1]) : return False
        return self.point_multiplication(d, G) == Q