import random

def is_prime(n):
    """Check if a number is prime."""
    if n <= 1:
        return False
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True

class EllipticCurve:
    def __init__(self, a : int, b : int, p : int):
        self.a = a 
        self.b = b
        assert is_prime(p), "p is not prime"
        self.p = p # prime

    def is_on_curve(self, x : int, y : int) -> bool:
        return (y**2 - x**3 - self.a*x - self.b) % self.p == 0

    def mod_inverse(self, value: int) -> int:
        """ uses fermat's little theorem for modular inverse """
        return pow(value, self.p - 2, self.p)

    def point_addition(self, P: tuple, Q: tuple) -> tuple:
        """ returns the sum of the two points, P, Q 
            uses (None, None) to represent infinity """
        # cases where either point are infinity
        if P == (None, None) : return Q
        if Q == (None, None) : return P
        
        # P + (-P) = 0 or if the y coordinate is 0, return point at infinity
        if P[0] == Q[0] and (P[1] != Q[1] or P[1] == 0) : return (None, None)

        if P != Q:
            # the lambda (slope) calculation for two distinct points
            m = (Q[1] - P[1]) * self.mod_inverse(Q[0] - P[0] + self.p) % self.p
        else:
            # the lambda (slope) calculation for point doubling
            m = (3 * P[0]**2 + self.a) * self.mod_inverse(2 * P[1]) % self.p

        x_r = (m**2 - P[0] - Q[0]) % self.p
        y_r = (m * (P[0] - x_r) - P[1]) % self.p

        return (x_r, y_r)

    def point_double(self, P: tuple) -> tuple:
        """ double the given point """
        return self.point_addition(P, P)

    def point_multiplication(self, k: int, P: tuple) -> tuple:
        """scalar multiplication of P by k."""
        if P == (None, None) or k == 0:
            return (None, None)

        result = (None, None) # initialize result as inf,inf
        addend = P

        while k:
            if k & 1:
                result = self.point_addition(result, addend)
            addend = self.point_addition(addend, addend)
            k >>= 1

        return result

    def windowed_point_multiplication(self, k: int, P: tuple, w: int = 4) -> tuple:
        if P == (None, None) or k == 0 : return (None, None)

        # precompute the multiples of P: P, 2P, 3P, ..., (2^w-1)P
        precomputed, current = [(None, None)], P
        for _ in range(1, 2**w):
            precomputed.append(current)
            current = self.point_addition(current, P)

        Q = (None, None)
        k_bin = bin(k)[2:]  # convert k to binary string

        # crocess each bit from left to right (MSB to LSB)
        for bit in k_bin:
            Q = self.point_double(Q)  # always double Q for each bit shift
            if bit == '1':
                Q = self.point_addition(Q, P)

        return Q

    def generate_keypair(self, G: tuple, n: int, d : int) -> tuple:
        """ Given an initial point G and an order n, construct a keypair, and d, the private key """
        assert 1 <= d and d <= n-1
        Q = self.windowed_point_multiplication(d, G) # public key
        return (d, Q)

    def validate_keypair(self, d: int, Q: tuple, G: tuple, n: int) -> bool:
        """ Validate the given keypair, given an initial point G,
            a public key Q, a private key d, and a group order n """
        if not (1 <= d < n) : return False
        if not self.is_on_curve(Q[0], Q[1]) : return False
        return self.windowed_point_multiplication(d, G) == Q