import numpy as np
import random

# Beaver Triple Generator over F_p at Sever 2
def F_5_triple_gen(a1, a2, a3, b1, b2, b3, p=5):
    a = (a1 + a2 + a3) % p
    b = (b1 + b2 + b3) % p
    c = (a * b) % p
    c1 = random.randint(0, 3)
    c2 = random.randint(0, 3)
    c3 = (c - c1 - c2) % p
    return c1, c2, c3

# Parameters
n = 3       # Number of parties
p = 5       # Prime field (p>n, smallest prime)
num_beaver = 1000

# Check correctness of majority vote polynomial F(x) = 2x^3 + 4x mod p
print("*********************** TESTING POLYNOMIAL CHECKING **************************")
for ii in range(-3, 4, 2):
    f_val = (2 * ii**3 + 4 * ii) % p
    print(f"INPUT : {ii}, function val is {f_val}")
print("*********************** TESTING POLYNOMIAL CHECKING END **********************")

# Generate random messages from {-1, 1}
m_vec = np.random.choice([-1, 1], n)
print("Each party messages:")
print(m_vec)
m_sum = np.sum(m_vec)
m_res = (2 * m_sum**3 + 4 * m_sum) % p
print(f"Exact majority vote result: {m_res}")

# Precompute Beaver triples
A = np.random.randint(0, p, (n, num_beaver))
B = np.random.randint(0, p, (n, num_beaver))
C = np.zeros((n, num_beaver), dtype=int)
RAM = np.zeros((n, 1000), dtype=int)
COMM = np.zeros((n, 4), dtype=int)

for i in range(num_beaver):
    c1, c2, c3 = z5_triple_gen(A[0, i], A[1, i], A[2, i], B[0, i], B[1, i], B[2, i], p)
    C[:, i] = [c1, c2, c3]

# Check Beaver triple correctness
for i in range(12):
    a = np.sum(A[:, i]) % p
    b = np.sum(B[:, i]) % p
    c = np.sum(C[:, i]) % p
    if c != (a * b) % p:
        raise ValueError("Beaver triple error!")

print("!!!!!!!!!!!!!!!!!! PROTOCOL START !!!!!!!!!!!!!!!!!!")

# Round 0: Broadcast x-a and x-b
for i in range(n):
    COMM[i, 0] = (m_vec[i] - A[i, 0]) % p
    COMM[i, 1] = (m_vec[i] - B[i, 0]) % p

for i in range(n):
    RAM[i, 1:4] = COMM[:, 0]
    RAM[i, 5:8] = COMM[:, 1]

print("ROUND 0 IS END!!")

# Round 1: Compute x^2 securely
for i in range(n):
    x_a = sum(RAM[i, 1:5]) % p
    y_b = sum(RAM[i, 5:9]) % p
    if i == 0:
        RAM[i, 200] = (x_a * y_b + A[i, 0] * y_b + B[i, 0] * x_a + C[i, 0]) % p
    else:
        RAM[i, 200] = (A[i, 0] * y_b + B[i, 0] * x_a + C[i, 0]) % p

    COMM[i, 0] = (RAM[i, 200] - A[i, 1]) % p
    COMM[i, 1] = (RAM[i, 200] - B[i, 1]) % p
    COMM[i, 2] = (m_vec[i] - A[i, 2]) % p
    COMM[i, 3] = (RAM[i, 200] - B[i, 2]) % p

for i in range(n):
    RAM[i, 1:4] = COMM[:, 0]
    RAM[i, 5:8] = COMM[:, 1]
    RAM[i, 9:12] = COMM[:, 2]
    RAM[i, 13:16] = COMM[:, 3]

print("ROUND 1 IS END!!")

# Round 2: Compute x^3 and evaluate f(x) = 2x^3 + 4x securely
for i in range(n):
    x_a = sum(RAM[i, 9:13]) % p
    y_b = sum(RAM[i, 13:17]) % p
    if i == 0:
        RAM[i, 202] = (x_a * y_b + A[i, 2] * y_b + B[i, 2] * x_a + C[i, 2]) % p
    else:
        RAM[i, 202] = (A[i, 2] * y_b + B[i, 2] * x_a + C[i, 2]) % p

    RAM[i, 203] = (2 * RAM[i, 202] + 4 * m_vec[i]) % p
    COMM[i, 0] = RAM[i, 203]

print("ROUND 2 is END!")

# Final aggregation at Server 1
FINAL_RESULT = np.sum(COMM[:, 0]) % p
print(f"Server 1's aggregation result: {FINAL_RESULT}")
